1
0

Compare commits

...

52 Commits

Author SHA1 Message Date
Erik Johnston
6b2d6fdd33 Add some debugging 2020-05-15 15:33:11 +01:00
Erik Johnston
d263a4de02 Enable moving event persistence off of master 2020-05-14 17:25:42 +01:00
Erik Johnston
66c1dff3ba Use new writers config 2020-05-14 17:25:41 +01:00
Erik Johnston
96b6023e3b Make location of events writer configurable 2020-05-14 17:25:09 +01:00
Erik Johnston
452019064c Allow ReplicationRestResource to be added to workers 2020-05-14 17:25:09 +01:00
Erik Johnston
7c8e09bcf1 Add a worker store for search insertion 2020-05-14 17:25:09 +01:00
Erik Johnston
e7f5ac4ed8 Fix lint 2020-05-14 17:09:58 +01:00
Erik Johnston
208ab7b135 Fix typing and add assertion. 2020-05-14 17:09:58 +01:00
Erik Johnston
41f558ccf7 Newsfile 2020-05-14 17:09:58 +01:00
Erik Johnston
342796d6ac Move push rules ID gen to push rules worker 2020-05-14 17:09:58 +01:00
Erik Johnston
bc3fc3927f Move events ID gens to EventWorkerStore 2020-05-14 17:09:58 +01:00
Erik Johnston
d67a8b5455 Move repliction event stream handling out of slave store 2020-05-14 17:09:58 +01:00
Erik Johnston
4734a7bbe4 Move EventStream handling into default ReplicationDataHandler (#7493)
This is so that the logic can happen on both master and workers when we move event persistence out.
2020-05-14 14:01:39 +01:00
Erik Johnston
1de36407d1 Add instance_map config and route replication calls (#7495) 2020-05-14 14:00:58 +01:00
Richard van der Hoff
dede23ff1e Merge tag 'v1.13.0rc2' into develop
Synapse 1.13.0rc2 (2020-05-14)
==============================

Bugfixes
--------

- Fix a long-standing bug which could cause messages not to be sent over federation, when state events with state keys matching user IDs (such as custom user statuses) were received. ([\#7376](https://github.com/matrix-org/synapse/issues/7376))
- Restore compatibility with non-compliant clients during the user interactive authentication process, fixing a problem introduced in v1.13.0rc1. ([\#7483](https://github.com/matrix-org/synapse/issues/7483))

Internal Changes
----------------

- Fix linting errors in new version of Flake8. ([\#7470](https://github.com/matrix-org/synapse/issues/7470))
2020-05-14 11:46:38 +01:00
Erik Johnston
1124111a12 Allow censoring of events to happen on workers. (#7492)
This is safe as we can now write to cache invalidation stream on workers, and is required for when we move event persistence off master.
2020-05-13 17:15:40 +01:00
Paul Tötterman
46cb2550bb Fix copypasted comment (#7477)
Signed-off-by: Paul Tötterman <paul.totterman@iki.fi>
2020-05-13 16:55:43 +01:00
Erik Johnston
18c1e52d82 Clean up replication unit tests. (#7490) 2020-05-13 16:01:47 +01:00
Erik Johnston
00ba9c48bf Spelling 2020-05-13 13:38:51 +01:00
Erik Johnston
782e4e64df Shuffle persist event data store functions. (#7440)
The aim here is to get to a stage where we have a `PersistEventStore` that holds all the write methods used during event persistence, so that we can take that class out of the `DataStore` mixin and instansiate it separately. This will allow us to instansiate it on processes other than master, while also ensuring it is only available on processes that are configured to write to events stream.

This is a bit of an architectural change, where we end up with multiple classes per data store (rather than one per data store we have now). We end up having:

1. Storage classes that provide high level APIs that can talk to multiple data stores.
2. Data store modules that consist of classes that must point at the same database instance.
3. Classes in a data store that can be instantiated on processes depending on config.
2020-05-13 13:38:22 +01:00
Erik Johnston
7ee24c5674 Have all instances correctly respond to REPLICATE command. (#7475)
Before all streams were only written to from master, so only master needed to respond to `REPLICATE` commands.

Before all instances wrote to the cache invalidation stream, but didn't respond to `REPLICATE`. This was a bug, which could lead to missed rows from cache invalidation stream if an instance is restarted, however all the caches would be empty in that case so it wasn't a problem.
2020-05-13 10:27:02 +01:00
Erik Johnston
8ca79613e6 Fix Redis reconnection logic (#7482)
Proactively send out `POSITION` commands (as if we had just received a `REPLICATE`) when we connect to Redis. This is important as other instances won't notice we've connected to issue a `REPLICATE` command (unlike for direct TCP connections). This is only currently an issue if master process reconnects without restarting (if it restarts then it won't have written anything and so other instances probably won't have missed anything).
2020-05-13 09:57:15 +01:00
Patrick Cloke
51fb0fc2e5 Update documentation about SSO mapping providers (#7458) 2020-05-12 10:51:07 -04:00
Erik Johnston
1a1da60ad2 Fix new flake8 errors (#7470) 2020-05-12 11:20:48 +01:00
Patrick Cloke
8c8858e124 Convert federation handler to async/await. (#7459) 2020-05-11 15:12:46 -04:00
Patrick Cloke
be309d99cf Convert search code to async/await. (#7460) 2020-05-11 15:12:39 -04:00
Amber Brown
7cb8b4bc67 Allow configuration of Synapse's cache without using synctl or environment variables (#6391) 2020-05-11 18:45:23 +01:00
Andrew Morgan
a8580c5f19 Remove unused store method get_hosts_in_room (#7448) 2020-05-11 16:55:57 +01:00
Andrew Morgan
5cf758cdd6 Merge branch 'release-v1.13.0' into develop
* release-v1.13.0:
  Don't UPGRADE database rows
  RST indenting
  Put rollback instructions in upgrade notes
  Fix changelog typo
  Oh yeah, RST
  Absolute URL it is then
  Fix upgrade notes link
  Provide summary of upgrade issues in changelog. Fix )
  Move next version notes from changelog to upgrade notes
  Changelog fixes
  1.13.0rc1
  Documentation on setting up redis (#7446)
  Rework UI Auth session validation for registration (#7455)
  Fix errors from malformed log line (#7454)
  Drop support for redis.dbid (#7450)
2020-05-11 16:46:33 +01:00
Andrew Morgan
67feea8044 Extend spam checker to allow for multiple modules (#7435) 2020-05-08 19:25:48 +01:00
Quentin Gliech
616af44137 Implement OpenID Connect-based login (#7256) 2020-05-08 08:30:40 -04:00
Manuel Stahl
a4a5ec4096 Add room details admin endpoint (#7317) 2020-05-07 15:33:07 -04:00
Brendan Abolivier
5bb26b7c4f Merge branch 'release-v1.13.0' into develop 2020-05-07 17:31:19 +02:00
Patrick Cloke
9e0384dd3f Fixes typo (bellow -> below) (#7449) 2020-05-07 09:31:06 -04:00
Patrick Cloke
22246919e3 Add more type hints to SAML handler. (#7445) 2020-05-07 09:30:45 -04:00
Erik Johnston
d7983b63a6 Support any process writing to cache invalidation stream. (#7436) 2020-05-07 13:51:08 +01:00
Brendan Abolivier
2929ce29d6 Merge pull request #7398 from Starbix/alpine-3.11
Update docker runtime image to Alpine v3.11
2020-05-07 11:56:56 +02:00
Richard van der Hoff
62ee862119 Merge branch 'release-v1.13.0' into develop 2020-05-06 15:56:03 +01:00
Richard van der Hoff
fa0b2bd28d Merge pull request #7428 from matrix-org/rav/cross_signing_keys_cache
Make get_e2e_cross_signing_key delegate to get_e2e_cross_signing_keys_bulk
2020-05-06 12:00:01 +01:00
Richard van der Hoff
16b67c404d Make get_e2e_cross_signing_key delegate to get_e2e_cross_signing_keys_bulk
... mostly because the latter has a cache.
2020-05-06 11:59:19 +01:00
Richard van der Hoff
db5f9031b7 Fix batching for fetching cross-signing keys
There's no point carefully dividing a list into batches, and then completely
ignoring the batches.
2020-05-06 11:59:19 +01:00
Richard van der Hoff
2e0c46ca07 Merge branch 'release-v1.13.0' into develop 2020-05-06 11:58:31 +01:00
Richard van der Hoff
79007a42b2 Merge pull request #7429 from matrix-org/rav/upsert_for_device_list
use an upsert to update device_lists_outbound_last_success
2020-05-06 11:53:18 +01:00
Richard van der Hoff
30a19daa02 Merge branch 'develop' into rav/upsert_for_device_list 2020-05-06 11:43:11 +01:00
Richard van der Hoff
e48361545d use an upsert to update device_lists_outbound_last_success 2020-05-06 11:41:23 +01:00
Richard van der Hoff
0f6ebf393d Better type annotations for simple_upsert_txn
most of these params don't really need to be lists.
2020-05-06 11:41:23 +01:00
Richard van der Hoff
16b1a34e80 Fix typing annotations in synapse/federation (#7382)
We're pretty close to having mypy working for `synapse.federation`, so let's
finish the job.
2020-05-05 14:27:13 +01:00
Richard van der Hoff
d5aa7d93ed Fix catchup-on-reconnect for the Federation Stream (#7374)
looks like we managed to break this during the refactorathon.
2020-05-05 14:15:57 +01:00
Erik Johnston
8123b2f909 Add MultiWriterIdGenerator. (#7281)
This will be used to coordinate stream IDs across multiple writers.

Functions as the equivalent of both `StreamIdGenerator` and
`SlavedIdTracker`.
2020-05-04 17:17:45 +01:00
Brendan Abolivier
15aa09bbe6 Merge branch 'release-v1.13.0' into develop 2020-05-04 16:33:56 +02:00
Patrick Cloke
eab59d758d Convert the room handler to async/await. (#7396) 2020-05-04 07:43:52 -04:00
Cédric Laubacher
a251e0f4ba Update runtime docker image to Alpine v3.11 2020-05-03 16:07:24 +02:00
162 changed files with 5933 additions and 2349 deletions

1
changelog.d/6391.feature Normal file
View File

@@ -0,0 +1 @@
Synapse's cache factor can now be configured in `homeserver.yaml` by the `caches.global_factor` setting. Additionally, `caches.per_cache_factors` controls the cache factors for individual caches.

1
changelog.d/7256.feature Normal file
View File

@@ -0,0 +1 @@
Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs).

1
changelog.d/7281.misc Normal file
View File

@@ -0,0 +1 @@
Add MultiWriterIdGenerator to support multiple concurrent writers of streams.

1
changelog.d/7317.feature Normal file
View File

@@ -0,0 +1 @@
Add room details admin endpoint. Contributed by Awesome Technologies Innovationslabor GmbH.

1
changelog.d/7374.misc Normal file
View File

@@ -0,0 +1 @@
Move catchup of replication streams logic to worker.

1
changelog.d/7382.misc Normal file
View File

@@ -0,0 +1 @@
Add typing annotations in `synapse.federation`.

1
changelog.d/7396.misc Normal file
View File

@@ -0,0 +1 @@
Convert the room handler to async/await.

1
changelog.d/7398.docker Normal file
View File

@@ -0,0 +1 @@
Update docker runtime image to Alpine v3.11. Contributed by @Starbix.

1
changelog.d/7428.misc Normal file
View File

@@ -0,0 +1 @@
Improve performance of `get_e2e_cross_signing_key`.

1
changelog.d/7429.misc Normal file
View File

@@ -0,0 +1 @@
Improve performance of `mark_as_sent_devices_by_remote`.

1
changelog.d/7435.feature Normal file
View File

@@ -0,0 +1 @@
Allow for using more than one spam checker module at once.

1
changelog.d/7436.misc Normal file
View File

@@ -0,0 +1 @@
Support any process writing to cache invalidation stream.

1
changelog.d/7440.misc Normal file
View File

@@ -0,0 +1 @@
Refactor event persistence database functions in preparation for allowing them to be run on non-master processes.

1
changelog.d/7445.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to the SAML handler.

1
changelog.d/7448.misc Normal file
View File

@@ -0,0 +1 @@
Remove storage method `get_hosts_in_room` that is no longer called anywhere.

1
changelog.d/7449.misc Normal file
View File

@@ -0,0 +1 @@
Fix some typos in the notice_expiry templates.

1
changelog.d/7458.doc Normal file
View File

@@ -0,0 +1 @@
Update information about mapping providers for SAML and OpenID.

1
changelog.d/7459.misc Normal file
View File

@@ -0,0 +1 @@
Convert the federation handler to async/await.

1
changelog.d/7460.misc Normal file
View File

@@ -0,0 +1 @@
Convert the search handler to async/await.

1
changelog.d/7470.misc Normal file
View File

@@ -0,0 +1 @@
Fix linting errors in new version of Flake8.

1
changelog.d/7475.misc Normal file
View File

@@ -0,0 +1 @@
Have all instance correctly respond to REPLICATE command.

1
changelog.d/7477.doc Normal file
View File

@@ -0,0 +1 @@
Fix copy-paste error in `ServerNoticesConfig` docstring. Contributed by @ptman.

1
changelog.d/7482.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix Redis reconnection logic that can result in missed updates over replication if master reconnects to Redis without restarting.

1
changelog.d/7490.misc Normal file
View File

@@ -0,0 +1 @@
Clean up replication unit tests.

1
changelog.d/7491.misc Normal file
View File

@@ -0,0 +1 @@
Move event stream handling out of slave store.

1
changelog.d/7492.misc Normal file
View File

@@ -0,0 +1 @@
Allow censoring of events to happen on workers.

1
changelog.d/7493.misc Normal file
View File

@@ -0,0 +1 @@
Move EventStream handling into default ReplicationDataHandler.

1
changelog.d/7495.feature Normal file
View File

@@ -0,0 +1 @@
Add `instance_map` config and route replication calls.

View File

@@ -55,7 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \
### Stage 1: runtime
###
FROM docker.io/python:${PYTHON_VERSION}-alpine3.10
FROM docker.io/python:${PYTHON_VERSION}-alpine3.11
# xmlsec is required for saml support
RUN apk add --no-cache --virtual .runtime_deps \

View File

@@ -264,3 +264,57 @@ Response:
Once the `next_token` parameter is no longer present, we know we've reached the
end of the list.
# DRAFT: Room Details API
The Room Details admin API allows server admins to get all details of a room.
This API is still a draft and details might change!
The following fields are possible in the JSON response body:
* `room_id` - The ID of the room.
* `name` - The name of the room.
* `canonical_alias` - The canonical (main) alias address of the room.
* `joined_members` - How many users are currently in the room.
* `joined_local_members` - How many local users are currently in the room.
* `version` - The version of the room as a string.
* `creator` - The `user_id` of the room creator.
* `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active.
* `federatable` - Whether users on other servers can join this room.
* `public` - Whether the room is visible in room directory.
* `join_rules` - The type of rules used for users wishing to join this room. One of: ["public", "knock", "invite", "private"].
* `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"].
* `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"].
* `state_events` - Total number of state_events of a room. Complexity of the room.
## Usage
A standard request:
```
GET /_synapse/admin/v1/rooms/<room_id>
{}
```
Response:
```
{
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
"name": "Music Theory",
"canonical_alias": "#musictheory:matrix.org",
"joined_members": 127
"joined_local_members": 2,
"version": "1",
"creator": "@foo:matrix.org",
"encryption": null,
"federatable": true,
"public": true,
"join_rules": "invite",
"guest_access": null,
"history_visibility": "shared",
"state_events": 93534
}
```

175
docs/dev/oidc.md Normal file
View File

@@ -0,0 +1,175 @@
# How to test OpenID Connect
Any OpenID Connect Provider (OP) should work with Synapse, as long as it supports the authorization code flow.
There are a few options for that:
- start a local OP. Synapse has been tested with [Hydra][hydra] and [Dex][dex-idp].
Note that for an OP to work, it should be served under a secure (HTTPS) origin.
A certificate signed with a self-signed, locally trusted CA should work. In that case, start Synapse with a `SSL_CERT_FILE` environment variable set to the path of the CA.
- use a publicly available OP. Synapse has been tested with [Google][google-idp].
- setup a SaaS OP, like [Auth0][auth0] and [Okta][okta]. Auth0 has a free tier which has been tested with Synapse.
[google-idp]: https://developers.google.com/identity/protocols/OpenIDConnect#authenticatingtheuser
[auth0]: https://auth0.com/
[okta]: https://www.okta.com/
[dex-idp]: https://github.com/dexidp/dex
[hydra]: https://www.ory.sh/docs/hydra/
## Sample configs
Here are a few configs for providers that should work with Synapse.
### [Dex][dex-idp]
[Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider.
Although it is designed to help building a full-blown provider, with some external database, it can be configured with static passwords in a config file.
Follow the [Getting Started guide](https://github.com/dexidp/dex/blob/master/Documentation/getting-started.md) to install Dex.
Edit `examples/config-dev.yaml` config file from the Dex repo to add a client:
```yaml
staticClients:
- id: synapse
secret: secret
redirectURIs:
- '[synapse base url]/_synapse/oidc/callback'
name: 'Synapse'
```
Run with `dex serve examples/config-dex.yaml`
Synapse config:
```yaml
oidc_config:
enabled: true
skip_verification: true # This is needed as Dex is served on an insecure endpoint
issuer: "http://127.0.0.1:5556/dex"
discover: true
client_id: "synapse"
client_secret: "secret"
scopes:
- openid
- profile
user_mapping_provider:
config:
localpart_template: '{{ user.name }}'
display_name_template: '{{ user.name|capitalize }}'
```
### [Auth0][auth0]
1. Create a regular web application for Synapse
2. Set the Allowed Callback URLs to `[synapse base url]/_synapse/oidc/callback`
3. Add a rule to add the `preferred_username` claim.
<details>
<summary>Code sample</summary>
```js
function addPersistenceAttribute(user, context, callback) {
user.user_metadata = user.user_metadata || {};
user.user_metadata.preferred_username = user.user_metadata.preferred_username || user.user_id;
context.idToken.preferred_username = user.user_metadata.preferred_username;
auth0.users.updateUserMetadata(user.user_id, user.user_metadata)
.then(function(){
callback(null, user, context);
})
.catch(function(err){
callback(err);
});
}
```
</details>
```yaml
oidc_config:
enabled: true
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
discover: true
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
scopes:
- openid
- profile
user_mapping_provider:
config:
localpart_template: '{{ user.preferred_username }}'
display_name_template: '{{ user.name }}'
```
### GitHub
GitHub is a bit special as it is not an OpenID Connect compliant provider, but just a regular OAuth2 provider.
The `/user` API endpoint can be used to retrieve informations from the user.
As the OIDC login mechanism needs an attribute to uniquely identify users and that endpoint does not return a `sub` property, an alternative `subject_claim` has to be set.
1. Create a new OAuth application: https://github.com/settings/applications/new
2. Set the callback URL to `[synapse base url]/_synapse/oidc/callback`
```yaml
oidc_config:
enabled: true
issuer: "https://github.com/"
discover: false
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
authorization_endpoint: "https://github.com/login/oauth/authorize"
token_endpoint: "https://github.com/login/oauth/access_token"
userinfo_endpoint: "https://api.github.com/user"
scopes:
- read:user
user_mapping_provider:
config:
subject_claim: 'id'
localpart_template: '{{ user.login }}'
display_name_template: '{{ user.name }}'
```
### Google
1. Setup a project in the Google API Console
2. Obtain the OAuth 2.0 credentials (see <https://developers.google.com/identity/protocols/oauth2/openid-connect>)
3. Add this Authorized redirect URI: `[synapse base url]/_synapse/oidc/callback`
```yaml
oidc_config:
enabled: true
issuer: "https://accounts.google.com/"
discover: true
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
scopes:
- openid
- profile
user_mapping_provider:
config:
localpart_template: '{{ user.given_name|lower }}'
display_name_template: '{{ user.name }}'
```
### Twitch
1. Setup a developer account on [Twitch](https://dev.twitch.tv/)
2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/)
3. Add this OAuth Redirect URL: `[synapse base url]/_synapse/oidc/callback`
```yaml
oidc_config:
enabled: true
issuer: "https://id.twitch.tv/oauth2/"
discover: true
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
client_auth_method: "client_secret_post"
scopes:
- openid
user_mapping_provider:
config:
localpart_template: '{{ user.preferred_username }}'
display_name_template: '{{ user.name }}'
```

View File

@@ -1,77 +0,0 @@
# SAML Mapping Providers
A SAML mapping provider is a Python class (loaded via a Python module) that
works out how to map attributes of a SAML response object to Matrix-specific
user attributes. Details such as user ID localpart, displayname, and even avatar
URLs are all things that can be mapped from talking to a SSO service.
As an example, a SSO service may return the email address
"john.smith@example.com" for a user, whereas Synapse will need to figure out how
to turn that into a displayname when creating a Matrix user for this individual.
It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
variations. As each Synapse configuration may want something different, this is
where SAML mapping providers come into play.
## Enabling Providers
External mapping providers are provided to Synapse in the form of an external
Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere,
then tell Synapse where to look for the handler class by editing the
`saml2_config.user_mapping_provider.module` config option.
`saml2_config.user_mapping_provider.config` allows you to provide custom
configuration options to the module. Check with the module's documentation for
what options it provides (if any). The options listed by default are for the
user mapping provider built in to Synapse. If using a custom module, you should
comment these options out and use those specified by the module instead.
## Building a Custom Mapping Provider
A custom mapping provider must specify the following methods:
* `__init__(self, parsed_config)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
* `saml_response_to_user_attributes(self, saml_response, failures)`
- Arguments:
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
information from.
- `failures` - An `int` that represents the amount of times the returned
mxid localpart mapping has failed. This should be used
to create a deduplicated mxid localpart which should be
returned instead. For example, if this method returns
`john.doe` as the value of `mxid_localpart` in the returned
dict, and that is already taken on the homeserver, this
method will be called again with the same parameters but
with failures=1. The method should then return a different
`mxid_localpart` value, such as `john.doe1`.
- This method must return a dictionary, which will then be used by Synapse
to build a new user. The following keys are allowed:
* `mxid_localpart` - Required. The mxid localpart of the new user.
* `displayname` - The displayname of the new user. If not provided, will default to
the value of `mxid_localpart`.
* `parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A `dict` representing the parsed content of the
`saml2_config.user_mapping_provider.config` homeserver config option.
Runs on homeserver startup. Providers should extract any option values
they need here.
- Whatever is returned will be passed back to the user mapping provider module's
`__init__` method during construction.
* `get_saml_attributes(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A object resulting from a call to `parse_config`.
- Returns a tuple of two sets. The first set equates to the saml auth
response attributes that are required for the module to function, whereas
the second set consists of those attributes which can be used if available,
but are not necessary.
## Synapse's Default Provider
Synapse has a built-in SAML mapping provider if a custom provider isn't
specified in the config. It is located at
[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).

View File

@@ -603,6 +603,45 @@ acme:
## Caching ##
# Caching can be configured through the following options.
#
# A cache 'factor' is a multiplier that can be applied to each of
# Synapse's caches in order to increase or decrease the maximum
# number of entries that can be stored.
# The number of events to cache in memory. Not affected by
# caches.global_factor.
#
#event_cache_size: 10K
caches:
# Controls the global cache factor, which is the default cache factor
# for all caches if a specific factor for that cache is not otherwise
# set.
#
# This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
# variable. Setting by environment variable takes priority over
# setting through the config file.
#
# Defaults to 0.5, which will half the size of all caches.
#
#global_factor: 1.0
# A dictionary of cache name to cache factor for that individual
# cache. Overrides the global cache factor for a given cache.
#
# These can also be set through environment variables comprised
# of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
# letters and underscores. Setting by environment variable
# takes priority over setting through the config file.
# Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
#
per_cache_factors:
#get_users_who_share_room_with_user: 2.0
## Database ##
# The 'database' setting defines the database that synapse uses to store all of
@@ -646,10 +685,6 @@ database:
args:
database: DATADIR/homeserver.db
# Number of events to cache in memory.
#
#event_cache_size: 10K
## Logging ##
@@ -1470,6 +1505,94 @@ saml2_config:
#template_dir: "res/templates"
# Enable OpenID Connect for registration and login. Uses authlib.
#
oidc_config:
# enable OpenID Connect. Defaults to false.
#
#enabled: true
# use the OIDC discovery mechanism to discover endpoints. Defaults to true.
#
#discover: true
# the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required.
#
#issuer: "https://accounts.example.com/"
# oauth2 client id to use. Required.
#
#client_id: "provided-by-your-issuer"
# oauth2 client secret to use. Required.
#
#client_secret: "provided-by-your-issuer"
# auth method to use when exchanging the token.
# Valid values are "client_secret_basic" (default), "client_secret_post" and "none".
#
#client_auth_method: "client_auth_basic"
# list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"].
#
#scopes: ["openid"]
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
#
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# the oauth2 token endpoint. Required if provider discovery is disabled.
#
#token_endpoint: "https://accounts.example.com/oauth2/token"
# the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked.
#
#userinfo_endpoint: "https://accounts.example.com/userinfo"
# URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used.
#
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# skip metadata verification. Defaults to false.
# Use this if you are connecting to a provider that is not OpenID Connect compliant.
# Avoid this in production.
#
#skip_verification: false
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
# Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
#
#module: mapping_provider.OidcMappingProvider
# Custom configuration values for the module. Below options are intended
# for the built-in provider, they should be changed if using a custom
# module. This section will be passed as a Python dictionary to the
# module's `parse_config` method.
#
# Below is the config of the default mapping provider, based on Jinja2
# templates. Those templates are used to render user attributes, where the
# userinfo object is available through the `user` variable.
#
config:
# name of the claim containing a unique identifier for the user.
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
#
#subject_claim: "sub"
# Jinja2 template for the localpart of the MXID
#
localpart_template: "{{ user.preferred_username }}"
# Jinja2 template for the display name to set on first login. Optional.
#
#display_name_template: "{{ user.given_name }} {{ user.last_name }}"
# Enable CAS for registration and login.
#
@@ -1554,6 +1677,13 @@ sso:
#
# This template has no additional variables.
#
# * HTML page to display to users if something goes wrong during the
# OpenID Connect authentication process: 'sso_error.html'.
#
# When rendering, this template is given two variables:
# * error: the technical name of the error
# * error_description: a human-readable message for the error
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
@@ -1772,10 +1902,17 @@ password_providers:
# include_content: true
#spam_checker:
# module: "my_custom_project.SuperSpamChecker"
# config:
# example_option: 'things'
# Spam checkers are third-party modules that can block specific actions
# of local users, such as creating rooms and registering undesirable
# usernames, as well as remote users by redacting incoming events.
#
spam_checker:
#- module: "my_custom_project.SuperSpamChecker"
# config:
# example_option: 'things'
#- module: "some_other_project.BadEventStopper"
# config:
# example_stop_events_from: ['@bad:example.com']
# Uncomment to allow non-server-admin users to create groups on this server

View File

@@ -64,10 +64,12 @@ class ExampleSpamChecker:
Modify the `spam_checker` section of your `homeserver.yaml` in the following
manner:
`module` should point to the fully qualified Python class that implements your
custom logic, e.g. `my_module.ExampleSpamChecker`.
Create a list entry with the keys `module` and `config`.
`config` is a dictionary that gets passed to the spam checker class.
* `module` should point to the fully qualified Python class that implements your
custom logic, e.g. `my_module.ExampleSpamChecker`.
* `config` is a dictionary that gets passed to the spam checker class.
### Example
@@ -75,12 +77,15 @@ This section might look like:
```yaml
spam_checker:
module: my_module.ExampleSpamChecker
config:
# Enable or disable a specific option in ExampleSpamChecker.
my_custom_option: true
- module: my_module.ExampleSpamChecker
config:
# Enable or disable a specific option in ExampleSpamChecker.
my_custom_option: true
```
More spam checkers can be added in tandem by appending more items to the list. An
action is blocked when at least one of the configured spam checkers flags it.
## Examples
The [Mjolnir](https://github.com/matrix-org/mjolnir) project is a full fledged

View File

@@ -0,0 +1,146 @@
# SSO Mapping Providers
A mapping provider is a Python class (loaded via a Python module) that
works out how to map attributes of a SSO response to Matrix-specific
user attributes. Details such as user ID localpart, displayname, and even avatar
URLs are all things that can be mapped from talking to a SSO service.
As an example, a SSO service may return the email address
"john.smith@example.com" for a user, whereas Synapse will need to figure out how
to turn that into a displayname when creating a Matrix user for this individual.
It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
variations. As each Synapse configuration may want something different, this is
where SAML mapping providers come into play.
SSO mapping providers are currently supported for OpenID and SAML SSO
configurations. Please see the details below for how to implement your own.
External mapping providers are provided to Synapse in the form of an external
Python module. You can retrieve this module from [PyPi](https://pypi.org) or elsewhere,
but it must be importable via Synapse (e.g. it must be in the same virtualenv
as Synapse). The Synapse config is then modified to point to the mapping provider
(and optionally provide additional configuration for it).
## OpenID Mapping Providers
The OpenID mapping provider can be customized by editing the
`oidc_config.user_mapping_provider.module` config option.
`oidc_config.user_mapping_provider.config` allows you to provide custom
configuration options to the module. Check with the module's documentation for
what options it provides (if any). The options listed by default are for the
user mapping provider built in to Synapse. If using a custom module, you should
comment these options out and use those specified by the module instead.
### Building a Custom OpenID Mapping Provider
A custom mapping provider must specify the following methods:
* `__init__(self, parsed_config)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
* `parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A `dict` representing the parsed content of the
`oidc_config.user_mapping_provider.config` homeserver config option.
Runs on homeserver startup. Providers should extract and validate
any option values they need here.
- Whatever is returned will be passed back to the user mapping provider module's
`__init__` method during construction.
* `get_remote_user_id(self, userinfo)`
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- This method must return a string, which is the unique identifier for the
user. Commonly the ``sub`` claim of the response.
* `map_user_attributes(self, userinfo, token)`
- This method should be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- `token` - A dictionary which includes information necessary to make
further requests to the OpenID provider.
- Returns a dictionary with two keys:
- localpart: A required string, used to generate the Matrix ID.
- displayname: An optional string, the display name for the user.
### Default OpenID Mapping Provider
Synapse has a built-in OpenID mapping provider if a custom provider isn't
specified in the config. It is located at
[`synapse.handlers.oidc_handler.JinjaOidcMappingProvider`](../synapse/handlers/oidc_handler.py).
## SAML Mapping Providers
The SAML mapping provider can be customized by editing the
`saml2_config.user_mapping_provider.module` config option.
`saml2_config.user_mapping_provider.config` allows you to provide custom
configuration options to the module. Check with the module's documentation for
what options it provides (if any). The options listed by default are for the
user mapping provider built in to Synapse. If using a custom module, you should
comment these options out and use those specified by the module instead.
### Building a Custom SAML Mapping Provider
A custom mapping provider must specify the following methods:
* `__init__(self, parsed_config)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
* `parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A `dict` representing the parsed content of the
`saml_config.user_mapping_provider.config` homeserver config option.
Runs on homeserver startup. Providers should extract and validate
any option values they need here.
- Whatever is returned will be passed back to the user mapping provider module's
`__init__` method during construction.
* `get_saml_attributes(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A object resulting from a call to `parse_config`.
- Returns a tuple of two sets. The first set equates to the SAML auth
response attributes that are required for the module to function, whereas
the second set consists of those attributes which can be used if available,
but are not necessary.
* `get_remote_user_id(self, saml_response, client_redirect_url)`
- Arguments:
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
information from.
- `client_redirect_url` - A string, the URL that the client will be
redirected to.
- This method must return a string, which is the unique identifier for the
user. Commonly the ``uid`` claim of the response.
* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
- Arguments:
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
information from.
- `failures` - An `int` that represents the amount of times the returned
mxid localpart mapping has failed. This should be used
to create a deduplicated mxid localpart which should be
returned instead. For example, if this method returns
`john.doe` as the value of `mxid_localpart` in the returned
dict, and that is already taken on the homeserver, this
method will be called again with the same parameters but
with failures=1. The method should then return a different
`mxid_localpart` value, such as `john.doe1`.
- `client_redirect_url` - A string, the URL that the client will be
redirected to.
- This method must return a dictionary, which will then be used by Synapse
to build a new user. The following keys are allowed:
* `mxid_localpart` - Required. The mxid localpart of the new user.
* `displayname` - The displayname of the new user. If not provided, will default to
the value of `mxid_localpart`.
### Default SAML Mapping Provider
Synapse has a built-in SAML mapping provider if a custom provider isn't
specified in the config. It is located at
[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).

View File

@@ -219,10 +219,6 @@ Asks the server for the current position of all streams.
Inform the server a pusher should be removed
#### INVALIDATE_CACHE (C)
Inform the server a cache should be invalidated
### REMOTE_SERVER_UP (S, C)
Inform other processes that a remote server may have come back online.

View File

@@ -75,3 +75,6 @@ ignore_missing_imports = True
[mypy-jwt.*]
ignore_missing_imports = True
[mypy-authlib.*]
ignore_missing_imports = True

View File

@@ -122,7 +122,7 @@ APPEND_ONLY_TABLES = [
"presence_stream",
"push_rules_stream",
"ex_outlier_stream",
"cache_invalidation_stream",
"cache_invalidation_stream_by_instance",
"public_room_list_stream",
"state_group_edges",
"stream_ordering_to_exterm",
@@ -188,7 +188,7 @@ class MockHomeserver:
self.clock = Clock(reactor)
self.config = config
self.hostname = config.server_name
self.version_string = "Synapse/"+get_version_string(synapse)
self.version_string = "Synapse/" + get_version_string(synapse)
def get_clock(self):
return self.clock

View File

@@ -37,7 +37,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import StateMap, UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@@ -73,7 +73,7 @@ class Auth(object):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
self.token_cache = LruCache(10000)
register_cache("cache", "token_cache", self.token_cache)
self._auth_blocking = AuthBlocking(self.hs)

View File

@@ -26,7 +26,6 @@ from twisted.web.resource import NoResource
import synapse
import synapse.events
from synapse.api.constants import EventTypes
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.api.urls import (
CLIENT_API_PREFIX,
@@ -48,6 +47,7 @@ from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -81,11 +81,6 @@ from synapse.replication.tcp.streams import (
ToDeviceStream,
TypingStream,
)
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@@ -122,11 +117,13 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer
from synapse.storage.data_stores.main.censor_events import CensorEventsStore
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
from synapse.storage.data_stores.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.data_stores.main.presence import UserPresenceState
from synapse.storage.data_stores.main.search import SearchWorkerStore
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
@@ -442,6 +439,7 @@ class GenericWorkerSlavedStore(
SlavedGroupServerStore,
SlavedAccountDataStore,
SlavedPusherStore,
CensorEventsStore,
SlavedEventStore,
SlavedKeyStore,
RoomStore,
@@ -455,6 +453,7 @@ class GenericWorkerSlavedStore(
SlavedFilteringStore,
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,
SearchWorkerStore,
BaseSlavedStore,
):
def __init__(self, database, db_conn, hs):
@@ -572,6 +571,9 @@ class GenericWorkerServer(HomeServer):
if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
@@ -631,7 +633,7 @@ class GenericWorkerServer(HomeServer):
class GenericWorkerReplicationHandler(ReplicationDataHandler):
def __init__(self, hs):
super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
super(GenericWorkerReplicationHandler, self).__init__(hs)
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
@@ -657,30 +659,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
stream_name, token, rows
)
if stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
for row in rows:
if row.type != EventsStreamEventRow.TypeId:
continue
assert isinstance(row, EventsStreamRow)
event = await self.store.get_event(
row.data.event_id, allow_rejected=True
)
if event.rejected_reason:
continue
extra_users = ()
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
max_token = self.store.get_room_max_stream_ordering()
self.notifier.on_new_room_event(
event, token, max_token, extra_users
)
await self.pusher_pool.on_new_notifications(token, token)
elif stream_name == PushRulesStream.NAME:
if stream_name == PushRulesStream.NAME:
self.notifier.on_new_event(
"push_rules_key", token, users=[row.user_id for row in rows]
)

View File

@@ -69,7 +69,6 @@ from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.module_loader import load_module
@@ -192,6 +191,11 @@ class SynapseHomeServer(HomeServer):
}
)
if self.get_config().oidc_enabled:
from synapse.rest.oidc import OIDCResource
resources["/_synapse/oidc"] = OIDCResource(self)
if self.get_config().saml2_enabled:
from synapse.rest.saml2 import SAML2Resource
@@ -422,6 +426,13 @@ def setup(config_options):
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
# Load the OIDC provider metadatas, if OIDC is enabled.
if hs.config.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
yield defer.ensureDeferred(oidc.load_metadata())
yield defer.ensureDeferred(oidc.load_jwks())
_base.start(hs, config.listeners)
hs.get_datastore().db.updates.start_doing_background_updates()
@@ -504,8 +515,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = CACHE_SIZE_FACTOR
stats["event_cache_size"] = hs.config.event_cache_size
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
#
# Performance statistics

View File

@@ -13,6 +13,7 @@ from synapse.config import (
key,
logger,
metrics,
oidc_config,
password,
password_auth_providers,
push,
@@ -59,6 +60,7 @@ class RootConfig:
saml2: saml2_config.SAML2Config
cas: cas.CasConfig
sso: sso.SSOConfig
oidc: oidc_config.OIDCConfig
jwt: jwt_config.JWTConfig
password: password.PasswordConfig
email: emailconfig.EmailConfig

164
synapse/config/cache.py Normal file
View File

@@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable, Dict
from ._base import Config, ConfigError
# The prefix for all cache factor-related environment variables
_CACHES = {}
_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
_DEFAULT_FACTOR_SIZE = 0.5
_DEFAULT_EVENT_CACHE_SIZE = "10K"
class CacheProperties(object):
def __init__(self):
# The default factor size for all caches
self.default_factor_size = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
)
self.resize_all_caches_func = None
properties = CacheProperties()
def add_resizable_cache(cache_name: str, cache_resize_callback: Callable):
"""Register a cache that's size can dynamically change
Args:
cache_name: A reference to the cache
cache_resize_callback: A callback function that will be ran whenever
the cache needs to be resized
"""
_CACHES[cache_name.lower()] = cache_resize_callback
# Ensure all loaded caches are sized appropriately
#
# This method should only run once the config has been read,
# as it uses values read from it
if properties.resize_all_caches_func:
properties.resize_all_caches_func()
class CacheConfig(Config):
section = "caches"
_environ = os.environ
@staticmethod
def reset():
"""Resets the caches to their defaults. Used for tests."""
properties.default_factor_size = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
)
properties.resize_all_caches_func = None
_CACHES.clear()
def generate_config_section(self, **kwargs):
return """\
## Caching ##
# Caching can be configured through the following options.
#
# A cache 'factor' is a multiplier that can be applied to each of
# Synapse's caches in order to increase or decrease the maximum
# number of entries that can be stored.
# The number of events to cache in memory. Not affected by
# caches.global_factor.
#
#event_cache_size: 10K
caches:
# Controls the global cache factor, which is the default cache factor
# for all caches if a specific factor for that cache is not otherwise
# set.
#
# This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
# variable. Setting by environment variable takes priority over
# setting through the config file.
#
# Defaults to 0.5, which will half the size of all caches.
#
#global_factor: 1.0
# A dictionary of cache name to cache factor for that individual
# cache. Overrides the global cache factor for a given cache.
#
# These can also be set through environment variables comprised
# of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
# letters and underscores. Setting by environment variable
# takes priority over setting through the config file.
# Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
#
per_cache_factors:
#get_users_who_share_room_with_user: 2.0
"""
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
)
self.cache_factors = {} # type: Dict[str, float]
cache_config = config.get("caches") or {}
self.global_factor = cache_config.get(
"global_factor", properties.default_factor_size
)
if not isinstance(self.global_factor, (int, float)):
raise ConfigError("caches.global_factor must be a number.")
# Set the global one so that it's reflected in new caches
properties.default_factor_size = self.global_factor
# Load cache factors from the config
individual_factors = cache_config.get("per_cache_factors") or {}
if not isinstance(individual_factors, dict):
raise ConfigError("caches.per_cache_factors must be a dictionary")
# Override factors from environment if necessary
individual_factors.update(
{
key[len(_CACHE_PREFIX) + 1 :].lower(): float(val)
for key, val in self._environ.items()
if key.startswith(_CACHE_PREFIX + "_")
}
)
for cache, factor in individual_factors.items():
if not isinstance(factor, (int, float)):
raise ConfigError(
"caches.per_cache_factors.%s must be a number" % (cache.lower(),)
)
self.cache_factors[cache.lower()] = factor
# Resize all caches (if necessary) with the new factors we've loaded
self.resize_all_caches()
# Store this function so that it can be called from other classes without
# needing an instance of Config
properties.resize_all_caches_func = self.resize_all_caches
def resize_all_caches(self):
"""Ensure all cache sizes are up to date
For each cache, run the mapped callback function with either
a specific cache factor or the default, global one.
"""
for cache_name, callback in _CACHES.items():
new_factor = self.cache_factors.get(cache_name, self.global_factor)
callback(new_factor)

View File

@@ -68,10 +68,6 @@ database:
name: sqlite3
args:
database: %(database_path)s
# Number of events to cache in memory.
#
#event_cache_size: 10K
"""
@@ -116,8 +112,6 @@ class DatabaseConfig(Config):
self.databases = []
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
# We *experimentally* support specifying multiple databases via the
# `databases` key. This is a map from a label to database config in the
# same format as the `database` config option, plus an extra

View File

@@ -17,6 +17,7 @@
from ._base import RootConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .cache import CacheConfig
from .captcha import CaptchaConfig
from .cas import CasConfig
from .consent_config import ConsentConfig
@@ -27,6 +28,7 @@ from .jwt_config import JWTConfig
from .key import KeyConfig
from .logger import LoggingConfig
from .metrics import MetricsConfig
from .oidc_config import OIDCConfig
from .password import PasswordConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
@@ -54,6 +56,7 @@ class HomeServerConfig(RootConfig):
config_classes = [
ServerConfig,
TlsConfig,
CacheConfig,
DatabaseConfig,
LoggingConfig,
RatelimitConfig,
@@ -66,6 +69,7 @@ class HomeServerConfig(RootConfig):
AppServiceConfig,
KeyConfig,
SAML2Config,
OIDCConfig,
CasConfig,
SSOConfig,
JWTConfig,

View File

@@ -257,5 +257,6 @@ def setup_logging(
logging.warning("***** STARTING SERVER *****")
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
logging.info("Instance name: %s", hs.get_instance_name())
return logger

View File

@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module
from ._base import Config, ConfigError
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
class OIDCConfig(Config):
section = "oidc"
def read_config(self, config, **kwargs):
self.oidc_enabled = False
oidc_config = config.get("oidc_config")
if not oidc_config or not oidc_config.get("enabled", False):
return
try:
check_requirements("oidc")
except DependencyException as e:
raise ConfigError(e.message)
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("oidc_config requires a public_baseurl to be set")
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
self.oidc_enabled = True
self.oidc_discover = oidc_config.get("discover", True)
self.oidc_issuer = oidc_config["issuer"]
self.oidc_client_id = oidc_config["client_id"]
self.oidc_client_secret = oidc_config["client_secret"]
self.oidc_client_auth_method = oidc_config.get(
"client_auth_method", "client_secret_basic"
)
self.oidc_scopes = oidc_config.get("scopes", ["openid"])
self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
self.oidc_token_endpoint = oidc_config.get("token_endpoint")
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
self.oidc_subject_claim = oidc_config.get("subject_claim", "sub")
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
ump_config.setdefault("config", {})
(
self.oidc_user_mapping_provider_class,
self.oidc_user_mapping_provider_config,
) = load_module(ump_config)
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
"get_remote_user_id",
"map_user_attributes",
]
missing_methods = [
method
for method in required_methods
if not hasattr(self.oidc_user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by oidc_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
# Enable OpenID Connect for registration and login. Uses authlib.
#
oidc_config:
# enable OpenID Connect. Defaults to false.
#
#enabled: true
# use the OIDC discovery mechanism to discover endpoints. Defaults to true.
#
#discover: true
# the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required.
#
#issuer: "https://accounts.example.com/"
# oauth2 client id to use. Required.
#
#client_id: "provided-by-your-issuer"
# oauth2 client secret to use. Required.
#
#client_secret: "provided-by-your-issuer"
# auth method to use when exchanging the token.
# Valid values are "client_secret_basic" (default), "client_secret_post" and "none".
#
#client_auth_method: "client_auth_basic"
# list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"].
#
#scopes: ["openid"]
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
#
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# the oauth2 token endpoint. Required if provider discovery is disabled.
#
#token_endpoint: "https://accounts.example.com/oauth2/token"
# the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked.
#
#userinfo_endpoint: "https://accounts.example.com/userinfo"
# URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used.
#
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# skip metadata verification. Defaults to false.
# Use this if you are connecting to a provider that is not OpenID Connect compliant.
# Avoid this in production.
#
#skip_verification: false
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
# Default is {mapping_provider!r}.
#
#module: mapping_provider.OidcMappingProvider
# Custom configuration values for the module. Below options are intended
# for the built-in provider, they should be changed if using a custom
# module. This section will be passed as a Python dictionary to the
# module's `parse_config` method.
#
# Below is the config of the default mapping provider, based on Jinja2
# templates. Those templates are used to render user attributes, where the
# userinfo object is available through the `user` variable.
#
config:
# name of the claim containing a unique identifier for the user.
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
#
#subject_claim: "sub"
# Jinja2 template for the localpart of the MXID
#
localpart_template: "{{{{ user.preferred_username }}}}"
# Jinja2 template for the display name to set on first login. Optional.
#
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
""".format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
)

View File

@@ -51,7 +51,7 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled.
server_notices_mxid_avatar_url (str|None):
The display name to use for the server notices user.
The MXC URL for the avatar of the server notices user.
None if server notices are not enabled.
server_notices_room_name (str|None):

View File

@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Tuple
from synapse.config import ConfigError
from synapse.util.module_loader import load_module
from ._base import Config
@@ -22,16 +25,35 @@ class SpamCheckerConfig(Config):
section = "spamchecker"
def read_config(self, config, **kwargs):
self.spam_checker = None
self.spam_checkers = [] # type: List[Tuple[Any, Dict]]
provider = config.get("spam_checker", None)
if provider is not None:
self.spam_checker = load_module(provider)
spam_checkers = config.get("spam_checker") or []
if isinstance(spam_checkers, dict):
# The spam_checker config option used to only support one
# spam checker, and thus was simply a dictionary with module
# and config keys. Support this old behaviour by checking
# to see if the option resolves to a dictionary
self.spam_checkers.append(load_module(spam_checkers))
elif isinstance(spam_checkers, list):
for spam_checker in spam_checkers:
if not isinstance(spam_checker, dict):
raise ConfigError("spam_checker syntax is incorrect")
self.spam_checkers.append(load_module(spam_checker))
else:
raise ConfigError("spam_checker syntax is incorrect")
def generate_config_section(self, **kwargs):
return """\
#spam_checker:
# module: "my_custom_project.SuperSpamChecker"
# config:
# example_option: 'things'
# Spam checkers are third-party modules that can block specific actions
# of local users, such as creating rooms and registering undesirable
# usernames, as well as remote users by redacting incoming events.
#
spam_checker:
#- module: "my_custom_project.SuperSpamChecker"
# config:
# example_option: 'things'
#- module: "some_other_project.BadEventStopper"
# config:
# example_stop_events_from: ['@bad:example.com']
"""

View File

@@ -36,17 +36,13 @@ class SSOConfig(Config):
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
self.sso_redirect_confirm_template_dir = template_dir
self.sso_template_dir = template_dir
self.sso_account_deactivated_template = self.read_file(
os.path.join(
self.sso_redirect_confirm_template_dir, "sso_account_deactivated.html"
),
os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
"sso_account_deactivated_template",
)
self.sso_auth_success_template = self.read_file(
os.path.join(
self.sso_redirect_confirm_template_dir, "sso_auth_success.html"
),
os.path.join(self.sso_template_dir, "sso_auth_success.html"),
"sso_auth_success_template",
)
@@ -137,6 +133,13 @@ class SSOConfig(Config):
#
# This template has no additional variables.
#
# * HTML page to display to users if something goes wrong during the
# OpenID Connect authentication process: 'sso_error.html'.
#
# When rendering, this template is given two variables:
# * error: the technical name of the error
# * error_description: a human-readable message for the error
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#

View File

@@ -13,9 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import attr
from ._base import Config
@attr.s
class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication.
"""
host = attr.ib(type=str)
port = attr.ib(type=int)
@attr.s
class WriterLocations:
"""Specifies the instances that write various streams.
Attributes:
events: The instance that writes to the event and backfill streams.
"""
events = attr.ib(default="master", type=str)
class WorkerConfig(Config):
"""The workers are processes run separately to the main synapse process.
They have their own pid_file and listener configuration. They use the
@@ -71,6 +93,16 @@ class WorkerConfig(Config):
elif not bind_addresses:
bind_addresses.append("")
# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map", {}) or {}
self.instance_map = {
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
}
# Map from type of streams to source, c.f. WriterLocations.
writers = config.get("writers", {}) or {}
self.writers = WriterLocations(**writers)
def read_arguments(self, args):
# We support a bunch of command line arguments that override options in
# the config. A lot of these options have a worker_* prefix when running

View File

@@ -15,7 +15,7 @@
# limitations under the License.
import inspect
from typing import Dict
from typing import Any, Dict, List
from synapse.spam_checker_api import SpamCheckerApi
@@ -26,24 +26,17 @@ if MYPY:
class SpamChecker(object):
def __init__(self, hs: "synapse.server.HomeServer"):
self.spam_checker = None
self.spam_checkers = [] # type: List[Any]
module = None
config = None
try:
module, config = hs.config.spam_checker
except Exception:
pass
if module is not None:
for module, config in hs.config.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we
# try and detect support.
spam_args = inspect.getfullargspec(module)
if "api" in spam_args.args:
api = SpamCheckerApi(hs)
self.spam_checker = module(config=config, api=api)
self.spam_checkers.append(module(config=config, api=api))
else:
self.spam_checker = module(config=config)
self.spam_checkers.append(module(config=config))
def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
"""Checks if a given event is considered "spammy" by this server.
@@ -58,10 +51,11 @@ class SpamChecker(object):
Returns:
True if the event is spammy.
"""
if self.spam_checker is None:
return False
for spam_checker in self.spam_checkers:
if spam_checker.check_event_for_spam(event):
return True
return self.spam_checker.check_event_for_spam(event)
return False
def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str
@@ -78,12 +72,14 @@ class SpamChecker(object):
Returns:
True if the user may send an invite, otherwise False
"""
if self.spam_checker is None:
return True
for spam_checker in self.spam_checkers:
if (
spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
is False
):
return False
return self.spam_checker.user_may_invite(
inviter_userid, invitee_userid, room_id
)
return True
def user_may_create_room(self, userid: str) -> bool:
"""Checks if a given user may create a room
@@ -96,10 +92,11 @@ class SpamChecker(object):
Returns:
True if the user may create a room, otherwise False
"""
if self.spam_checker is None:
return True
for spam_checker in self.spam_checkers:
if spam_checker.user_may_create_room(userid) is False:
return False
return self.spam_checker.user_may_create_room(userid)
return True
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias
@@ -113,10 +110,11 @@ class SpamChecker(object):
Returns:
True if the user may create a room alias, otherwise False
"""
if self.spam_checker is None:
return True
for spam_checker in self.spam_checkers:
if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
return False
return self.spam_checker.user_may_create_room_alias(userid, room_alias)
return True
def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory
@@ -130,10 +128,11 @@ class SpamChecker(object):
Returns:
True if the user may publish the room, otherwise False
"""
if self.spam_checker is None:
return True
for spam_checker in self.spam_checkers:
if spam_checker.user_may_publish_room(userid, room_id) is False:
return False
return self.spam_checker.user_may_publish_room(userid, room_id)
return True
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
@@ -150,13 +149,14 @@ class SpamChecker(object):
Returns:
True if the user is spammy.
"""
if self.spam_checker is None:
return False
for spam_checker in self.spam_checkers:
# For backwards compatibility, only run if the method exists on the
# spam checker
checker = getattr(spam_checker, "check_username_for_spam", None)
if checker:
# Make a copy of the user profile object to ensure the spam checker
# cannot modify it.
if checker(user_profile.copy()):
return True
# For backwards compatibility, if the method does not exist on the spam checker, fallback to not interfering.
checker = getattr(self.spam_checker, "check_username_for_spam", None)
if not checker:
return False
# Make a copy of the user profile object to ensure the spam checker
# cannot modify it.
return checker(user_profile.copy())
return False

View File

@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
from typing import Dict, List, Tuple, Type
from six import iteritems
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.presence_map = {} # Pending presence map user_id -> UserPresenceState
self.presence_changed = SortedDict() # Stream position -> list[user_id]
# Pending presence map user_id -> UserPresenceState
self.presence_map = {} # type: Dict[str, UserPresenceState]
# Stream position -> list[user_id]
self.presence_changed = SortedDict() # type: SortedDict[int, List[str]]
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
self.presence_destinations = SortedDict()
self.presence_destinations = (
SortedDict()
) # type: SortedDict[int, Tuple[str, List[str]]]
self.keyed_edu = {} # (destination, key) -> EDU
self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
# (destination, key) -> EDU
self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
self.edus = SortedDict() # stream position -> Edu
# stream position -> (destination, key)
self.keyed_edu_changed = (
SortedDict()
) # type: SortedDict[int, Tuple[str, tuple]]
self.edus = SortedDict() # type: SortedDict[int, Edu]
# stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
self.pos = 1
self.pos_time = SortedDict()
# map from stream ID to the time that stream entry was generated, so that we
# can clear out entries after a while
self.pos_time = SortedDict() # type: SortedDict[int, int]
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
for edu_key in self.keyed_edu_changed.values():
live_keys.add(edu_key)
to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
for edu_key in to_del:
keys_to_del = [
edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
]
for edu_key in keys_to_del:
del self.keyed_edu[edu_key]
# Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
self._clear_queue_before_pos(token)
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
"""Get rows to be sent over federation between the two tokens
Args:
from_token (int)
to_token(int)
limit (int)
federation_ack (int): Optional. The position where the worker is
explicitly acknowledged it has handled. Allows us to drop
data from before that point
instance_name: the name of the current process
from_token: the previous stream token: the starting point for fetching the
updates
to_token: the new stream token: the point to get updates up to
target_row_count: a target for the number of rows to be returned.
Returns: a triplet `(updates, new_last_token, limited)`, where:
* `updates` is a list of `(token, row)` entries.
* `new_last_token` is the new position in stream.
* `limited` is whether there are more updates to fetch.
"""
# TODO: Handle limit.
# TODO: Handle target_row_count.
# To handle restarts where we wrap around
if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
# list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream.
rows = []
# There should be only one reader, so lets delete everything its
# acknowledged its seen.
if federation_ack:
self._clear_queue_before_pos(federation_ack)
rows = [] # type: List[Tuple[int, BaseFederationRow]]
# Fetch changed presence
i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
# Sort rows based on pos
rows.sort()
return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
return (
[(pos, (row.TypeId, row.to_data())) for pos, row in rows],
to_token,
False,
)
class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
@staticmethod
def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
TypeToRow = {
Row.TypeId: Row
for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
}
_rowtypes = (
PresenceRow,
PresenceDestinationsRow,
KeyedEduRow,
EduRow,
) # type: Tuple[Type[BaseFederationRow], ...]
TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
ParsedFederationStreamData = namedtuple(

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Dict, Hashable, Iterable, List, Optional, Set
from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
from six import itervalues
@@ -498,14 +498,16 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self) -> int:
@staticmethod
def get_current_token() -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
@staticmethod
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return []
return [], 0, False

View File

@@ -15,11 +15,10 @@
# limitations under the License.
import datetime
import logging
from typing import Dict, Hashable, Iterable, List, Tuple
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter
import synapse.server
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
if TYPE_CHECKING:
import synapse.server
# This is defined in the Matrix spec and enforced by the receiver.
MAX_EDUS_PER_TRANSACTION = 100

View File

@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List
from typing import TYPE_CHECKING, List
from canonicaljson import json
import synapse.server
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
)
from synapse.util.metrics import measure_func
if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)

View File

@@ -126,13 +126,13 @@ class AuthHandler(BaseHandler):
# It notifies the user they are about to give access to their matrix account
# to the client.
self._sso_redirect_confirm_template = load_jinja2_templates(
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
)[0]
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
self._sso_auth_confirm_template = load_jinja2_templates(
hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"],
hs.config.sso_template_dir, ["sso_auth_confirm.html"],
)[0]
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.

View File

@@ -125,10 +125,9 @@ class FederationHandler(BaseHandler):
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
self._instance_name = hs.get_instance_name()
self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
hs
)
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
hs
)
@@ -2681,8 +2680,7 @@ class FederationHandler(BaseHandler):
member_handler = self.hs.get_room_member_handler()
await member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks
def add_display_name_to_third_party_invite(
async def add_display_name_to_third_party_invite(
self, room_version, event_dict, event, context
):
key = (
@@ -2690,10 +2688,10 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite = await self.store.get_event(
original_invite_id, allow_none=True
)
if original_invite:
@@ -2714,14 +2712,13 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
EventValidator().validate_new(event, self.config)
return (event, context)
@defer.inlineCallbacks
def _check_signature(self, event, context):
async def _check_signature(self, event, context):
"""
Checks that the signature in the event is consistent with its invite.
@@ -2738,12 +2735,12 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None
if invite_event_id:
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
invite_event = await self.store.get_event(invite_event_id, allow_none=True)
if not invite_event:
raise AuthError(403, "Could not find invite")
@@ -2792,7 +2789,7 @@ class FederationHandler(BaseHandler):
raise
try:
if "key_validity_url" in public_key_object:
yield self._check_key_revocation(
await self._check_key_revocation(
public_key, public_key_object["key_validity_url"]
)
except Exception:
@@ -2806,8 +2803,7 @@ class FederationHandler(BaseHandler):
last_exception = e
raise last_exception
@defer.inlineCallbacks
def _check_key_revocation(self, public_key, url):
async def _check_key_revocation(self, public_key, url):
"""
Checks whether public_key has been revoked.
@@ -2821,7 +2817,7 @@ class FederationHandler(BaseHandler):
for revocation.
"""
try:
response = yield self.http_client.get_json(url, {"public_key": public_key})
response = await self.http_client.get_json(url, {"public_key": public_key})
except Exception:
raise SynapseError(502, "Third party certificate could not be checked")
if "valid" not in response or not response["valid"]:
@@ -2840,8 +2836,9 @@ class FederationHandler(BaseHandler):
backfilled: Whether these events are a result of
backfilling or not
"""
if self.config.worker_app:
await self._send_events_to_master(
if self.config.worker.writers.events != self._instance_name:
await self._send_events(
instance_name=self.config.worker.writers.events,
store=self.store,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
@@ -2916,8 +2913,7 @@ class FederationHandler(BaseHandler):
else:
user_joined_room(self.distributor, user, room_id)
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):
async def get_room_complexity(self, remote_room_hosts, room_id):
"""
Fetch the complexity of a remote room over federation.
@@ -2931,12 +2927,12 @@ class FederationHandler(BaseHandler):
"""
for host in remote_room_hosts:
res = yield self.federation_client.get_room_complexity(host, room_id)
res = await self.federation_client.get_room_complexity(host, room_id)
# We got a result, return it.
if res:
defer.returnValue(res)
return res
# We fell off the bottom, couldn't get the complexity from anyone. Oh
# well.
defer.returnValue(None)
return None

View File

@@ -72,7 +72,6 @@ class MessageHandler(object):
self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
self._is_worker_app = bool(hs.config.worker_app)
# The scheduled call to self._expire_event. None if no call is currently
# scheduled.
@@ -260,7 +259,6 @@ class MessageHandler(object):
Args:
event (EventBase): The event to schedule the expiry of.
"""
assert not self._is_worker_app
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
if not isinstance(expiry_ts, int) or event.is_state():
@@ -367,10 +365,11 @@ class EventCreationHandler(object):
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
self._instance_name = hs.get_instance_name()
self.room_invite_state_types = self.hs.config.room_invite_state_types
self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
@@ -824,8 +823,9 @@ class EventCreationHandler(object):
success = False
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
await self.send_event_to_master(
if self.config.worker.writers.events != self._instance_name:
await self.send_event(
instance_name=self.config.worker.writers.events,
event_id=event.event_id,
store=self.store,
requester=requester,
@@ -890,7 +890,7 @@ class EventCreationHandler(object):
This should only be run on master.
"""
assert not self.config.worker_app
assert self.config.worker.writers.events == self._instance_name
if ratelimit:
# We check if this is a room admin redacting an event so that we

View File

@@ -0,0 +1,998 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
import attr
import pymacaroons
from authlib.common.security import generate_token
from authlib.jose import JsonWebToken
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
MacaroonDeserializationException,
MacaroonInvalidSignatureException,
)
from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.push.mailer import load_jinja2_templates
from synapse.server import HomeServer
from synapse.types import UserID, map_username_to_mxid_localpart
logger = logging.getLogger(__name__)
SESSION_COOKIE_NAME = b"oidc_session"
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
#: OpenID.Core sec 3.1.3.3.
Token = TypedDict(
"Token",
{
"access_token": str,
"token_type": str,
"id_token": Optional[str],
"refresh_token": Optional[str],
"expires_in": int,
"scope": Optional[str],
},
)
#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
#: there is no real point of doing this in our case.
JWK = Dict[str, str]
#: A JWK Set, as per RFC7517 sec 5.
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint
"""
def __init__(self, error, error_description=None):
self.error = error
self.error_description = error_description
def __str__(self):
if self.error_description:
return "{}: {}".format(self.error, self.error_description)
return self.error
class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: HomeServer):
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._client_auth = ClientAuth(
hs.config.oidc_client_id,
hs.config.oidc_client_secret,
hs.config.oidc_client_auth_method,
) # type: ClientAuth
self._client_auth_method = hs.config.oidc_client_auth_method # type: str
self._subject_claim = hs.config.oidc_subject_claim
self._provider_metadata = OpenIDProviderMetadata(
issuer=hs.config.oidc_issuer,
authorization_endpoint=hs.config.oidc_authorization_endpoint,
token_endpoint=hs.config.oidc_token_endpoint,
userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
jwks_uri=hs.config.oidc_jwks_uri,
) # type: OpenIDProviderMetadata
self._provider_needs_discovery = hs.config.oidc_discover # type: bool
self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
hs.config.oidc_user_mapping_provider_config
) # type: OidcMappingProvider
self._skip_verification = hs.config.oidc_skip_verification # type: bool
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._datastore = hs.get_datastore()
self._clock = hs.get_clock()
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = load_jinja2_templates(
hs.config.sso_template_dir, ["sso_error.html"]
)[0]
# identifier for the external_ids table
self._auth_provider_id = "oidc"
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Renders the error template and respond with it.
This is used to show errors to the user. The template of this page can
be found under ``synapse/res/templates/sso_error.html``.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error. Those include
well-known OAuth2/OIDC error types like invalid_request or
access_denied.
error_description: A human-readable description of the error.
"""
html_bytes = self._error_template.render(
error=error, error_description=error_description
).encode("utf-8")
request.setResponseCode(400)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % len(html_bytes))
request.write(html_bytes)
finish_request(request)
def _validate_metadata(self):
"""Verifies the provider metadata.
This checks the validity of the currently loaded provider. Not
everything is checked, only:
- ``issuer``
- ``authorization_endpoint``
- ``token_endpoint``
- ``response_types_supported`` (checks if "code" is in it)
- ``jwks_uri``
Raises:
ValueError: if something in the provider is not valid
"""
# Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
if self._skip_verification is True:
return
m = self._provider_metadata
m.validate_issuer()
m.validate_authorization_endpoint()
m.validate_token_endpoint()
if m.get("token_endpoint_auth_methods_supported") is not None:
m.validate_token_endpoint_auth_methods_supported()
if (
self._client_auth_method
not in m["token_endpoint_auth_methods_supported"]
):
raise ValueError(
'"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format(
auth_method=self._client_auth_method,
supported=m["token_endpoint_auth_methods_supported"],
)
)
if m.get("response_types_supported") is not None:
m.validate_response_types_supported()
if "code" not in m["response_types_supported"]:
raise ValueError(
'"code" not in "response_types_supported" (%r)'
% (m["response_types_supported"],)
)
# If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
if self._uses_userinfo:
if m.get("userinfo_endpoint") is None:
raise ValueError(
'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
if m.get("jwks") is None:
if m.get("jwks_uri") is not None:
m.validate_jwks_uri()
else:
raise ValueError('"jwks_uri" must be set')
@property
def _uses_userinfo(self) -> bool:
"""Returns True if the ``userinfo_endpoint`` should be used.
This is based on the requested scopes: if the scopes include
``openid``, the provider should give use an ID token containing the
user informations. If not, we should fetch them using the
``access_token`` with the ``userinfo_endpoint``.
"""
# Maybe that should be user-configurable and not inferred?
return "openid" not in self._scopes
async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
The values metadatas are discovered if ``oidc_config.discovery`` is
``True`` and then cached.
Raises:
ValueError: if something in the provider is not valid
Returns:
The provider's metadata.
"""
# If we are using the OpenID Discovery documents, it needs to be loaded once
# FIXME: should there be a lock here?
if self._provider_needs_discovery:
url = get_well_known_url(self._provider_metadata["issuer"], external=True)
metadata_response = await self._http_client.get_json(url)
# TODO: maybe update the other way around to let user override some values?
self._provider_metadata.update(metadata_response)
self._provider_needs_discovery = False
self._validate_metadata()
return self._provider_metadata
async def load_jwks(self, force: bool = False) -> JWKS:
"""Load the JSON Web Key Set used to sign ID tokens.
If we're not using the ``userinfo_endpoint``, user infos are extracted
from the ID token, which is a JWT signed by keys given by the provider.
The keys are then cached.
Args:
force: Force reloading the keys.
Returns:
The key set
Looks like this::
{
'keys': [
{
'kid': 'abcdef',
'kty': 'RSA',
'alg': 'RS256',
'use': 'sig',
'e': 'XXXX',
'n': 'XXXX',
}
]
}
"""
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}
# First check if the JWKS are loaded in the provider metadata.
# It can happen either if the provider gives its JWKS in the discovery
# document directly or if it was already loaded once.
metadata = await self.load_metadata()
jwk_set = metadata.get("jwks")
if jwk_set is not None and not force:
return jwk_set
# Loading the JWKS using the `jwks_uri` metadata
uri = metadata.get("jwks_uri")
if not uri:
raise RuntimeError('Missing "jwks_uri" in metadata')
jwk_set = await self._http_client.get_json(uri)
# Caching the JWKS in the provider's metadata
self._provider_metadata["jwks"] = jwk_set
return jwk_set
async def _exchange_code(self, code: str) -> Token:
"""Exchange an authorization code for a token.
This calls the ``token_endpoint`` with the authorization code we
received in the callback to exchange it for a token. The call uses the
``ClientAuth`` to authenticate with the client with its ID and secret.
Args:
code: The autorization code we got from the callback.
Returns:
A dict containing various tokens.
May look like this::
{
'token_type': 'bearer',
'access_token': 'abcdef',
'expires_in': 3599,
'id_token': 'ghijkl',
'refresh_token': 'mnopqr',
}
Raises:
OidcError: when the ``token_endpoint`` returned an error.
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"Accept": "application/json",
}
args = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self._callback_url,
}
body = urlencode(args, True)
# Fill the body/headers with credentials
uri, headers, body = self._client_auth.prepare(
method="POST", uri=token_endpoint, headers=headers, body=body
)
headers = {k: [v] for (k, v) in headers.items()}
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
# check the HTTP status code and we do the body encoding ourself.
response = await self._http_client.request(
method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
)
# This is used in multiple error messages below
status = "{code} {phrase}".format(
code=response.code, phrase=response.phrase.decode("utf-8")
)
resp_body = await readBody(response)
if response.code >= 500:
# In case of a server error, we should first try to decode the body
# and check for an error field. If not, we respond with a generic
# error message.
try:
resp = json.loads(resp_body.decode("utf-8"))
error = resp["error"]
description = resp.get("error_description", error)
except (ValueError, KeyError):
# Catch ValueError for the JSON decoding and KeyError for the "error" field
error = "server_error"
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=status),
)
raise OidcError(error, description)
# Since it is a not a 5xx code, body should be a valid JSON. It will
# raise if not.
resp = json.loads(resp_body.decode("utf-8"))
if "error" in resp:
error = resp["error"]
# In case the authorization server responded with an error field,
# it should be a 4xx code. If not, warn about it but don't do
# anything special and report the original error message.
if response.code < 400:
logger.debug(
"Invalid response from the authorization server: "
'responded with a "{status}" '
"but body has an error field: {error!r}".format(
status=status, error=resp["error"]
)
)
description = resp.get("error_description", error)
raise OidcError(error, description)
# Now, this should not be an error. According to RFC6749 sec 5.1, it
# should be a 200 code. We're a bit more flexible than that, and will
# only throw on a 4xx code.
if response.code >= 400:
description = (
'Authorization server responded with a "{status}" error '
'but did not include an "error" field in its response.'.format(
status=status
)
)
logger.warning(description)
# Body was still valid JSON. Might be useful to log it for debugging.
logger.warning("Code exchange response: {resp!r}".format(resp=resp))
raise OidcError("server_error", description)
return resp
async def _fetch_userinfo(self, token: Token) -> UserInfo:
"""Fetch user informations from the ``userinfo_endpoint``.
Args:
token: the token given by the ``token_endpoint``.
Must include an ``access_token`` field.
Returns:
UserInfo: an object representing the user.
"""
metadata = await self.load_metadata()
resp = await self._http_client.get_json(
metadata["userinfo_endpoint"],
headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
)
return UserInfo(resp)
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
"""Return an instance of UserInfo from token's ``id_token``.
Args:
token: the token given by the ``token_endpoint``.
Must include an ``id_token`` field.
nonce: the nonce value originally sent in the initial authorization
request. This value should match the one inside the token.
Returns:
An object representing the user.
"""
metadata = await self.load_metadata()
claims_params = {
"nonce": nonce,
"client_id": self._client_auth.client_id,
}
if "access_token" in token:
# If we got an `access_token`, there should be an `at_hash` claim
# in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"]
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
claim_options = {"iss": {"values": [metadata["issuer"]]}}
# Try to decode the keys in cache first, then retry by forcing the keys
# to be reloaded
jwk_set = await self.load_jwks()
try:
claims = jwt.decode(
token["id_token"],
key=jwk_set,
claims_cls=claims_cls,
claims_options=claim_options,
claims_params=claims_params,
)
except ValueError:
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
token["id_token"],
key=jwk_set,
claims_cls=claims_cls,
claims_options=claim_options,
claims_params=claims_params,
)
claims.validate(leeway=120) # allows 2 min of clock skew
return UserInfo(claims)
async def handle_redirect_request(
self, request: SynapseRequest, client_redirect_url: bytes
) -> None:
"""Handle an incoming request to /login/sso/redirect
It redirects the browser to the authorization endpoint with a few
parameters:
- ``client_id``: the client ID set in ``oidc_config.client_id``
- ``response_type``: ``code``
- ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
- ``scope``: the list of scopes set in ``oidc_config.scopes``
- ``state``: a random string
- ``nonce``: a random string
In addition to redirecting the client, we are setting a cookie with
a signed macaroon token containing the state, the nonce and the
client_redirect_url params. Those are then checked when the client
comes back from the provider.
Args:
request: the incoming request from the browser.
We'll respond to it with a redirect and a cookie.
client_redirect_url: the URL that we should redirect the client to
when everything is done
"""
state = generate_token()
nonce = generate_token()
cookie = self._generate_oidc_session_token(
state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
)
request.addCookie(
SESSION_COOKIE_NAME,
cookie,
path="/_synapse/oidc",
max_age="3600",
httpOnly=True,
sameSite="lax",
)
metadata = await self.load_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
uri = prepare_grant_uri(
authorization_endpoint,
client_id=self._client_auth.client_id,
response_type="code",
redirect_uri=self._callback_url,
scope=self._scopes,
state=state,
nonce=nonce,
)
request.redirect(uri)
finish_request(request)
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
- first, we check if there was any error returned by the provider and
display it
- then we fetch the session cookie, decode and verify it
- the ``state`` query parameter should match with the one stored in the
session cookie
- once we known this session is legit, exchange the code with the
provider using the ``token_endpoint`` (see ``_exchange_code``)
- once we have the token, use it to either extract the UserInfo from
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
to fetch UserInfo from the ``userinfo_endpoint``
(``_fetch_userinfo``)
- map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
finish the login
Args:
request: the incoming request from the browser.
"""
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
# Most of the errors returned by the provider could be due by
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._render_error(request, error, description)
return
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME)
if session is None:
logger.info("No session cookie found")
self._render_error(request, "missing_session", "No session cookie found")
return
# Remove the cookie. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
# Removing it early avoids spamming the provider with token requests.
request.addCookie(
SESSION_COOKIE_NAME,
b"",
path="/_synapse/oidc",
expires="Thu, Jan 01 1970 00:00:00 UTC",
httpOnly=True,
sameSite="lax",
)
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._render_error(request, "invalid_request", "State parameter is missing")
return
state = request.args[b"state"][0].decode()
# Deserialize the session token and verify it.
try:
nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._render_error(request, "invalid_request", "Code parameter is missing")
return
logger.info("Exchanging code")
code = request.args[b"code"][0].decode()
try:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
self._render_error(request, e.error, e.error_description)
return
# Now that we have a token, get the userinfo, either by decoding the
# `id_token` or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
logger.info("Fetching userinfo")
try:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
self._render_error(request, "fetch_error", str(e))
return
else:
logger.info("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._render_error(request, "invalid_token", str(e))
return
# Call the mapper to register/login the user
try:
user_id = await self._map_userinfo_to_user(userinfo, token)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
return
# and finally complete the login
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url
)
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.
Args:
state: The ``state`` parameter passed to the OIDC provider.
nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the
flow.
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
Returns:
A signed macaroon token with the session informations.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,)
)
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.
Args:
session: The session token to verify
state: The state the OIDC provider gave back
Returns:
The nonce and the client_redirect_url for this session
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce` and `client_redirect_url` from the token
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
return nonce, client_redirect_url
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise Exception("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
"""Maps a UserInfo object to a mxid.
UserInfo should have a claim that uniquely identifies users. This claim
is usually `sub`, but can be configured with `oidc_config.subject_claim`.
It is then used as an `external_id`.
If we don't find the user that way, we should register the user,
mapping the localpart and the display name from the UserInfo.
If a user already exists with the mxid we've mapped, raise an exception.
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
Raises:
MappingException: if there was an error while mapping some properties
Returns:
The mxid of the user
"""
try:
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
except Exception as e:
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
logger.info(
"Looking for existing mapping for user %s:%s",
self._auth_provider_id,
remote_user_id,
)
registered_user_id = await self._datastore.get_user_by_external_id(
self._auth_provider_id, remote_user_id,
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id
try:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token
)
except Exception as e:
raise MappingException(
"Could not extract user attributes from OIDC response: " + str(e)
)
logger.debug(
"Retrieved user attributes from user mapping provider: %r", attributes
)
if not attributes["localpart"]:
raise MappingException("localpart is empty")
localpart = map_username_to_mxid_localpart(attributes["localpart"])
user_id = UserID(localpart, self._hostname)
if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
# This mxid is taken
raise MappingException(
"mxid '{}' is already taken".format(user_id.to_string())
)
# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=attributes["display_name"],
)
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
return registered_user_id
UserAttribute = TypedDict(
"UserAttribute", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
class OidcMappingProvider(Generic[C]):
"""A mapping provider maps a UserInfo object to user attributes.
It should provide the API described by this class.
"""
def __init__(self, config: C):
"""
Args:
config: A custom config object from this module, parsed by ``parse_config()``
"""
@staticmethod
def parse_config(config: dict) -> C:
"""Parse the dict provided by the homeserver's config
Args:
config: A dictionary containing configuration options for this provider
Returns:
A custom config object for this module
"""
raise NotImplementedError()
def get_remote_user_id(self, userinfo: UserInfo) -> str:
"""Get a unique user ID for this user.
Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object.
Args:
userinfo: An object representing the user given by the OIDC provider
Returns:
A unique user ID
"""
raise NotImplementedError()
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
"""Map a ``UserInfo`` objects into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
"""
raise NotImplementedError()
# Used to clear out "None" values in templates
def jinja_finalize(thing):
return thing if thing is not None else ""
env = Environment(finalize=jinja_finalize)
@attr.s
class JinjaOidcMappingConfig:
subject_claim = attr.ib() # type: str
localpart_template = attr.ib() # type: Template
display_name_template = attr.ib() # type: Optional[Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
"""An implementation of a mapping provider based on Jinja templates.
This is the default mapping provider.
"""
def __init__(self, config: JinjaOidcMappingConfig):
self._config = config
@staticmethod
def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub")
if "localpart_template" not in config:
raise ConfigError(
"missing key: oidc_config.user_mapping_provider.config.localpart_template"
)
try:
localpart_template = env.from_string(config["localpart_template"])
except Exception as e:
raise ConfigError(
"invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
% (e,)
)
display_name_template = None # type: Optional[Template]
if "display_name_template" in config:
try:
display_name_template = env.from_string(config["display_name_template"])
except Exception as e:
raise ConfigError(
"invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
% (e,)
)
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
return userinfo[self._config.subject_claim]
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
localpart = self._config.localpart_template.render(user=userinfo).strip()
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
user=userinfo
).strip()
if display_name == "":
display_name = None
return UserAttribute(localpart=localpart, display_name=display_name)

View File

@@ -25,8 +25,6 @@ from collections import OrderedDict
from six import iteritems, string_types
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
@defer.inlineCallbacks
def upgrade_room(
async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
"""Replace a room with a new room with a different version
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
Returns:
Deferred[unicode]: the new room id
"""
yield self.ratelimit(requester)
await self.ratelimit(requester)
user_id = requester.user.to_string()
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
# If this user has sent multiple upgrade requests for the same room
# and one of them is not complete yet, cache the response and
# return it to all subsequent requests
ret = yield self._upgrade_response_cache.wrap(
ret = await self._upgrade_response_cache.wrap(
(old_room_id, user_id),
self._upgrade_room,
requester,
@@ -856,8 +853,7 @@ class RoomCreationHandler(BaseHandler):
for (etype, state_key), content in initial_state.items():
await send(etype=etype, state_key=state_key, content=content)
@defer.inlineCallbacks
def _generate_room_id(
async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
@@ -869,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
gen_room_id = gen_room_id.decode("utf-8")
yield self.store.store_room(
await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
is_public=is_public,
@@ -888,8 +884,7 @@ class RoomContextHandler(object):
self.storage = hs.get_storage()
self.state_store = self.storage.state
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, event_filter):
async def get_event_context(self, user, room_id, event_id, limit, event_filter):
"""Retrieves events, pagination tokens and state around a given event
in a room.
@@ -908,7 +903,7 @@ class RoomContextHandler(object):
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
users = yield self.store.get_users_in_room(room_id)
users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
def filter_evts(events):
@@ -916,17 +911,17 @@ class RoomContextHandler(object):
self.storage, user.to_string(), events, is_peeking=is_peeking
)
event = yield self.store.get_event(
event = await self.store.get_event(
event_id, get_prev_content=True, allow_none=True
)
if not event:
return None
filtered = yield (filter_evts([event]))
filtered = await filter_evts([event])
if not filtered:
raise AuthError(403, "You don't have permission to access that event.")
results = yield self.store.get_events_around(
results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
@@ -934,8 +929,8 @@ class RoomContextHandler(object):
results["events_before"] = event_filter.filter(results["events_before"])
results["events_after"] = event_filter.filter(results["events_after"])
results["events_before"] = yield filter_evts(results["events_before"])
results["events_after"] = yield filter_evts(results["events_after"])
results["events_before"] = await filter_evts(results["events_before"])
results["events_after"] = await filter_evts(results["events_after"])
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
@@ -962,7 +957,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.state_store.get_state_for_events(
state = await self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
@@ -970,7 +965,7 @@ class RoomContextHandler(object):
if event_filter:
state_events = event_filter.filter(state_events)
results["state"] = yield filter_evts(state_events)
results["state"] = await filter_evts(state_events)
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
@@ -989,13 +984,12 @@ class RoomEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_new_events(
async def get_new_events(
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
to_key = yield self.get_current_key()
to_key = await self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
@@ -1008,11 +1002,11 @@ class RoomEventSource(object):
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
room_events = yield self.store.get_membership_changes_for_user(
room_events = await self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
)
room_to_events = yield self.store.get_room_events_stream_for_rooms(
room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=room_ids,
from_key=from_key,
to_key=to_key,

View File

@@ -875,8 +875,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
async def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
"""
Check if complexity of a remote room is too great.
@@ -888,7 +887,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if unable to be fetched
"""
max_complexity = self.hs.config.limit_remote_rooms.complexity
complexity = yield self.federation_handler.get_room_complexity(
complexity = await self.federation_handler.get_room_complexity(
remote_room_hosts, room_id
)

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
from typing import Optional, Tuple
from typing import Callable, Dict, Optional, Set, Tuple
import attr
import saml2
@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
from synapse.module_api.errors import RedirectException
from synapse.types import (
@@ -81,17 +82,19 @@ class SamlHandler:
self._error_html_content = hs.config.saml2_error_html_content
def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
) -> bytes:
"""Handle an incoming request to /login/sso/redirect
Args:
client_redirect_url (bytes): the URL that we should redirect the
client_redirect_url: the URL that we should redirect the
client to when everything is done
ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
bytes: URL to redirect to
URL to redirect to
"""
reqid, info = self._saml_client.prepare_for_authenticate(
relay_state=client_redirect_url
@@ -109,15 +112,15 @@ class SamlHandler:
# this shouldn't happen!
raise Exception("prepare_for_authenticate didn't return a Location header")
async def handle_saml_response(self, request):
async def handle_saml_response(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_matrix/saml2/authn_response
Args:
request (SynapseRequest): the incoming request from the browser. We'll
request: the incoming request from the browser. We'll
respond to it with a redirect.
Returns:
Deferred[none]: Completes once we have handled the request.
Completes once we have handled the request.
"""
resp_bytes = parse_string(request, "SAMLResponse", required=True)
relay_state = parse_string(request, "RelayState", required=True)
@@ -310,6 +313,7 @@ DOT_REPLACE_PATTERN = re.compile(
def dot_replace_for_mxid(username: str) -> str:
"""Replace any characters which are not allowed in Matrix IDs with a dot."""
username = username.lower()
username = DOT_REPLACE_PATTERN.sub(".", username)
@@ -321,7 +325,7 @@ def dot_replace_for_mxid(username: str) -> str:
MXID_MAPPER_MAP = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
}
} # type: Dict[str, Callable[[str], str]]
@attr.s
@@ -349,7 +353,7 @@ class DefaultSamlMappingProvider(object):
def get_remote_user_id(
self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
):
) -> str:
"""Extracts the remote user id from the SAML response"""
try:
return saml_response.ava["uid"][0]
@@ -428,14 +432,14 @@ class DefaultSamlMappingProvider(object):
return SamlConfig(mxid_source_attribute, mxid_mapper)
@staticmethod
def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
"""Returns the required attributes of a SAML
Args:
config: A SamlConfig object containing configuration params for this provider
Returns:
tuple[set,set]: The first set equates to the saml auth response
The first set equates to the saml auth response
attributes that are required for the module to function, whereas the
second set consists of those attributes which can be used if
available, but are not necessary

View File

@@ -18,8 +18,6 @@ import logging
from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
@@ -39,8 +37,7 @@ class SearchHandler(BaseHandler):
self.state_store = self.storage.state
self.auth = hs.get_auth()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
async def get_old_rooms_from_upgraded_room(self, room_id):
"""Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a
@@ -60,7 +57,7 @@ class SearchHandler(BaseHandler):
historical_room_ids = []
# The initial room must have been known for us to get this far
predecessor = yield self.store.get_room_predecessor(room_id)
predecessor = await self.store.get_room_predecessor(room_id)
while True:
if not predecessor:
@@ -75,7 +72,7 @@ class SearchHandler(BaseHandler):
# Don't add it to the list until we have checked that we are in the room
try:
next_predecessor_room = yield self.store.get_room_predecessor(
next_predecessor_room = await self.store.get_room_predecessor(
predecessor_room_id
)
except NotFoundError:
@@ -89,8 +86,7 @@ class SearchHandler(BaseHandler):
return historical_room_ids
@defer.inlineCallbacks
def search(self, user, content, batch=None):
async def search(self, user, content, batch=None):
"""Performs a full text search for a user.
Args:
@@ -179,7 +175,7 @@ class SearchHandler(BaseHandler):
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
@@ -192,7 +188,7 @@ class SearchHandler(BaseHandler):
historical_room_ids = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = yield self.get_old_rooms_from_upgraded_room(room_id)
ids = await self.get_old_rooms_from_upgraded_room(room_id)
historical_room_ids += ids
# Prevent any historical events from being filtered
@@ -223,7 +219,7 @@ class SearchHandler(BaseHandler):
count = None
if order_by == "rank":
search_result = yield self.store.search_msgs(room_ids, search_term, keys)
search_result = await self.store.search_msgs(room_ids, search_term, keys)
count = search_result["count"]
@@ -238,7 +234,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)
@@ -267,7 +263,7 @@ class SearchHandler(BaseHandler):
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
search_result = yield self.store.search_rooms(
search_result = await self.store.search_rooms(
room_ids,
search_term,
keys,
@@ -288,7 +284,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)
@@ -343,11 +339,11 @@ class SearchHandler(BaseHandler):
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
now_token = yield self.hs.get_event_sources().get_current_token()
now_token = await self.hs.get_event_sources().get_current_token()
contexts = {}
for event in allowed_events:
res = yield self.store.get_events_around(
res = await self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
)
@@ -357,11 +353,11 @@ class SearchHandler(BaseHandler):
len(res["events_after"]),
)
res["events_before"] = yield filter_events_for_client(
res["events_before"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_before"]
)
res["events_after"] = yield filter_events_for_client(
res["events_after"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_after"]
)
@@ -390,7 +386,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders]
)
state = yield self.state_store.get_state_for_event(
state = await self.state_store.get_state_for_event(
last_event_id, state_filter
)
@@ -412,10 +408,10 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
context["events_before"] = yield self._event_serializer.serialize_events(
context["events_before"] = await self._event_serializer.serialize_events(
context["events_before"], time_now
)
context["events_after"] = yield self._event_serializer.serialize_events(
context["events_after"] = await self._event_serializer.serialize_events(
context["events_after"], time_now
)
@@ -423,7 +419,7 @@ class SearchHandler(BaseHandler):
if include_state:
rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
state = yield self.state_handler.get_current_state(room_id)
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
state_results.values()
@@ -437,7 +433,7 @@ class SearchHandler(BaseHandler):
{
"rank": rank_map[e.event_id],
"result": (
yield self._event_serializer.serialize_event(e, time_now)
await self._event_serializer.serialize_event(e, time_now)
),
"context": contexts.get(e.event_id, {}),
}
@@ -452,7 +448,7 @@ class SearchHandler(BaseHandler):
if state_results:
s = {}
for room_id, state in state_results.items():
s[room_id] = yield self._event_serializer.serialize_events(
s[room_id] = await self._event_serializer.serialize_events(
state, time_now
)

View File

@@ -49,7 +49,6 @@ from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util.async_helpers import timeout_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
logger = logging.getLogger(__name__)
@@ -241,7 +240,10 @@ class SimpleHttpClient(object):
# tends to do so in batches, so we need to allow the pool to keep
# lots of idle connections around.
pool = HTTPConnectionPool(self.reactor)
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
# XXX: The justification for using the cache factor here is that larger instances
# will need both more cache and more connections.
# Still, this should probably be a separate dial
pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
pool.cachedConnectionTimeout = 2 * 60
self.agent = ProxyAgent(
@@ -359,6 +361,7 @@ class SimpleHttpClient(object):
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
@@ -399,6 +402,7 @@ class SimpleHttpClient(object):
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
@@ -434,6 +438,10 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON
"""
actual_headers = {b"Accept": [b"application/json"]}
if headers:
actual_headers.update(headers)
body = yield self.get_raw(uri, args, headers=headers)
return json.loads(body)
@@ -467,6 +475,7 @@ class SimpleHttpClient(object):
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)

View File

@@ -33,6 +33,8 @@ from prometheus_client import REGISTRY
from twisted.web.resource import Resource
from synapse.util import caches
try:
from prometheus_client.samples import Sample
except ImportError:
@@ -103,13 +105,15 @@ def nameify_sample(sample):
def generate_latest(registry, emit_help=False):
# Trigger the cache metrics to be rescraped, which updates the common
# metrics but do not produce metrics themselves
for collector in caches.collectors_by_name.values():
collector.collect()
output = []
for metric in registry.collect():
if metric.name.startswith("__unused"):
continue
if not metric.samples:
# No samples, don't bother.
continue

View File

@@ -51,6 +51,7 @@ push_rules_delta_state_cache_metric = register_cache(
"cache",
"push_rules_delta_state_cache_metric",
cache=[], # Meaningless size, as this isn't a cache that stores values
resizable=False,
)
@@ -67,7 +68,8 @@ class BulkPushRuleEvaluator(object):
self.room_push_rule_cache_metrics = register_cache(
"cache",
"room_push_rule_cache",
cache=[], # Meaningless size, as this isn't a cache that stores values
cache=[], # Meaningless size, as this isn't a cache that stores values,
resizable=False,
)
@defer.inlineCallbacks

View File

@@ -22,7 +22,7 @@ from six import string_types
from synapse.events import EventBase
from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -165,7 +165,7 @@ class PushRuleEvaluatorForEvent(object):
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
regex_cache = LruCache(50000)
register_cache("cache", "regex_push_cache", regex_cache)

View File

@@ -92,6 +92,7 @@ CONDITIONAL_REQUIREMENTS = {
'eliot<1.8.0;python_version<"3.5.3"',
],
"saml2": ["pysaml2>=4.5.0"],
"oidc": ["authlib>=0.14.0"],
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
"test": ["mock>=2.0", "parameterized"],

View File

@@ -34,9 +34,11 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs):
send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)
federation.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
streams.register_servlets(hs, self)
if hs.config.worker.worker_app is None:
membership.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
streams.register_servlets(hs, self)

View File

@@ -141,17 +141,29 @@ class ReplicationEndpoint(object):
Returns a callable that accepts the same parameters as `_serialize_payload`.
"""
clock = hs.get_clock()
host = hs.config.worker_replication_host
port = hs.config.worker_replication_http_port
client = hs.get_simple_http_client()
local_instance_name = hs.get_instance_name()
master_host = hs.config.worker_replication_host
master_port = hs.config.worker_replication_http_port
instance_map = hs.config.worker.instance_map
@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
def send_request(instance_name="master", **kwargs):
# Currently we only support sending requests to master process.
if instance_name != "master":
raise Exception("Unknown instance")
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master":
host = master_host
port = master_port
elif instance_name in instance_map:
host = instance_map[instance_name].host
port = instance_map[instance_name].port
else:
raise Exception(
"Instance %r not in 'instance_map' config" % (instance_name,)
)
data = yield cls._serialize_payload(**kwargs)

View File

@@ -52,9 +52,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
self._instance_name = hs.get_instance_name()
# We pull the streams from the replication steamer (if we try and make
# We pull the streams from the replication handler (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
self.streams = hs.get_tcp_replication().get_streams()
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token):

View File

@@ -18,14 +18,10 @@ from typing import Optional
import six
from synapse.storage.data_stores.main.cache import (
CURRENT_STATE_CACHE_NAME,
CacheInvalidationWorkerStore,
)
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage.util.id_generators import MultiWriterIdGenerator
logger = logging.getLogger(__name__)
@@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id"
) # type: Optional[SlavedIdTracker]
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
) # type: Optional[MultiWriterIdGenerator]
else:
self._cache_id_gen = None
self.hs = hs
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:
self._cache_id_gen.advance(token)
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
raise Exception(
"Can't send an 'invalidate all' for current state cache"
)
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
txn.call_after(self._send_invalidation_poke, cache_func, keys)
def _send_invalidation_poke(self, cache_func, keys):
self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)

View File

@@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token)
for row in rows:
@@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedAccountDataStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -15,7 +15,6 @@
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.storage.database import Database
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
from ._base import BaseSlavedStore
@@ -26,7 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore):
super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
name="client_ip_last_seen", keylen=4, max_entries=50000
)
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):

View File

@@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "to_device":
self._device_inbox_id_gen.advance(token)
for row in rows:
@@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token
)
return super(SlavedDeviceInboxStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
self._invalidate_caches_for_devices(token, rows)
@@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_caches_for_devices(self, token, rows):
for row in rows:

View File

@@ -15,11 +15,6 @@
# limitations under the License.
import logging
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
from synapse.storage.data_stores.main.event_push_actions import (
EventPushActionsWorkerStore,
@@ -35,7 +30,6 @@ from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__)
@@ -62,11 +56,6 @@ class SlavedEventStore(
BaseSlavedStore,
):
def __init__(self, database: Database, db_conn, hs):
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
@@ -92,83 +81,3 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)
for row in rows:
self._process_event_stream_row(token, row)
elif stream_name == "backfill":
self._backfill_id_gen.advance(-token)
for row in rows:
self.invalidate_caches_for_event(
-token,
row.event_id,
row.room_id,
row.type,
row.state_key,
row.redacts,
row.relates_to,
backfilled=True,
)
return super(SlavedEventStore, self).process_replication_rows(
stream_name, token, rows
)
def _process_event_stream_row(self, token, row):
data = row.data
if row.type == EventsStreamEventRow.TypeId:
self.invalidate_caches_for_event(
token,
data.event_id,
data.room_id,
data.type,
data.state_key,
data.redacts,
data.relates_to,
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed(
row.data.room_id, token
)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
(data.state_key,)
)
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
def invalidate_caches_for_event(
self,
stream_ordering,
event_id,
room_id,
etype,
state_key,
redacts,
relates_to,
backfilled,
):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
if redacts:
self._invalidate_get_event_cache(redacts)
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))
self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
self.get_applicable_edit.invalidate((relates_to,))

View File

@@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedGroupServerStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "presence":
self._presence_id_gen.advance(token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
return super(SlavedPresenceStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -15,19 +15,11 @@
# limitations under the License.
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
from synapse.storage.database import Database
from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def __init__(self, database: Database, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
)
super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
@@ -37,13 +29,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedPushRuleStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)
return super(SlavedPusherStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "receipts":
self._receipts_id_gen.advance(token)
for row in rows:
@@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super(SlavedReceiptsStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)
return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -16,12 +16,17 @@
"""
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.api.constants import EventTypes
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -83,8 +88,10 @@ class ReplicationDataHandler:
to handle updates in additional ways.
"""
def __init__(self, store: BaseSlavedStore):
self.store = store
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.pusher_pool = hs.get_pusherpool()
self.notifier = hs.get_notifier()
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -100,10 +107,32 @@ class ReplicationDataHandler:
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, token, rows)
self.store.process_replication_rows(stream_name, instance_name, token, rows)
async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, [])
if stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
for row in rows:
if row.type != EventsStreamEventRow.TypeId:
continue
assert isinstance(row, EventsStreamRow)
event = await self.store.get_event(
row.data.event_id, allow_rejected=True
)
if event.rejected_reason:
continue
extra_users = () # type: Tuple[str, ...]
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
max_token = self.store.get_room_max_stream_ordering()
self.notifier.on_new_room_event(event, token, max_token, extra_users)
await self.pusher_pool.on_new_notifications(token, token)
async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""

View File

@@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
return " ".join((self.app_id, self.push_key, self.user_id))
class InvalidateCacheCommand(Command):
"""Sent by the client to invalidate an upstream cache.
THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
NOT DISASTROUS IF WE DROP ON THE FLOOR.
Mainly used to invalidate destination retry timing caches.
Format::
INVALIDATE_CACHE <cache_func> <keys_json>
Where <keys_json> is a json list.
"""
NAME = "INVALIDATE_CACHE"
def __init__(self, cache_func, keys):
self.cache_func = cache_func
self.keys = keys
@classmethod
def from_line(cls, line):
cache_func, keys_json = line.split(" ", 1)
return cls(cache_func, json.loads(keys_json))
def to_line(self):
return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@@ -439,7 +408,6 @@ _COMMANDS = (
UserSyncCommand,
FederationAckCommand,
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
@@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,

View File

@@ -15,18 +15,7 @@
# limitations under the License.
import logging
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
)
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from prometheus_client import Counter
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@@ -48,7 +36,14 @@ from synapse.replication.tcp.commands import (
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.replication.tcp.streams import (
STREAMS_MAP,
BackfillStream,
CachesStream,
EventsStream,
FederationStream,
Stream,
)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -85,6 +80,32 @@ class ReplicationCommandHandler:
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
# List of streams that this instance is the source of
self._streams_to_replicate = [] # type: List[Stream]
for stream in self._streams.values():
if stream.NAME == CachesStream.NAME:
# All workers can write to the cache invalidation stream.
self._streams_to_replicate.append(stream)
continue
if (
isinstance(stream, (EventsStream, BackfillStream))
and hs.config.worker.writers.events == hs.get_instance_name()
):
self._streams_to_replicate.append(stream)
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
continue
if stream.NAME == FederationStream.NAME and hs.config.send_federation:
# We only support federation stream if federation sending
# has been disabled on the master.
continue
self._streams_to_replicate.append(stream)
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)
@@ -162,16 +183,33 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self._is_master:
return
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams.
"""
return self._streams
for stream_name, stream in self._streams.items():
current_token = stream.current_token()
def get_streams_to_replicate(self) -> List[Stream]:
"""Get a list of streams that this instances replicates.
"""
return self._streams_to_replicate
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: AbstractConnection):
"""Send current position of all streams this process is source of to
the connection.
"""
# We respond with current position of all streams this instance
# replicates.
for stream in self.get_streams_to_replicate():
self.send_command(
PositionCommand(stream_name, self._instance_name, current_token)
PositionCommand(
stream.NAME,
self._instance_name,
stream.current_token(self._instance_name),
)
)
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
@@ -208,18 +246,6 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
):
invalidate_cache_counter.inc()
if self._is_master:
# We invalidate the cache locally, but then also stream that to other
# workers.
await self._store.invalidate_cache_and_stream(
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
@@ -293,7 +319,7 @@ class ReplicationCommandHandler:
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
@@ -324,7 +350,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to.
current_token = stream.current_token()
current_token = stream.current_token(cmd.instance_name)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@@ -361,7 +387,9 @@ class ReplicationCommandHandler:
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(stream_name, cmd.token)
await self._replication_data_handler.on_position(
cmd.stream_name, cmd.instance_name, cmd.token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
@@ -489,12 +517,6 @@ class ReplicationCommandHandler:
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(
self,
user_id: str,

View File

@@ -70,7 +70,6 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.info("Connected to redis")
super().connectionMade()
run_as_background_process("subscribe-replication", self._send_subscribe)
self.handler.new_connection(self)
async def _send_subscribe(self):
# it's important to make sure that we only send the REPLICATE command once we
@@ -81,9 +80,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
self.handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE.
self.handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""

View File

@@ -17,7 +17,6 @@
import logging
import random
from typing import Dict, List
from prometheus_client import Counter
@@ -25,7 +24,6 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -71,26 +69,11 @@ class ReplicationStreamer(object):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self._instance_name = hs.get_instance_name()
self._replication_torture_level = hs.config.replication_torture_level
# Work out list of streams that this instance is the source of.
self.streams = [] # type: List[Stream]
if hs.config.worker_app is None:
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
# We only support federation stream if federation sending
# hase been disabled on the master.
continue
self.streams.append(stream(hs))
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
# Only bother registering the notifier callback if we have streams to
# publish.
if self.streams:
self.notifier.add_replication_callback(self.on_notifier_poke)
self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -98,10 +81,8 @@ class ReplicationStreamer(object):
self.command_handler = hs.get_tcp_replication()
def get_streams(self) -> Dict[str, Stream]:
"""Get a mapp from stream name to stream instance.
"""
return self.streams_by_name
# Set of streams to replicate.
self.streams = self.command_handler.get_streams_to_replicate()
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
@@ -145,7 +126,9 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
if stream.last_token == stream.current_token():
if stream.last_token == stream.current_token(
self._instance_name
):
continue
if self._replication_torture_level:
@@ -157,7 +140,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
stream.current_token(),
stream.current_token(self._instance_name),
)
try:
updates, current_token, limited = await stream.get_updates()

View File

@@ -95,19 +95,25 @@ class Stream(object):
def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[], Token],
current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
current_token_function and update_function are callbacks which should be
implemented by subclasses.
`current_token_function` and `update_function` are callbacks which
should be implemented by subclasses.
current_token_function is called to get the current token of the underlying
stream.
`current_token_function` takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process). On the writer process this is where
the writer has successfully written up to, whereas on other processes
this is the position which we have received updates up to over
replication. (Note that most streams have a single writer and so their
implementations ignore the instance name passed in).
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
`update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more
info.
Args:
local_instance_name: The instance name of the current process
@@ -119,13 +125,13 @@ class Stream(object):
self.update_function = update_function
# The token from which we last asked for updates
self.last_token = self.current_token()
self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
self.last_token = self.current_token()
self.last_token = self.current_token(self.local_instance_name)
async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
@@ -137,7 +143,7 @@ class Stream(object):
position in stream, and `limited` is whether there are more updates
to fetch.
"""
current_token = self.current_token()
current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token
)
@@ -169,6 +175,16 @@ class Stream(object):
return updates, upto_token, limited
def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
"""Takes a current token callback function for a single writer stream
that doesn't take an instance name parameter and wraps it in a function that
does accept an instance name parameter but ignores it.
"""
return lambda instance_name: current_token()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
@@ -234,7 +250,7 @@ class BackfillStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_backfill_token,
current_token_without_instance(store.get_current_backfill_token),
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@@ -270,7 +286,9 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function
hs.get_instance_name(),
current_token_without_instance(store.get_current_presence_token),
update_function,
)
@@ -295,7 +313,9 @@ class TypingStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function
hs.get_instance_name(),
current_token_without_instance(typing_handler.get_current_token),
update_function,
)
@@ -318,7 +338,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_receipt_stream_id,
current_token_without_instance(store.get_max_receipt_stream_id),
db_query_to_update_function(store.get_all_updated_receipts),
)
@@ -338,7 +358,7 @@ class PushRulesStream(Stream):
hs.get_instance_name(), self._current_token, self._update_function
)
def _current_token(self) -> int:
def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
@@ -372,7 +392,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
store.get_pushers_stream_token,
current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@@ -401,13 +421,27 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
self.store.get_cache_stream_token,
self._update_function,
)
async def _update_function(
self, instance_name: str, from_token: int, upto_token: int, limit: int
):
rows = await self.store.get_all_updated_caches(
instance_name, from_token, upto_token, limit
)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
class PublicRoomsStream(Stream):
"""The public rooms list changed
@@ -430,7 +464,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_public_room_stream_id,
current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms),
)
@@ -451,7 +485,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@@ -469,7 +503,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_to_device_stream_token,
current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages),
)
@@ -489,7 +523,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_account_data_stream_id,
current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags),
)
@@ -509,7 +543,7 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
current_token_without_instance(self.store.get_max_account_data_stream_id),
db_query_to_update_function(self._update_function),
)
@@ -540,7 +574,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_group_stream_token,
current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes),
)
@@ -558,7 +592,7 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),

View File

@@ -20,7 +20,7 @@ from typing import List, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream
@@ -119,7 +119,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self._store.get_current_events_token,
current_token_without_instance(self._store.get_current_events_token),
self._update_function,
)

View File

@@ -15,7 +15,11 @@
# limitations under the License.
from collections import namedtuple
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
from synapse.replication.tcp.streams._base import (
Stream,
current_token_without_instance,
make_http_update_function,
)
class FederationStream(Stream):
@@ -35,21 +39,35 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
# Not all synapse instances will have a federation sender instance,
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
# so we stub the stream out when that is the case.
if hs.config.worker_app is None or hs.should_send_federation():
if hs.config.worker_app is None:
# master process: get updates from the FederationRemoteSendQueue.
# (if the master is configured to send federation itself, federation_sender
# will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.)
federation_sender = hs.get_federation_sender()
current_token = federation_sender.get_current_token
update_function = db_query_to_update_function(
federation_sender.get_replication_rows
current_token = current_token_without_instance(
federation_sender.get_current_token
)
update_function = federation_sender.get_replication_rows
elif hs.should_send_federation():
# federation sender: Query master process
update_function = make_http_update_function(hs, self.NAME)
current_token = self._stub_current_token
else:
current_token = lambda: 0
# other worker: stub out the update function (we're not interested in
# any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
current_token = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
def _stub_current_token(instance_name: str) -> int:
# dummy current-token method for use on workers
return 0
@staticmethod
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False

View File

@@ -30,7 +30,7 @@
<tr>
<td colspan="2">
<div class="noticetext">Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.</div>
<div class="noticetext">To extend the validity of your account, please click on the link bellow (or copy and paste it into a new browser tab):</div>
<div class="noticetext">To extend the validity of your account, please click on the link below (or copy and paste it into a new browser tab):</div>
<div class="noticetext"><a href="{{ url }}">{{ url }}</a></div>
</td>
</tr>

View File

@@ -2,6 +2,6 @@ Hi {{ display_name }},
Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
To extend the validity of your account, please click on the link bellow (or copy and paste it to a new browser tab):
To extend the validity of your account, please click on the link below (or copy and paste it to a new browser tab):
{{ url }}

View File

@@ -0,0 +1,18 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>SSO error</title>
</head>
<body>
<p>Oops! Something went wrong during authentication.</p>
<p>
Try logging in again from your Matrix client and if the problem persists
please contact the server's administrator.
</p>
<p>Error: <code>{{ error }}</code></p>
{% if error_description %}
<pre><code>{{ error_description }}</code></pre>
{% endif %}
</body>
</html>

View File

@@ -32,6 +32,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
JoinRoomAliasServlet,
ListRoomRestServlet,
RoomRestServlet,
ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -193,6 +194,7 @@ def register_servlets(hs, http_server):
"""
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)

View File

@@ -26,6 +26,7 @@ from synapse.http.servlet import (
)
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
historical_admin_path_patterns,
)
@@ -169,7 +170,7 @@ class ListRoomRestServlet(RestServlet):
in a dictionary containing room information. Supports pagination.
"""
PATTERNS = admin_patterns("/rooms")
PATTERNS = admin_patterns("/rooms$")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -253,6 +254,29 @@ class ListRoomRestServlet(RestServlet):
return 200, response
class RoomRestServlet(RestServlet):
"""Get room details.
TODO: Add on_POST to allow room creation without joining the room
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room_with_stats(room_id)
if not ret:
raise NotFoundError("Room not found")
return 200, ret
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")

View File

@@ -83,6 +83,7 @@ class LoginRestServlet(RestServlet):
self.jwt_algorithm = hs.config.jwt_algorithm
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -96,9 +97,7 @@ class LoginRestServlet(RestServlet):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
@@ -114,6 +113,11 @@ class LoginRestServlet(RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
elif self.saml2_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
elif self.oidc_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
flows.extend(
({"type": t} for t in self.auth_handler.get_supported_login_types())
@@ -465,6 +469,22 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
return self._saml_handler.handle_redirect_request(client_redirect_url)
class OIDCRedirectServlet(RestServlet):
"""Implementation for /login/sso/redirect for the OIDC login flow."""
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs):
self._oidc_handler = hs.get_oidc_handler()
async def on_GET(self, request):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled:
@@ -472,3 +492,5 @@ def register_servlets(hs, http_server):
CasTicketServlet(hs).register(http_server)
elif hs.config.saml2_enabled:
SAMLRedirectServlet(hs).register(http_server)
elif hs.config.oidc_enabled:
OIDCRedirectServlet(hs).register(http_server)

View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.web.resource import Resource
from synapse.rest.oidc.callback_resource import OIDCCallbackResource
logger = logging.getLogger(__name__)
class OIDCResource(Resource):
def __init__(self, hs):
Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs))

View File

@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.http.server import DirectServeResource, wrap_html_request_handler
logger = logging.getLogger(__name__)
class OIDCCallbackResource(DirectServeResource):
isLeaf = 1
def __init__(self, hs):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
@wrap_html_request_handler
async def _async_render_GET(self, request):
return await self._oidc_handler.handle_oidc_callback(request)

Some files were not shown because too many files have changed in this diff Show More