Compare commits
52 Commits
v1.13.0rc2
...
erikj/debu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b2d6fdd33 | ||
|
|
d263a4de02 | ||
|
|
66c1dff3ba | ||
|
|
96b6023e3b | ||
|
|
452019064c | ||
|
|
7c8e09bcf1 | ||
|
|
e7f5ac4ed8 | ||
|
|
208ab7b135 | ||
|
|
41f558ccf7 | ||
|
|
342796d6ac | ||
|
|
bc3fc3927f | ||
|
|
d67a8b5455 | ||
|
|
4734a7bbe4 | ||
|
|
1de36407d1 | ||
|
|
dede23ff1e | ||
|
|
1124111a12 | ||
|
|
46cb2550bb | ||
|
|
18c1e52d82 | ||
|
|
00ba9c48bf | ||
|
|
782e4e64df | ||
|
|
7ee24c5674 | ||
|
|
8ca79613e6 | ||
|
|
51fb0fc2e5 | ||
|
|
1a1da60ad2 | ||
|
|
8c8858e124 | ||
|
|
be309d99cf | ||
|
|
7cb8b4bc67 | ||
|
|
a8580c5f19 | ||
|
|
5cf758cdd6 | ||
|
|
67feea8044 | ||
|
|
616af44137 | ||
|
|
a4a5ec4096 | ||
|
|
5bb26b7c4f | ||
|
|
9e0384dd3f | ||
|
|
22246919e3 | ||
|
|
d7983b63a6 | ||
|
|
2929ce29d6 | ||
|
|
62ee862119 | ||
|
|
fa0b2bd28d | ||
|
|
16b67c404d | ||
|
|
db5f9031b7 | ||
|
|
2e0c46ca07 | ||
|
|
79007a42b2 | ||
|
|
30a19daa02 | ||
|
|
e48361545d | ||
|
|
0f6ebf393d | ||
|
|
16b1a34e80 | ||
|
|
d5aa7d93ed | ||
|
|
8123b2f909 | ||
|
|
15aa09bbe6 | ||
|
|
eab59d758d | ||
|
|
a251e0f4ba |
1
changelog.d/6391.feature
Normal file
1
changelog.d/6391.feature
Normal file
@@ -0,0 +1 @@
|
||||
Synapse's cache factor can now be configured in `homeserver.yaml` by the `caches.global_factor` setting. Additionally, `caches.per_cache_factors` controls the cache factors for individual caches.
|
||||
1
changelog.d/7256.feature
Normal file
1
changelog.d/7256.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs).
|
||||
1
changelog.d/7281.misc
Normal file
1
changelog.d/7281.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add MultiWriterIdGenerator to support multiple concurrent writers of streams.
|
||||
1
changelog.d/7317.feature
Normal file
1
changelog.d/7317.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add room details admin endpoint. Contributed by Awesome Technologies Innovationslabor GmbH.
|
||||
1
changelog.d/7374.misc
Normal file
1
changelog.d/7374.misc
Normal file
@@ -0,0 +1 @@
|
||||
Move catchup of replication streams logic to worker.
|
||||
1
changelog.d/7382.misc
Normal file
1
changelog.d/7382.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add typing annotations in `synapse.federation`.
|
||||
1
changelog.d/7396.misc
Normal file
1
changelog.d/7396.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert the room handler to async/await.
|
||||
1
changelog.d/7398.docker
Normal file
1
changelog.d/7398.docker
Normal file
@@ -0,0 +1 @@
|
||||
Update docker runtime image to Alpine v3.11. Contributed by @Starbix.
|
||||
1
changelog.d/7428.misc
Normal file
1
changelog.d/7428.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve performance of `get_e2e_cross_signing_key`.
|
||||
1
changelog.d/7429.misc
Normal file
1
changelog.d/7429.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve performance of `mark_as_sent_devices_by_remote`.
|
||||
1
changelog.d/7435.feature
Normal file
1
changelog.d/7435.feature
Normal file
@@ -0,0 +1 @@
|
||||
Allow for using more than one spam checker module at once.
|
||||
1
changelog.d/7436.misc
Normal file
1
changelog.d/7436.misc
Normal file
@@ -0,0 +1 @@
|
||||
Support any process writing to cache invalidation stream.
|
||||
1
changelog.d/7440.misc
Normal file
1
changelog.d/7440.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor event persistence database functions in preparation for allowing them to be run on non-master processes.
|
||||
1
changelog.d/7445.misc
Normal file
1
changelog.d/7445.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to the SAML handler.
|
||||
1
changelog.d/7448.misc
Normal file
1
changelog.d/7448.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove storage method `get_hosts_in_room` that is no longer called anywhere.
|
||||
1
changelog.d/7449.misc
Normal file
1
changelog.d/7449.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix some typos in the notice_expiry templates.
|
||||
1
changelog.d/7458.doc
Normal file
1
changelog.d/7458.doc
Normal file
@@ -0,0 +1 @@
|
||||
Update information about mapping providers for SAML and OpenID.
|
||||
1
changelog.d/7459.misc
Normal file
1
changelog.d/7459.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert the federation handler to async/await.
|
||||
1
changelog.d/7460.misc
Normal file
1
changelog.d/7460.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert the search handler to async/await.
|
||||
1
changelog.d/7470.misc
Normal file
1
changelog.d/7470.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix linting errors in new version of Flake8.
|
||||
1
changelog.d/7475.misc
Normal file
1
changelog.d/7475.misc
Normal file
@@ -0,0 +1 @@
|
||||
Have all instance correctly respond to REPLICATE command.
|
||||
1
changelog.d/7477.doc
Normal file
1
changelog.d/7477.doc
Normal file
@@ -0,0 +1 @@
|
||||
Fix copy-paste error in `ServerNoticesConfig` docstring. Contributed by @ptman.
|
||||
1
changelog.d/7482.bugfix
Normal file
1
changelog.d/7482.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix Redis reconnection logic that can result in missed updates over replication if master reconnects to Redis without restarting.
|
||||
1
changelog.d/7490.misc
Normal file
1
changelog.d/7490.misc
Normal file
@@ -0,0 +1 @@
|
||||
Clean up replication unit tests.
|
||||
1
changelog.d/7491.misc
Normal file
1
changelog.d/7491.misc
Normal file
@@ -0,0 +1 @@
|
||||
Move event stream handling out of slave store.
|
||||
1
changelog.d/7492.misc
Normal file
1
changelog.d/7492.misc
Normal file
@@ -0,0 +1 @@
|
||||
Allow censoring of events to happen on workers.
|
||||
1
changelog.d/7493.misc
Normal file
1
changelog.d/7493.misc
Normal file
@@ -0,0 +1 @@
|
||||
Move EventStream handling into default ReplicationDataHandler.
|
||||
1
changelog.d/7495.feature
Normal file
1
changelog.d/7495.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add `instance_map` config and route replication calls.
|
||||
@@ -55,7 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \
|
||||
### Stage 1: runtime
|
||||
###
|
||||
|
||||
FROM docker.io/python:${PYTHON_VERSION}-alpine3.10
|
||||
FROM docker.io/python:${PYTHON_VERSION}-alpine3.11
|
||||
|
||||
# xmlsec is required for saml support
|
||||
RUN apk add --no-cache --virtual .runtime_deps \
|
||||
|
||||
@@ -264,3 +264,57 @@ Response:
|
||||
|
||||
Once the `next_token` parameter is no longer present, we know we've reached the
|
||||
end of the list.
|
||||
|
||||
# DRAFT: Room Details API
|
||||
|
||||
The Room Details admin API allows server admins to get all details of a room.
|
||||
|
||||
This API is still a draft and details might change!
|
||||
|
||||
The following fields are possible in the JSON response body:
|
||||
|
||||
* `room_id` - The ID of the room.
|
||||
* `name` - The name of the room.
|
||||
* `canonical_alias` - The canonical (main) alias address of the room.
|
||||
* `joined_members` - How many users are currently in the room.
|
||||
* `joined_local_members` - How many local users are currently in the room.
|
||||
* `version` - The version of the room as a string.
|
||||
* `creator` - The `user_id` of the room creator.
|
||||
* `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active.
|
||||
* `federatable` - Whether users on other servers can join this room.
|
||||
* `public` - Whether the room is visible in room directory.
|
||||
* `join_rules` - The type of rules used for users wishing to join this room. One of: ["public", "knock", "invite", "private"].
|
||||
* `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"].
|
||||
* `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"].
|
||||
* `state_events` - Total number of state_events of a room. Complexity of the room.
|
||||
|
||||
## Usage
|
||||
|
||||
A standard request:
|
||||
|
||||
```
|
||||
GET /_synapse/admin/v1/rooms/<room_id>
|
||||
|
||||
{}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
{
|
||||
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
|
||||
"name": "Music Theory",
|
||||
"canonical_alias": "#musictheory:matrix.org",
|
||||
"joined_members": 127
|
||||
"joined_local_members": 2,
|
||||
"version": "1",
|
||||
"creator": "@foo:matrix.org",
|
||||
"encryption": null,
|
||||
"federatable": true,
|
||||
"public": true,
|
||||
"join_rules": "invite",
|
||||
"guest_access": null,
|
||||
"history_visibility": "shared",
|
||||
"state_events": 93534
|
||||
}
|
||||
```
|
||||
|
||||
175
docs/dev/oidc.md
Normal file
175
docs/dev/oidc.md
Normal file
@@ -0,0 +1,175 @@
|
||||
# How to test OpenID Connect
|
||||
|
||||
Any OpenID Connect Provider (OP) should work with Synapse, as long as it supports the authorization code flow.
|
||||
There are a few options for that:
|
||||
|
||||
- start a local OP. Synapse has been tested with [Hydra][hydra] and [Dex][dex-idp].
|
||||
Note that for an OP to work, it should be served under a secure (HTTPS) origin.
|
||||
A certificate signed with a self-signed, locally trusted CA should work. In that case, start Synapse with a `SSL_CERT_FILE` environment variable set to the path of the CA.
|
||||
- use a publicly available OP. Synapse has been tested with [Google][google-idp].
|
||||
- setup a SaaS OP, like [Auth0][auth0] and [Okta][okta]. Auth0 has a free tier which has been tested with Synapse.
|
||||
|
||||
[google-idp]: https://developers.google.com/identity/protocols/OpenIDConnect#authenticatingtheuser
|
||||
[auth0]: https://auth0.com/
|
||||
[okta]: https://www.okta.com/
|
||||
[dex-idp]: https://github.com/dexidp/dex
|
||||
[hydra]: https://www.ory.sh/docs/hydra/
|
||||
|
||||
|
||||
## Sample configs
|
||||
|
||||
Here are a few configs for providers that should work with Synapse.
|
||||
|
||||
### [Dex][dex-idp]
|
||||
|
||||
[Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider.
|
||||
Although it is designed to help building a full-blown provider, with some external database, it can be configured with static passwords in a config file.
|
||||
|
||||
Follow the [Getting Started guide](https://github.com/dexidp/dex/blob/master/Documentation/getting-started.md) to install Dex.
|
||||
|
||||
Edit `examples/config-dev.yaml` config file from the Dex repo to add a client:
|
||||
|
||||
```yaml
|
||||
staticClients:
|
||||
- id: synapse
|
||||
secret: secret
|
||||
redirectURIs:
|
||||
- '[synapse base url]/_synapse/oidc/callback'
|
||||
name: 'Synapse'
|
||||
```
|
||||
|
||||
Run with `dex serve examples/config-dex.yaml`
|
||||
|
||||
Synapse config:
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
skip_verification: true # This is needed as Dex is served on an insecure endpoint
|
||||
issuer: "http://127.0.0.1:5556/dex"
|
||||
discover: true
|
||||
client_id: "synapse"
|
||||
client_secret: "secret"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: '{{ user.name }}'
|
||||
display_name_template: '{{ user.name|capitalize }}'
|
||||
```
|
||||
|
||||
### [Auth0][auth0]
|
||||
|
||||
1. Create a regular web application for Synapse
|
||||
2. Set the Allowed Callback URLs to `[synapse base url]/_synapse/oidc/callback`
|
||||
3. Add a rule to add the `preferred_username` claim.
|
||||
<details>
|
||||
<summary>Code sample</summary>
|
||||
|
||||
```js
|
||||
function addPersistenceAttribute(user, context, callback) {
|
||||
user.user_metadata = user.user_metadata || {};
|
||||
user.user_metadata.preferred_username = user.user_metadata.preferred_username || user.user_id;
|
||||
context.idToken.preferred_username = user.user_metadata.preferred_username;
|
||||
|
||||
auth0.users.updateUserMetadata(user.user_id, user.user_metadata)
|
||||
.then(function(){
|
||||
callback(null, user, context);
|
||||
})
|
||||
.catch(function(err){
|
||||
callback(err);
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
|
||||
discover: true
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: '{{ user.preferred_username }}'
|
||||
display_name_template: '{{ user.name }}'
|
||||
```
|
||||
|
||||
### GitHub
|
||||
|
||||
GitHub is a bit special as it is not an OpenID Connect compliant provider, but just a regular OAuth2 provider.
|
||||
The `/user` API endpoint can be used to retrieve informations from the user.
|
||||
As the OIDC login mechanism needs an attribute to uniquely identify users and that endpoint does not return a `sub` property, an alternative `subject_claim` has to be set.
|
||||
|
||||
1. Create a new OAuth application: https://github.com/settings/applications/new
|
||||
2. Set the callback URL to `[synapse base url]/_synapse/oidc/callback`
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
issuer: "https://github.com/"
|
||||
discover: false
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
authorization_endpoint: "https://github.com/login/oauth/authorize"
|
||||
token_endpoint: "https://github.com/login/oauth/access_token"
|
||||
userinfo_endpoint: "https://api.github.com/user"
|
||||
scopes:
|
||||
- read:user
|
||||
user_mapping_provider:
|
||||
config:
|
||||
subject_claim: 'id'
|
||||
localpart_template: '{{ user.login }}'
|
||||
display_name_template: '{{ user.name }}'
|
||||
```
|
||||
|
||||
### Google
|
||||
|
||||
1. Setup a project in the Google API Console
|
||||
2. Obtain the OAuth 2.0 credentials (see <https://developers.google.com/identity/protocols/oauth2/openid-connect>)
|
||||
3. Add this Authorized redirect URI: `[synapse base url]/_synapse/oidc/callback`
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
issuer: "https://accounts.google.com/"
|
||||
discover: true
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: '{{ user.given_name|lower }}'
|
||||
display_name_template: '{{ user.name }}'
|
||||
```
|
||||
|
||||
### Twitch
|
||||
|
||||
1. Setup a developer account on [Twitch](https://dev.twitch.tv/)
|
||||
2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/)
|
||||
3. Add this OAuth Redirect URL: `[synapse base url]/_synapse/oidc/callback`
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
issuer: "https://id.twitch.tv/oauth2/"
|
||||
discover: true
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
client_auth_method: "client_secret_post"
|
||||
scopes:
|
||||
- openid
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: '{{ user.preferred_username }}'
|
||||
display_name_template: '{{ user.name }}'
|
||||
```
|
||||
@@ -1,77 +0,0 @@
|
||||
# SAML Mapping Providers
|
||||
|
||||
A SAML mapping provider is a Python class (loaded via a Python module) that
|
||||
works out how to map attributes of a SAML response object to Matrix-specific
|
||||
user attributes. Details such as user ID localpart, displayname, and even avatar
|
||||
URLs are all things that can be mapped from talking to a SSO service.
|
||||
|
||||
As an example, a SSO service may return the email address
|
||||
"john.smith@example.com" for a user, whereas Synapse will need to figure out how
|
||||
to turn that into a displayname when creating a Matrix user for this individual.
|
||||
It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
|
||||
variations. As each Synapse configuration may want something different, this is
|
||||
where SAML mapping providers come into play.
|
||||
|
||||
## Enabling Providers
|
||||
|
||||
External mapping providers are provided to Synapse in the form of an external
|
||||
Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere,
|
||||
then tell Synapse where to look for the handler class by editing the
|
||||
`saml2_config.user_mapping_provider.module` config option.
|
||||
|
||||
`saml2_config.user_mapping_provider.config` allows you to provide custom
|
||||
configuration options to the module. Check with the module's documentation for
|
||||
what options it provides (if any). The options listed by default are for the
|
||||
user mapping provider built in to Synapse. If using a custom module, you should
|
||||
comment these options out and use those specified by the module instead.
|
||||
|
||||
## Building a Custom Mapping Provider
|
||||
|
||||
A custom mapping provider must specify the following methods:
|
||||
|
||||
* `__init__(self, parsed_config)`
|
||||
- Arguments:
|
||||
- `parsed_config` - A configuration object that is the return value of the
|
||||
`parse_config` method. You should set any configuration options needed by
|
||||
the module here.
|
||||
* `saml_response_to_user_attributes(self, saml_response, failures)`
|
||||
- Arguments:
|
||||
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
|
||||
information from.
|
||||
- `failures` - An `int` that represents the amount of times the returned
|
||||
mxid localpart mapping has failed. This should be used
|
||||
to create a deduplicated mxid localpart which should be
|
||||
returned instead. For example, if this method returns
|
||||
`john.doe` as the value of `mxid_localpart` in the returned
|
||||
dict, and that is already taken on the homeserver, this
|
||||
method will be called again with the same parameters but
|
||||
with failures=1. The method should then return a different
|
||||
`mxid_localpart` value, such as `john.doe1`.
|
||||
- This method must return a dictionary, which will then be used by Synapse
|
||||
to build a new user. The following keys are allowed:
|
||||
* `mxid_localpart` - Required. The mxid localpart of the new user.
|
||||
* `displayname` - The displayname of the new user. If not provided, will default to
|
||||
the value of `mxid_localpart`.
|
||||
* `parse_config(config)`
|
||||
- This method should have the `@staticmethod` decoration.
|
||||
- Arguments:
|
||||
- `config` - A `dict` representing the parsed content of the
|
||||
`saml2_config.user_mapping_provider.config` homeserver config option.
|
||||
Runs on homeserver startup. Providers should extract any option values
|
||||
they need here.
|
||||
- Whatever is returned will be passed back to the user mapping provider module's
|
||||
`__init__` method during construction.
|
||||
* `get_saml_attributes(config)`
|
||||
- This method should have the `@staticmethod` decoration.
|
||||
- Arguments:
|
||||
- `config` - A object resulting from a call to `parse_config`.
|
||||
- Returns a tuple of two sets. The first set equates to the saml auth
|
||||
response attributes that are required for the module to function, whereas
|
||||
the second set consists of those attributes which can be used if available,
|
||||
but are not necessary.
|
||||
|
||||
## Synapse's Default Provider
|
||||
|
||||
Synapse has a built-in SAML mapping provider if a custom provider isn't
|
||||
specified in the config. It is located at
|
||||
[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).
|
||||
@@ -603,6 +603,45 @@ acme:
|
||||
|
||||
|
||||
|
||||
## Caching ##
|
||||
|
||||
# Caching can be configured through the following options.
|
||||
#
|
||||
# A cache 'factor' is a multiplier that can be applied to each of
|
||||
# Synapse's caches in order to increase or decrease the maximum
|
||||
# number of entries that can be stored.
|
||||
|
||||
# The number of events to cache in memory. Not affected by
|
||||
# caches.global_factor.
|
||||
#
|
||||
#event_cache_size: 10K
|
||||
|
||||
caches:
|
||||
# Controls the global cache factor, which is the default cache factor
|
||||
# for all caches if a specific factor for that cache is not otherwise
|
||||
# set.
|
||||
#
|
||||
# This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
|
||||
# variable. Setting by environment variable takes priority over
|
||||
# setting through the config file.
|
||||
#
|
||||
# Defaults to 0.5, which will half the size of all caches.
|
||||
#
|
||||
#global_factor: 1.0
|
||||
|
||||
# A dictionary of cache name to cache factor for that individual
|
||||
# cache. Overrides the global cache factor for a given cache.
|
||||
#
|
||||
# These can also be set through environment variables comprised
|
||||
# of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
|
||||
# letters and underscores. Setting by environment variable
|
||||
# takes priority over setting through the config file.
|
||||
# Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
|
||||
#
|
||||
per_cache_factors:
|
||||
#get_users_who_share_room_with_user: 2.0
|
||||
|
||||
|
||||
## Database ##
|
||||
|
||||
# The 'database' setting defines the database that synapse uses to store all of
|
||||
@@ -646,10 +685,6 @@ database:
|
||||
args:
|
||||
database: DATADIR/homeserver.db
|
||||
|
||||
# Number of events to cache in memory.
|
||||
#
|
||||
#event_cache_size: 10K
|
||||
|
||||
|
||||
## Logging ##
|
||||
|
||||
@@ -1470,6 +1505,94 @@ saml2_config:
|
||||
#template_dir: "res/templates"
|
||||
|
||||
|
||||
# Enable OpenID Connect for registration and login. Uses authlib.
|
||||
#
|
||||
oidc_config:
|
||||
# enable OpenID Connect. Defaults to false.
|
||||
#
|
||||
#enabled: true
|
||||
|
||||
# use the OIDC discovery mechanism to discover endpoints. Defaults to true.
|
||||
#
|
||||
#discover: true
|
||||
|
||||
# the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required.
|
||||
#
|
||||
#issuer: "https://accounts.example.com/"
|
||||
|
||||
# oauth2 client id to use. Required.
|
||||
#
|
||||
#client_id: "provided-by-your-issuer"
|
||||
|
||||
# oauth2 client secret to use. Required.
|
||||
#
|
||||
#client_secret: "provided-by-your-issuer"
|
||||
|
||||
# auth method to use when exchanging the token.
|
||||
# Valid values are "client_secret_basic" (default), "client_secret_post" and "none".
|
||||
#
|
||||
#client_auth_method: "client_auth_basic"
|
||||
|
||||
# list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"].
|
||||
#
|
||||
#scopes: ["openid"]
|
||||
|
||||
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
|
||||
#
|
||||
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
|
||||
|
||||
# the oauth2 token endpoint. Required if provider discovery is disabled.
|
||||
#
|
||||
#token_endpoint: "https://accounts.example.com/oauth2/token"
|
||||
|
||||
# the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked.
|
||||
#
|
||||
#userinfo_endpoint: "https://accounts.example.com/userinfo"
|
||||
|
||||
# URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used.
|
||||
#
|
||||
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
||||
|
||||
# skip metadata verification. Defaults to false.
|
||||
# Use this if you are connecting to a provider that is not OpenID Connect compliant.
|
||||
# Avoid this in production.
|
||||
#
|
||||
#skip_verification: false
|
||||
|
||||
|
||||
# An external module can be provided here as a custom solution to mapping
|
||||
# attributes returned from a OIDC provider onto a matrix user.
|
||||
#
|
||||
user_mapping_provider:
|
||||
# The custom module's class. Uncomment to use a custom module.
|
||||
# Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
|
||||
#
|
||||
#module: mapping_provider.OidcMappingProvider
|
||||
|
||||
# Custom configuration values for the module. Below options are intended
|
||||
# for the built-in provider, they should be changed if using a custom
|
||||
# module. This section will be passed as a Python dictionary to the
|
||||
# module's `parse_config` method.
|
||||
#
|
||||
# Below is the config of the default mapping provider, based on Jinja2
|
||||
# templates. Those templates are used to render user attributes, where the
|
||||
# userinfo object is available through the `user` variable.
|
||||
#
|
||||
config:
|
||||
# name of the claim containing a unique identifier for the user.
|
||||
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
|
||||
#
|
||||
#subject_claim: "sub"
|
||||
|
||||
# Jinja2 template for the localpart of the MXID
|
||||
#
|
||||
localpart_template: "{{ user.preferred_username }}"
|
||||
|
||||
# Jinja2 template for the display name to set on first login. Optional.
|
||||
#
|
||||
#display_name_template: "{{ user.given_name }} {{ user.last_name }}"
|
||||
|
||||
|
||||
|
||||
# Enable CAS for registration and login.
|
||||
#
|
||||
@@ -1554,6 +1677,13 @@ sso:
|
||||
#
|
||||
# This template has no additional variables.
|
||||
#
|
||||
# * HTML page to display to users if something goes wrong during the
|
||||
# OpenID Connect authentication process: 'sso_error.html'.
|
||||
#
|
||||
# When rendering, this template is given two variables:
|
||||
# * error: the technical name of the error
|
||||
# * error_description: a human-readable message for the error
|
||||
#
|
||||
# You can see the default templates at:
|
||||
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
|
||||
#
|
||||
@@ -1772,10 +1902,17 @@ password_providers:
|
||||
# include_content: true
|
||||
|
||||
|
||||
#spam_checker:
|
||||
# module: "my_custom_project.SuperSpamChecker"
|
||||
# config:
|
||||
# example_option: 'things'
|
||||
# Spam checkers are third-party modules that can block specific actions
|
||||
# of local users, such as creating rooms and registering undesirable
|
||||
# usernames, as well as remote users by redacting incoming events.
|
||||
#
|
||||
spam_checker:
|
||||
#- module: "my_custom_project.SuperSpamChecker"
|
||||
# config:
|
||||
# example_option: 'things'
|
||||
#- module: "some_other_project.BadEventStopper"
|
||||
# config:
|
||||
# example_stop_events_from: ['@bad:example.com']
|
||||
|
||||
|
||||
# Uncomment to allow non-server-admin users to create groups on this server
|
||||
|
||||
@@ -64,10 +64,12 @@ class ExampleSpamChecker:
|
||||
Modify the `spam_checker` section of your `homeserver.yaml` in the following
|
||||
manner:
|
||||
|
||||
`module` should point to the fully qualified Python class that implements your
|
||||
custom logic, e.g. `my_module.ExampleSpamChecker`.
|
||||
Create a list entry with the keys `module` and `config`.
|
||||
|
||||
`config` is a dictionary that gets passed to the spam checker class.
|
||||
* `module` should point to the fully qualified Python class that implements your
|
||||
custom logic, e.g. `my_module.ExampleSpamChecker`.
|
||||
|
||||
* `config` is a dictionary that gets passed to the spam checker class.
|
||||
|
||||
### Example
|
||||
|
||||
@@ -75,12 +77,15 @@ This section might look like:
|
||||
|
||||
```yaml
|
||||
spam_checker:
|
||||
module: my_module.ExampleSpamChecker
|
||||
config:
|
||||
# Enable or disable a specific option in ExampleSpamChecker.
|
||||
my_custom_option: true
|
||||
- module: my_module.ExampleSpamChecker
|
||||
config:
|
||||
# Enable or disable a specific option in ExampleSpamChecker.
|
||||
my_custom_option: true
|
||||
```
|
||||
|
||||
More spam checkers can be added in tandem by appending more items to the list. An
|
||||
action is blocked when at least one of the configured spam checkers flags it.
|
||||
|
||||
## Examples
|
||||
|
||||
The [Mjolnir](https://github.com/matrix-org/mjolnir) project is a full fledged
|
||||
|
||||
146
docs/sso_mapping_providers.md
Normal file
146
docs/sso_mapping_providers.md
Normal file
@@ -0,0 +1,146 @@
|
||||
# SSO Mapping Providers
|
||||
|
||||
A mapping provider is a Python class (loaded via a Python module) that
|
||||
works out how to map attributes of a SSO response to Matrix-specific
|
||||
user attributes. Details such as user ID localpart, displayname, and even avatar
|
||||
URLs are all things that can be mapped from talking to a SSO service.
|
||||
|
||||
As an example, a SSO service may return the email address
|
||||
"john.smith@example.com" for a user, whereas Synapse will need to figure out how
|
||||
to turn that into a displayname when creating a Matrix user for this individual.
|
||||
It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
|
||||
variations. As each Synapse configuration may want something different, this is
|
||||
where SAML mapping providers come into play.
|
||||
|
||||
SSO mapping providers are currently supported for OpenID and SAML SSO
|
||||
configurations. Please see the details below for how to implement your own.
|
||||
|
||||
External mapping providers are provided to Synapse in the form of an external
|
||||
Python module. You can retrieve this module from [PyPi](https://pypi.org) or elsewhere,
|
||||
but it must be importable via Synapse (e.g. it must be in the same virtualenv
|
||||
as Synapse). The Synapse config is then modified to point to the mapping provider
|
||||
(and optionally provide additional configuration for it).
|
||||
|
||||
## OpenID Mapping Providers
|
||||
|
||||
The OpenID mapping provider can be customized by editing the
|
||||
`oidc_config.user_mapping_provider.module` config option.
|
||||
|
||||
`oidc_config.user_mapping_provider.config` allows you to provide custom
|
||||
configuration options to the module. Check with the module's documentation for
|
||||
what options it provides (if any). The options listed by default are for the
|
||||
user mapping provider built in to Synapse. If using a custom module, you should
|
||||
comment these options out and use those specified by the module instead.
|
||||
|
||||
### Building a Custom OpenID Mapping Provider
|
||||
|
||||
A custom mapping provider must specify the following methods:
|
||||
|
||||
* `__init__(self, parsed_config)`
|
||||
- Arguments:
|
||||
- `parsed_config` - A configuration object that is the return value of the
|
||||
`parse_config` method. You should set any configuration options needed by
|
||||
the module here.
|
||||
* `parse_config(config)`
|
||||
- This method should have the `@staticmethod` decoration.
|
||||
- Arguments:
|
||||
- `config` - A `dict` representing the parsed content of the
|
||||
`oidc_config.user_mapping_provider.config` homeserver config option.
|
||||
Runs on homeserver startup. Providers should extract and validate
|
||||
any option values they need here.
|
||||
- Whatever is returned will be passed back to the user mapping provider module's
|
||||
`__init__` method during construction.
|
||||
* `get_remote_user_id(self, userinfo)`
|
||||
- Arguments:
|
||||
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
||||
information from.
|
||||
- This method must return a string, which is the unique identifier for the
|
||||
user. Commonly the ``sub`` claim of the response.
|
||||
* `map_user_attributes(self, userinfo, token)`
|
||||
- This method should be async.
|
||||
- Arguments:
|
||||
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
||||
information from.
|
||||
- `token` - A dictionary which includes information necessary to make
|
||||
further requests to the OpenID provider.
|
||||
- Returns a dictionary with two keys:
|
||||
- localpart: A required string, used to generate the Matrix ID.
|
||||
- displayname: An optional string, the display name for the user.
|
||||
|
||||
### Default OpenID Mapping Provider
|
||||
|
||||
Synapse has a built-in OpenID mapping provider if a custom provider isn't
|
||||
specified in the config. It is located at
|
||||
[`synapse.handlers.oidc_handler.JinjaOidcMappingProvider`](../synapse/handlers/oidc_handler.py).
|
||||
|
||||
## SAML Mapping Providers
|
||||
|
||||
The SAML mapping provider can be customized by editing the
|
||||
`saml2_config.user_mapping_provider.module` config option.
|
||||
|
||||
`saml2_config.user_mapping_provider.config` allows you to provide custom
|
||||
configuration options to the module. Check with the module's documentation for
|
||||
what options it provides (if any). The options listed by default are for the
|
||||
user mapping provider built in to Synapse. If using a custom module, you should
|
||||
comment these options out and use those specified by the module instead.
|
||||
|
||||
### Building a Custom SAML Mapping Provider
|
||||
|
||||
A custom mapping provider must specify the following methods:
|
||||
|
||||
* `__init__(self, parsed_config)`
|
||||
- Arguments:
|
||||
- `parsed_config` - A configuration object that is the return value of the
|
||||
`parse_config` method. You should set any configuration options needed by
|
||||
the module here.
|
||||
* `parse_config(config)`
|
||||
- This method should have the `@staticmethod` decoration.
|
||||
- Arguments:
|
||||
- `config` - A `dict` representing the parsed content of the
|
||||
`saml_config.user_mapping_provider.config` homeserver config option.
|
||||
Runs on homeserver startup. Providers should extract and validate
|
||||
any option values they need here.
|
||||
- Whatever is returned will be passed back to the user mapping provider module's
|
||||
`__init__` method during construction.
|
||||
* `get_saml_attributes(config)`
|
||||
- This method should have the `@staticmethod` decoration.
|
||||
- Arguments:
|
||||
- `config` - A object resulting from a call to `parse_config`.
|
||||
- Returns a tuple of two sets. The first set equates to the SAML auth
|
||||
response attributes that are required for the module to function, whereas
|
||||
the second set consists of those attributes which can be used if available,
|
||||
but are not necessary.
|
||||
* `get_remote_user_id(self, saml_response, client_redirect_url)`
|
||||
- Arguments:
|
||||
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
|
||||
information from.
|
||||
- `client_redirect_url` - A string, the URL that the client will be
|
||||
redirected to.
|
||||
- This method must return a string, which is the unique identifier for the
|
||||
user. Commonly the ``uid`` claim of the response.
|
||||
* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
|
||||
- Arguments:
|
||||
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
|
||||
information from.
|
||||
- `failures` - An `int` that represents the amount of times the returned
|
||||
mxid localpart mapping has failed. This should be used
|
||||
to create a deduplicated mxid localpart which should be
|
||||
returned instead. For example, if this method returns
|
||||
`john.doe` as the value of `mxid_localpart` in the returned
|
||||
dict, and that is already taken on the homeserver, this
|
||||
method will be called again with the same parameters but
|
||||
with failures=1. The method should then return a different
|
||||
`mxid_localpart` value, such as `john.doe1`.
|
||||
- `client_redirect_url` - A string, the URL that the client will be
|
||||
redirected to.
|
||||
- This method must return a dictionary, which will then be used by Synapse
|
||||
to build a new user. The following keys are allowed:
|
||||
* `mxid_localpart` - Required. The mxid localpart of the new user.
|
||||
* `displayname` - The displayname of the new user. If not provided, will default to
|
||||
the value of `mxid_localpart`.
|
||||
|
||||
### Default SAML Mapping Provider
|
||||
|
||||
Synapse has a built-in SAML mapping provider if a custom provider isn't
|
||||
specified in the config. It is located at
|
||||
[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).
|
||||
@@ -219,10 +219,6 @@ Asks the server for the current position of all streams.
|
||||
|
||||
Inform the server a pusher should be removed
|
||||
|
||||
#### INVALIDATE_CACHE (C)
|
||||
|
||||
Inform the server a cache should be invalidated
|
||||
|
||||
### REMOTE_SERVER_UP (S, C)
|
||||
|
||||
Inform other processes that a remote server may have come back online.
|
||||
|
||||
3
mypy.ini
3
mypy.ini
@@ -75,3 +75,6 @@ ignore_missing_imports = True
|
||||
|
||||
[mypy-jwt.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-authlib.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@@ -122,7 +122,7 @@ APPEND_ONLY_TABLES = [
|
||||
"presence_stream",
|
||||
"push_rules_stream",
|
||||
"ex_outlier_stream",
|
||||
"cache_invalidation_stream",
|
||||
"cache_invalidation_stream_by_instance",
|
||||
"public_room_list_stream",
|
||||
"state_group_edges",
|
||||
"stream_ordering_to_exterm",
|
||||
@@ -188,7 +188,7 @@ class MockHomeserver:
|
||||
self.clock = Clock(reactor)
|
||||
self.config = config
|
||||
self.hostname = config.server_name
|
||||
self.version_string = "Synapse/"+get_version_string(synapse)
|
||||
self.version_string = "Synapse/" + get_version_string(synapse)
|
||||
|
||||
def get_clock(self):
|
||||
return self.clock
|
||||
|
||||
@@ -37,7 +37,7 @@ from synapse.api.errors import (
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import StateMap, UserID
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
||||
from synapse.util.caches import register_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
@@ -73,7 +73,7 @@ class Auth(object):
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
|
||||
self.token_cache = LruCache(10000)
|
||||
register_cache("cache", "token_cache", self.token_cache)
|
||||
|
||||
self._auth_blocking = AuthBlocking(self.hs)
|
||||
|
||||
@@ -26,7 +26,6 @@ from twisted.web.resource import NoResource
|
||||
|
||||
import synapse
|
||||
import synapse.events
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import HttpResponseException, SynapseError
|
||||
from synapse.api.urls import (
|
||||
CLIENT_API_PREFIX,
|
||||
@@ -48,6 +47,7 @@ from synapse.http.site import SynapseSite
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
@@ -81,11 +81,6 @@ from synapse.replication.tcp.streams import (
|
||||
ToDeviceStream,
|
||||
TypingStream,
|
||||
)
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStream,
|
||||
EventsStreamEventRow,
|
||||
EventsStreamRow,
|
||||
)
|
||||
from synapse.rest.admin import register_servlets_for_media_repo
|
||||
from synapse.rest.client.v1 import events
|
||||
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
||||
@@ -122,11 +117,13 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
from synapse.rest.client.versions import VersionsRestServlet
|
||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.data_stores.main.censor_events import CensorEventsStore
|
||||
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
|
||||
from synapse.storage.data_stores.main.monthly_active_users import (
|
||||
MonthlyActiveUsersWorkerStore,
|
||||
)
|
||||
from synapse.storage.data_stores.main.presence import UserPresenceState
|
||||
from synapse.storage.data_stores.main.search import SearchWorkerStore
|
||||
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
|
||||
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
|
||||
from synapse.types import ReadReceipt
|
||||
@@ -442,6 +439,7 @@ class GenericWorkerSlavedStore(
|
||||
SlavedGroupServerStore,
|
||||
SlavedAccountDataStore,
|
||||
SlavedPusherStore,
|
||||
CensorEventsStore,
|
||||
SlavedEventStore,
|
||||
SlavedKeyStore,
|
||||
RoomStore,
|
||||
@@ -455,6 +453,7 @@ class GenericWorkerSlavedStore(
|
||||
SlavedFilteringStore,
|
||||
MonthlyActiveUsersWorkerStore,
|
||||
MediaRepositoryStore,
|
||||
SearchWorkerStore,
|
||||
BaseSlavedStore,
|
||||
):
|
||||
def __init__(self, database, db_conn, hs):
|
||||
@@ -572,6 +571,9 @@ class GenericWorkerServer(HomeServer):
|
||||
if name in ["keys", "federation"]:
|
||||
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
|
||||
|
||||
if name == "replication":
|
||||
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
|
||||
|
||||
root_resource = create_resource_tree(resources, NoResource())
|
||||
|
||||
_base.listen_tcp(
|
||||
@@ -631,7 +633,7 @@ class GenericWorkerServer(HomeServer):
|
||||
|
||||
class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
||||
def __init__(self, hs):
|
||||
super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
|
||||
super(GenericWorkerReplicationHandler, self).__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.typing_handler = hs.get_typing_handler()
|
||||
@@ -657,30 +659,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
||||
stream_name, token, rows
|
||||
)
|
||||
|
||||
if stream_name == EventsStream.NAME:
|
||||
# We shouldn't get multiple rows per token for events stream, so
|
||||
# we don't need to optimise this for multiple rows.
|
||||
for row in rows:
|
||||
if row.type != EventsStreamEventRow.TypeId:
|
||||
continue
|
||||
assert isinstance(row, EventsStreamRow)
|
||||
|
||||
event = await self.store.get_event(
|
||||
row.data.event_id, allow_rejected=True
|
||||
)
|
||||
if event.rejected_reason:
|
||||
continue
|
||||
|
||||
extra_users = ()
|
||||
if event.type == EventTypes.Member:
|
||||
extra_users = (event.state_key,)
|
||||
max_token = self.store.get_room_max_stream_ordering()
|
||||
self.notifier.on_new_room_event(
|
||||
event, token, max_token, extra_users
|
||||
)
|
||||
|
||||
await self.pusher_pool.on_new_notifications(token, token)
|
||||
elif stream_name == PushRulesStream.NAME:
|
||||
if stream_name == PushRulesStream.NAME:
|
||||
self.notifier.on_new_event(
|
||||
"push_rules_key", token, users=[row.user_id for row in rows]
|
||||
)
|
||||
|
||||
@@ -69,7 +69,6 @@ from synapse.server import HomeServer
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.engines import IncorrectDatabaseSetup
|
||||
from synapse.storage.prepare_database import UpgradeDatabaseException
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.module_loader import load_module
|
||||
@@ -192,6 +191,11 @@ class SynapseHomeServer(HomeServer):
|
||||
}
|
||||
)
|
||||
|
||||
if self.get_config().oidc_enabled:
|
||||
from synapse.rest.oidc import OIDCResource
|
||||
|
||||
resources["/_synapse/oidc"] = OIDCResource(self)
|
||||
|
||||
if self.get_config().saml2_enabled:
|
||||
from synapse.rest.saml2 import SAML2Resource
|
||||
|
||||
@@ -422,6 +426,13 @@ def setup(config_options):
|
||||
# Check if it needs to be reprovisioned every day.
|
||||
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
|
||||
|
||||
# Load the OIDC provider metadatas, if OIDC is enabled.
|
||||
if hs.config.oidc_enabled:
|
||||
oidc = hs.get_oidc_handler()
|
||||
# Loading the provider metadata also ensures the provider config is valid.
|
||||
yield defer.ensureDeferred(oidc.load_metadata())
|
||||
yield defer.ensureDeferred(oidc.load_jwks())
|
||||
|
||||
_base.start(hs, config.listeners)
|
||||
|
||||
hs.get_datastore().db.updates.start_doing_background_updates()
|
||||
@@ -504,8 +515,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
|
||||
|
||||
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
|
||||
stats["daily_sent_messages"] = daily_sent_messages
|
||||
stats["cache_factor"] = CACHE_SIZE_FACTOR
|
||||
stats["event_cache_size"] = hs.config.event_cache_size
|
||||
stats["cache_factor"] = hs.config.caches.global_factor
|
||||
stats["event_cache_size"] = hs.config.caches.event_cache_size
|
||||
|
||||
#
|
||||
# Performance statistics
|
||||
|
||||
@@ -13,6 +13,7 @@ from synapse.config import (
|
||||
key,
|
||||
logger,
|
||||
metrics,
|
||||
oidc_config,
|
||||
password,
|
||||
password_auth_providers,
|
||||
push,
|
||||
@@ -59,6 +60,7 @@ class RootConfig:
|
||||
saml2: saml2_config.SAML2Config
|
||||
cas: cas.CasConfig
|
||||
sso: sso.SSOConfig
|
||||
oidc: oidc_config.OIDCConfig
|
||||
jwt: jwt_config.JWTConfig
|
||||
password: password.PasswordConfig
|
||||
email: emailconfig.EmailConfig
|
||||
|
||||
164
synapse/config/cache.py
Normal file
164
synapse/config/cache.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Callable, Dict
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
# The prefix for all cache factor-related environment variables
|
||||
_CACHES = {}
|
||||
_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
|
||||
_DEFAULT_FACTOR_SIZE = 0.5
|
||||
_DEFAULT_EVENT_CACHE_SIZE = "10K"
|
||||
|
||||
|
||||
class CacheProperties(object):
|
||||
def __init__(self):
|
||||
# The default factor size for all caches
|
||||
self.default_factor_size = float(
|
||||
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
|
||||
)
|
||||
self.resize_all_caches_func = None
|
||||
|
||||
|
||||
properties = CacheProperties()
|
||||
|
||||
|
||||
def add_resizable_cache(cache_name: str, cache_resize_callback: Callable):
|
||||
"""Register a cache that's size can dynamically change
|
||||
|
||||
Args:
|
||||
cache_name: A reference to the cache
|
||||
cache_resize_callback: A callback function that will be ran whenever
|
||||
the cache needs to be resized
|
||||
"""
|
||||
_CACHES[cache_name.lower()] = cache_resize_callback
|
||||
|
||||
# Ensure all loaded caches are sized appropriately
|
||||
#
|
||||
# This method should only run once the config has been read,
|
||||
# as it uses values read from it
|
||||
if properties.resize_all_caches_func:
|
||||
properties.resize_all_caches_func()
|
||||
|
||||
|
||||
class CacheConfig(Config):
|
||||
section = "caches"
|
||||
_environ = os.environ
|
||||
|
||||
@staticmethod
|
||||
def reset():
|
||||
"""Resets the caches to their defaults. Used for tests."""
|
||||
properties.default_factor_size = float(
|
||||
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
|
||||
)
|
||||
properties.resize_all_caches_func = None
|
||||
_CACHES.clear()
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
## Caching ##
|
||||
|
||||
# Caching can be configured through the following options.
|
||||
#
|
||||
# A cache 'factor' is a multiplier that can be applied to each of
|
||||
# Synapse's caches in order to increase or decrease the maximum
|
||||
# number of entries that can be stored.
|
||||
|
||||
# The number of events to cache in memory. Not affected by
|
||||
# caches.global_factor.
|
||||
#
|
||||
#event_cache_size: 10K
|
||||
|
||||
caches:
|
||||
# Controls the global cache factor, which is the default cache factor
|
||||
# for all caches if a specific factor for that cache is not otherwise
|
||||
# set.
|
||||
#
|
||||
# This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
|
||||
# variable. Setting by environment variable takes priority over
|
||||
# setting through the config file.
|
||||
#
|
||||
# Defaults to 0.5, which will half the size of all caches.
|
||||
#
|
||||
#global_factor: 1.0
|
||||
|
||||
# A dictionary of cache name to cache factor for that individual
|
||||
# cache. Overrides the global cache factor for a given cache.
|
||||
#
|
||||
# These can also be set through environment variables comprised
|
||||
# of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
|
||||
# letters and underscores. Setting by environment variable
|
||||
# takes priority over setting through the config file.
|
||||
# Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
|
||||
#
|
||||
per_cache_factors:
|
||||
#get_users_who_share_room_with_user: 2.0
|
||||
"""
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
self.event_cache_size = self.parse_size(
|
||||
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
|
||||
)
|
||||
self.cache_factors = {} # type: Dict[str, float]
|
||||
|
||||
cache_config = config.get("caches") or {}
|
||||
self.global_factor = cache_config.get(
|
||||
"global_factor", properties.default_factor_size
|
||||
)
|
||||
if not isinstance(self.global_factor, (int, float)):
|
||||
raise ConfigError("caches.global_factor must be a number.")
|
||||
|
||||
# Set the global one so that it's reflected in new caches
|
||||
properties.default_factor_size = self.global_factor
|
||||
|
||||
# Load cache factors from the config
|
||||
individual_factors = cache_config.get("per_cache_factors") or {}
|
||||
if not isinstance(individual_factors, dict):
|
||||
raise ConfigError("caches.per_cache_factors must be a dictionary")
|
||||
|
||||
# Override factors from environment if necessary
|
||||
individual_factors.update(
|
||||
{
|
||||
key[len(_CACHE_PREFIX) + 1 :].lower(): float(val)
|
||||
for key, val in self._environ.items()
|
||||
if key.startswith(_CACHE_PREFIX + "_")
|
||||
}
|
||||
)
|
||||
|
||||
for cache, factor in individual_factors.items():
|
||||
if not isinstance(factor, (int, float)):
|
||||
raise ConfigError(
|
||||
"caches.per_cache_factors.%s must be a number" % (cache.lower(),)
|
||||
)
|
||||
self.cache_factors[cache.lower()] = factor
|
||||
|
||||
# Resize all caches (if necessary) with the new factors we've loaded
|
||||
self.resize_all_caches()
|
||||
|
||||
# Store this function so that it can be called from other classes without
|
||||
# needing an instance of Config
|
||||
properties.resize_all_caches_func = self.resize_all_caches
|
||||
|
||||
def resize_all_caches(self):
|
||||
"""Ensure all cache sizes are up to date
|
||||
|
||||
For each cache, run the mapped callback function with either
|
||||
a specific cache factor or the default, global one.
|
||||
"""
|
||||
for cache_name, callback in _CACHES.items():
|
||||
new_factor = self.cache_factors.get(cache_name, self.global_factor)
|
||||
callback(new_factor)
|
||||
@@ -68,10 +68,6 @@ database:
|
||||
name: sqlite3
|
||||
args:
|
||||
database: %(database_path)s
|
||||
|
||||
# Number of events to cache in memory.
|
||||
#
|
||||
#event_cache_size: 10K
|
||||
"""
|
||||
|
||||
|
||||
@@ -116,8 +112,6 @@ class DatabaseConfig(Config):
|
||||
self.databases = []
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
|
||||
|
||||
# We *experimentally* support specifying multiple databases via the
|
||||
# `databases` key. This is a map from a label to database config in the
|
||||
# same format as the `database` config option, plus an extra
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from ._base import RootConfig
|
||||
from .api import ApiConfig
|
||||
from .appservice import AppServiceConfig
|
||||
from .cache import CacheConfig
|
||||
from .captcha import CaptchaConfig
|
||||
from .cas import CasConfig
|
||||
from .consent_config import ConsentConfig
|
||||
@@ -27,6 +28,7 @@ from .jwt_config import JWTConfig
|
||||
from .key import KeyConfig
|
||||
from .logger import LoggingConfig
|
||||
from .metrics import MetricsConfig
|
||||
from .oidc_config import OIDCConfig
|
||||
from .password import PasswordConfig
|
||||
from .password_auth_providers import PasswordAuthProviderConfig
|
||||
from .push import PushConfig
|
||||
@@ -54,6 +56,7 @@ class HomeServerConfig(RootConfig):
|
||||
config_classes = [
|
||||
ServerConfig,
|
||||
TlsConfig,
|
||||
CacheConfig,
|
||||
DatabaseConfig,
|
||||
LoggingConfig,
|
||||
RatelimitConfig,
|
||||
@@ -66,6 +69,7 @@ class HomeServerConfig(RootConfig):
|
||||
AppServiceConfig,
|
||||
KeyConfig,
|
||||
SAML2Config,
|
||||
OIDCConfig,
|
||||
CasConfig,
|
||||
SSOConfig,
|
||||
JWTConfig,
|
||||
|
||||
@@ -257,5 +257,6 @@ def setup_logging(
|
||||
logging.warning("***** STARTING SERVER *****")
|
||||
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
|
||||
logging.info("Server hostname: %s", config.server_name)
|
||||
logging.info("Instance name: %s", hs.get_instance_name())
|
||||
|
||||
return logger
|
||||
|
||||
177
synapse/config/oidc_config.py
Normal file
177
synapse/config/oidc_config.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Quentin Gliech
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
|
||||
|
||||
|
||||
class OIDCConfig(Config):
|
||||
section = "oidc"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
self.oidc_enabled = False
|
||||
|
||||
oidc_config = config.get("oidc_config")
|
||||
|
||||
if not oidc_config or not oidc_config.get("enabled", False):
|
||||
return
|
||||
|
||||
try:
|
||||
check_requirements("oidc")
|
||||
except DependencyException as e:
|
||||
raise ConfigError(e.message)
|
||||
|
||||
public_baseurl = self.public_baseurl
|
||||
if public_baseurl is None:
|
||||
raise ConfigError("oidc_config requires a public_baseurl to be set")
|
||||
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
|
||||
|
||||
self.oidc_enabled = True
|
||||
self.oidc_discover = oidc_config.get("discover", True)
|
||||
self.oidc_issuer = oidc_config["issuer"]
|
||||
self.oidc_client_id = oidc_config["client_id"]
|
||||
self.oidc_client_secret = oidc_config["client_secret"]
|
||||
self.oidc_client_auth_method = oidc_config.get(
|
||||
"client_auth_method", "client_secret_basic"
|
||||
)
|
||||
self.oidc_scopes = oidc_config.get("scopes", ["openid"])
|
||||
self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
|
||||
self.oidc_token_endpoint = oidc_config.get("token_endpoint")
|
||||
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
|
||||
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
|
||||
self.oidc_subject_claim = oidc_config.get("subject_claim", "sub")
|
||||
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
|
||||
|
||||
ump_config = oidc_config.get("user_mapping_provider", {})
|
||||
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
|
||||
ump_config.setdefault("config", {})
|
||||
|
||||
(
|
||||
self.oidc_user_mapping_provider_class,
|
||||
self.oidc_user_mapping_provider_config,
|
||||
) = load_module(ump_config)
|
||||
|
||||
# Ensure loaded user mapping module has defined all necessary methods
|
||||
required_methods = [
|
||||
"get_remote_user_id",
|
||||
"map_user_attributes",
|
||||
]
|
||||
missing_methods = [
|
||||
method
|
||||
for method in required_methods
|
||||
if not hasattr(self.oidc_user_mapping_provider_class, method)
|
||||
]
|
||||
if missing_methods:
|
||||
raise ConfigError(
|
||||
"Class specified by oidc_config."
|
||||
"user_mapping_provider.module is missing required "
|
||||
"methods: %s" % (", ".join(missing_methods),)
|
||||
)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """\
|
||||
# Enable OpenID Connect for registration and login. Uses authlib.
|
||||
#
|
||||
oidc_config:
|
||||
# enable OpenID Connect. Defaults to false.
|
||||
#
|
||||
#enabled: true
|
||||
|
||||
# use the OIDC discovery mechanism to discover endpoints. Defaults to true.
|
||||
#
|
||||
#discover: true
|
||||
|
||||
# the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required.
|
||||
#
|
||||
#issuer: "https://accounts.example.com/"
|
||||
|
||||
# oauth2 client id to use. Required.
|
||||
#
|
||||
#client_id: "provided-by-your-issuer"
|
||||
|
||||
# oauth2 client secret to use. Required.
|
||||
#
|
||||
#client_secret: "provided-by-your-issuer"
|
||||
|
||||
# auth method to use when exchanging the token.
|
||||
# Valid values are "client_secret_basic" (default), "client_secret_post" and "none".
|
||||
#
|
||||
#client_auth_method: "client_auth_basic"
|
||||
|
||||
# list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"].
|
||||
#
|
||||
#scopes: ["openid"]
|
||||
|
||||
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
|
||||
#
|
||||
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
|
||||
|
||||
# the oauth2 token endpoint. Required if provider discovery is disabled.
|
||||
#
|
||||
#token_endpoint: "https://accounts.example.com/oauth2/token"
|
||||
|
||||
# the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked.
|
||||
#
|
||||
#userinfo_endpoint: "https://accounts.example.com/userinfo"
|
||||
|
||||
# URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used.
|
||||
#
|
||||
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
||||
|
||||
# skip metadata verification. Defaults to false.
|
||||
# Use this if you are connecting to a provider that is not OpenID Connect compliant.
|
||||
# Avoid this in production.
|
||||
#
|
||||
#skip_verification: false
|
||||
|
||||
|
||||
# An external module can be provided here as a custom solution to mapping
|
||||
# attributes returned from a OIDC provider onto a matrix user.
|
||||
#
|
||||
user_mapping_provider:
|
||||
# The custom module's class. Uncomment to use a custom module.
|
||||
# Default is {mapping_provider!r}.
|
||||
#
|
||||
#module: mapping_provider.OidcMappingProvider
|
||||
|
||||
# Custom configuration values for the module. Below options are intended
|
||||
# for the built-in provider, they should be changed if using a custom
|
||||
# module. This section will be passed as a Python dictionary to the
|
||||
# module's `parse_config` method.
|
||||
#
|
||||
# Below is the config of the default mapping provider, based on Jinja2
|
||||
# templates. Those templates are used to render user attributes, where the
|
||||
# userinfo object is available through the `user` variable.
|
||||
#
|
||||
config:
|
||||
# name of the claim containing a unique identifier for the user.
|
||||
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
|
||||
#
|
||||
#subject_claim: "sub"
|
||||
|
||||
# Jinja2 template for the localpart of the MXID
|
||||
#
|
||||
localpart_template: "{{{{ user.preferred_username }}}}"
|
||||
|
||||
# Jinja2 template for the display name to set on first login. Optional.
|
||||
#
|
||||
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
|
||||
""".format(
|
||||
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
|
||||
)
|
||||
@@ -51,7 +51,7 @@ class ServerNoticesConfig(Config):
|
||||
None if server notices are not enabled.
|
||||
|
||||
server_notices_mxid_avatar_url (str|None):
|
||||
The display name to use for the server notices user.
|
||||
The MXC URL for the avatar of the server notices user.
|
||||
None if server notices are not enabled.
|
||||
|
||||
server_notices_room_name (str|None):
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from ._base import Config
|
||||
@@ -22,16 +25,35 @@ class SpamCheckerConfig(Config):
|
||||
section = "spamchecker"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
self.spam_checker = None
|
||||
self.spam_checkers = [] # type: List[Tuple[Any, Dict]]
|
||||
|
||||
provider = config.get("spam_checker", None)
|
||||
if provider is not None:
|
||||
self.spam_checker = load_module(provider)
|
||||
spam_checkers = config.get("spam_checker") or []
|
||||
if isinstance(spam_checkers, dict):
|
||||
# The spam_checker config option used to only support one
|
||||
# spam checker, and thus was simply a dictionary with module
|
||||
# and config keys. Support this old behaviour by checking
|
||||
# to see if the option resolves to a dictionary
|
||||
self.spam_checkers.append(load_module(spam_checkers))
|
||||
elif isinstance(spam_checkers, list):
|
||||
for spam_checker in spam_checkers:
|
||||
if not isinstance(spam_checker, dict):
|
||||
raise ConfigError("spam_checker syntax is incorrect")
|
||||
|
||||
self.spam_checkers.append(load_module(spam_checker))
|
||||
else:
|
||||
raise ConfigError("spam_checker syntax is incorrect")
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
#spam_checker:
|
||||
# module: "my_custom_project.SuperSpamChecker"
|
||||
# config:
|
||||
# example_option: 'things'
|
||||
# Spam checkers are third-party modules that can block specific actions
|
||||
# of local users, such as creating rooms and registering undesirable
|
||||
# usernames, as well as remote users by redacting incoming events.
|
||||
#
|
||||
spam_checker:
|
||||
#- module: "my_custom_project.SuperSpamChecker"
|
||||
# config:
|
||||
# example_option: 'things'
|
||||
#- module: "some_other_project.BadEventStopper"
|
||||
# config:
|
||||
# example_stop_events_from: ['@bad:example.com']
|
||||
"""
|
||||
|
||||
@@ -36,17 +36,13 @@ class SSOConfig(Config):
|
||||
if not template_dir:
|
||||
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
|
||||
|
||||
self.sso_redirect_confirm_template_dir = template_dir
|
||||
self.sso_template_dir = template_dir
|
||||
self.sso_account_deactivated_template = self.read_file(
|
||||
os.path.join(
|
||||
self.sso_redirect_confirm_template_dir, "sso_account_deactivated.html"
|
||||
),
|
||||
os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
|
||||
"sso_account_deactivated_template",
|
||||
)
|
||||
self.sso_auth_success_template = self.read_file(
|
||||
os.path.join(
|
||||
self.sso_redirect_confirm_template_dir, "sso_auth_success.html"
|
||||
),
|
||||
os.path.join(self.sso_template_dir, "sso_auth_success.html"),
|
||||
"sso_auth_success_template",
|
||||
)
|
||||
|
||||
@@ -137,6 +133,13 @@ class SSOConfig(Config):
|
||||
#
|
||||
# This template has no additional variables.
|
||||
#
|
||||
# * HTML page to display to users if something goes wrong during the
|
||||
# OpenID Connect authentication process: 'sso_error.html'.
|
||||
#
|
||||
# When rendering, this template is given two variables:
|
||||
# * error: the technical name of the error
|
||||
# * error_description: a human-readable message for the error
|
||||
#
|
||||
# You can see the default templates at:
|
||||
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
|
||||
#
|
||||
|
||||
@@ -13,9 +13,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
@attr.s
|
||||
class InstanceLocationConfig:
|
||||
"""The host and port to talk to an instance via HTTP replication.
|
||||
"""
|
||||
|
||||
host = attr.ib(type=str)
|
||||
port = attr.ib(type=int)
|
||||
|
||||
|
||||
@attr.s
|
||||
class WriterLocations:
|
||||
"""Specifies the instances that write various streams.
|
||||
|
||||
Attributes:
|
||||
events: The instance that writes to the event and backfill streams.
|
||||
"""
|
||||
|
||||
events = attr.ib(default="master", type=str)
|
||||
|
||||
|
||||
class WorkerConfig(Config):
|
||||
"""The workers are processes run separately to the main synapse process.
|
||||
They have their own pid_file and listener configuration. They use the
|
||||
@@ -71,6 +93,16 @@ class WorkerConfig(Config):
|
||||
elif not bind_addresses:
|
||||
bind_addresses.append("")
|
||||
|
||||
# A map from instance name to host/port of their HTTP replication endpoint.
|
||||
instance_map = config.get("instance_map", {}) or {}
|
||||
self.instance_map = {
|
||||
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
|
||||
}
|
||||
|
||||
# Map from type of streams to source, c.f. WriterLocations.
|
||||
writers = config.get("writers", {}) or {}
|
||||
self.writers = WriterLocations(**writers)
|
||||
|
||||
def read_arguments(self, args):
|
||||
# We support a bunch of command line arguments that override options in
|
||||
# the config. A lot of these options have a worker_* prefix when running
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from synapse.spam_checker_api import SpamCheckerApi
|
||||
|
||||
@@ -26,24 +26,17 @@ if MYPY:
|
||||
|
||||
class SpamChecker(object):
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
self.spam_checker = None
|
||||
self.spam_checkers = [] # type: List[Any]
|
||||
|
||||
module = None
|
||||
config = None
|
||||
try:
|
||||
module, config = hs.config.spam_checker
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if module is not None:
|
||||
for module, config in hs.config.spam_checkers:
|
||||
# Older spam checkers don't accept the `api` argument, so we
|
||||
# try and detect support.
|
||||
spam_args = inspect.getfullargspec(module)
|
||||
if "api" in spam_args.args:
|
||||
api = SpamCheckerApi(hs)
|
||||
self.spam_checker = module(config=config, api=api)
|
||||
self.spam_checkers.append(module(config=config, api=api))
|
||||
else:
|
||||
self.spam_checker = module(config=config)
|
||||
self.spam_checkers.append(module(config=config))
|
||||
|
||||
def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
|
||||
"""Checks if a given event is considered "spammy" by this server.
|
||||
@@ -58,10 +51,11 @@ class SpamChecker(object):
|
||||
Returns:
|
||||
True if the event is spammy.
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return False
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.check_event_for_spam(event):
|
||||
return True
|
||||
|
||||
return self.spam_checker.check_event_for_spam(event)
|
||||
return False
|
||||
|
||||
def user_may_invite(
|
||||
self, inviter_userid: str, invitee_userid: str, room_id: str
|
||||
@@ -78,12 +72,14 @@ class SpamChecker(object):
|
||||
Returns:
|
||||
True if the user may send an invite, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
for spam_checker in self.spam_checkers:
|
||||
if (
|
||||
spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
|
||||
is False
|
||||
):
|
||||
return False
|
||||
|
||||
return self.spam_checker.user_may_invite(
|
||||
inviter_userid, invitee_userid, room_id
|
||||
)
|
||||
return True
|
||||
|
||||
def user_may_create_room(self, userid: str) -> bool:
|
||||
"""Checks if a given user may create a room
|
||||
@@ -96,10 +92,11 @@ class SpamChecker(object):
|
||||
Returns:
|
||||
True if the user may create a room, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.user_may_create_room(userid) is False:
|
||||
return False
|
||||
|
||||
return self.spam_checker.user_may_create_room(userid)
|
||||
return True
|
||||
|
||||
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
|
||||
"""Checks if a given user may create a room alias
|
||||
@@ -113,10 +110,11 @@ class SpamChecker(object):
|
||||
Returns:
|
||||
True if the user may create a room alias, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
|
||||
return False
|
||||
|
||||
return self.spam_checker.user_may_create_room_alias(userid, room_alias)
|
||||
return True
|
||||
|
||||
def user_may_publish_room(self, userid: str, room_id: str) -> bool:
|
||||
"""Checks if a given user may publish a room to the directory
|
||||
@@ -130,10 +128,11 @@ class SpamChecker(object):
|
||||
Returns:
|
||||
True if the user may publish the room, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.user_may_publish_room(userid, room_id) is False:
|
||||
return False
|
||||
|
||||
return self.spam_checker.user_may_publish_room(userid, room_id)
|
||||
return True
|
||||
|
||||
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
|
||||
"""Checks if a user ID or display name are considered "spammy" by this server.
|
||||
@@ -150,13 +149,14 @@ class SpamChecker(object):
|
||||
Returns:
|
||||
True if the user is spammy.
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return False
|
||||
for spam_checker in self.spam_checkers:
|
||||
# For backwards compatibility, only run if the method exists on the
|
||||
# spam checker
|
||||
checker = getattr(spam_checker, "check_username_for_spam", None)
|
||||
if checker:
|
||||
# Make a copy of the user profile object to ensure the spam checker
|
||||
# cannot modify it.
|
||||
if checker(user_profile.copy()):
|
||||
return True
|
||||
|
||||
# For backwards compatibility, if the method does not exist on the spam checker, fallback to not interfering.
|
||||
checker = getattr(self.spam_checker, "check_username_for_spam", None)
|
||||
if not checker:
|
||||
return False
|
||||
# Make a copy of the user profile object to ensure the spam checker
|
||||
# cannot modify it.
|
||||
return checker(user_profile.copy())
|
||||
return False
|
||||
|
||||
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Dict, List, Tuple, Type
|
||||
|
||||
from six import iteritems
|
||||
|
||||
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
|
||||
self.notifier = hs.get_notifier()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
self.presence_map = {} # Pending presence map user_id -> UserPresenceState
|
||||
self.presence_changed = SortedDict() # Stream position -> list[user_id]
|
||||
# Pending presence map user_id -> UserPresenceState
|
||||
self.presence_map = {} # type: Dict[str, UserPresenceState]
|
||||
|
||||
# Stream position -> list[user_id]
|
||||
self.presence_changed = SortedDict() # type: SortedDict[int, List[str]]
|
||||
|
||||
# Stores the destinations we need to explicitly send presence to about a
|
||||
# given user.
|
||||
# Stream position -> (user_id, destinations)
|
||||
self.presence_destinations = SortedDict()
|
||||
self.presence_destinations = (
|
||||
SortedDict()
|
||||
) # type: SortedDict[int, Tuple[str, List[str]]]
|
||||
|
||||
self.keyed_edu = {} # (destination, key) -> EDU
|
||||
self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
|
||||
# (destination, key) -> EDU
|
||||
self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
|
||||
|
||||
self.edus = SortedDict() # stream position -> Edu
|
||||
# stream position -> (destination, key)
|
||||
self.keyed_edu_changed = (
|
||||
SortedDict()
|
||||
) # type: SortedDict[int, Tuple[str, tuple]]
|
||||
|
||||
self.edus = SortedDict() # type: SortedDict[int, Edu]
|
||||
|
||||
# stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
|
||||
self.pos = 1
|
||||
self.pos_time = SortedDict()
|
||||
|
||||
# map from stream ID to the time that stream entry was generated, so that we
|
||||
# can clear out entries after a while
|
||||
self.pos_time = SortedDict() # type: SortedDict[int, int]
|
||||
|
||||
# EVERYTHING IS SAD. In particular, python only makes new scopes when
|
||||
# we make a new function, so we need to make a new function so the inner
|
||||
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
|
||||
for edu_key in self.keyed_edu_changed.values():
|
||||
live_keys.add(edu_key)
|
||||
|
||||
to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
|
||||
for edu_key in to_del:
|
||||
keys_to_del = [
|
||||
edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
|
||||
]
|
||||
for edu_key in keys_to_del:
|
||||
del self.keyed_edu[edu_key]
|
||||
|
||||
# Delete things out of edu map
|
||||
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
|
||||
self._clear_queue_before_pos(token)
|
||||
|
||||
async def get_replication_rows(
|
||||
self, from_token, to_token, limit, federation_ack=None
|
||||
):
|
||||
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
|
||||
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
|
||||
"""Get rows to be sent over federation between the two tokens
|
||||
|
||||
Args:
|
||||
from_token (int)
|
||||
to_token(int)
|
||||
limit (int)
|
||||
federation_ack (int): Optional. The position where the worker is
|
||||
explicitly acknowledged it has handled. Allows us to drop
|
||||
data from before that point
|
||||
instance_name: the name of the current process
|
||||
from_token: the previous stream token: the starting point for fetching the
|
||||
updates
|
||||
to_token: the new stream token: the point to get updates up to
|
||||
target_row_count: a target for the number of rows to be returned.
|
||||
|
||||
Returns: a triplet `(updates, new_last_token, limited)`, where:
|
||||
* `updates` is a list of `(token, row)` entries.
|
||||
* `new_last_token` is the new position in stream.
|
||||
* `limited` is whether there are more updates to fetch.
|
||||
"""
|
||||
# TODO: Handle limit.
|
||||
# TODO: Handle target_row_count.
|
||||
|
||||
# To handle restarts where we wrap around
|
||||
if from_token > self.pos:
|
||||
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
|
||||
|
||||
# list of tuple(int, BaseFederationRow), where the first is the position
|
||||
# of the federation stream.
|
||||
rows = []
|
||||
|
||||
# There should be only one reader, so lets delete everything its
|
||||
# acknowledged its seen.
|
||||
if federation_ack:
|
||||
self._clear_queue_before_pos(federation_ack)
|
||||
rows = [] # type: List[Tuple[int, BaseFederationRow]]
|
||||
|
||||
# Fetch changed presence
|
||||
i = self.presence_changed.bisect_right(from_token)
|
||||
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
|
||||
# Sort rows based on pos
|
||||
rows.sort()
|
||||
|
||||
return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
|
||||
return (
|
||||
[(pos, (row.TypeId, row.to_data())) for pos, row in rows],
|
||||
to_token,
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
class BaseFederationRow(object):
|
||||
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
|
||||
Specifies how to identify, serialize and deserialize the different types.
|
||||
"""
|
||||
|
||||
TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
|
||||
TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
|
||||
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
||||
|
||||
|
||||
TypeToRow = {
|
||||
Row.TypeId: Row
|
||||
for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
|
||||
}
|
||||
_rowtypes = (
|
||||
PresenceRow,
|
||||
PresenceDestinationsRow,
|
||||
KeyedEduRow,
|
||||
EduRow,
|
||||
) # type: Tuple[Type[BaseFederationRow], ...]
|
||||
|
||||
TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
|
||||
|
||||
|
||||
ParsedFederationStreamData = namedtuple(
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Hashable, Iterable, List, Optional, Set
|
||||
from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from six import itervalues
|
||||
|
||||
@@ -498,14 +498,16 @@ class FederationSender(object):
|
||||
|
||||
self._get_per_destination_queue(destination).attempt_new_transaction()
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
@staticmethod
|
||||
def get_current_token() -> int:
|
||||
# Dummy implementation for case where federation sender isn't offloaded
|
||||
# to a worker.
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def get_replication_rows(
|
||||
self, from_token, to_token, limit, federation_ack=None
|
||||
):
|
||||
instance_name: str, from_token: int, to_token: int, target_row_count: int
|
||||
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
|
||||
# Dummy implementation for case where federation sender isn't offloaded
|
||||
# to a worker.
|
||||
return []
|
||||
return [], 0, False
|
||||
|
||||
@@ -15,11 +15,10 @@
|
||||
# limitations under the License.
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Dict, Hashable, Iterable, List, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
import synapse.server
|
||||
from synapse.api.errors import (
|
||||
FederationDeniedError,
|
||||
HttpResponseException,
|
||||
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import ReadReceipt
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.server
|
||||
|
||||
# This is defined in the Matrix spec and enforced by the receiver.
|
||||
MAX_EDUS_PER_TRANSACTION = 100
|
||||
|
||||
|
||||
@@ -13,11 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
import synapse.server
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.persistence import TransactionActions
|
||||
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
|
||||
)
|
||||
from synapse.util.metrics import measure_func
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -126,13 +126,13 @@ class AuthHandler(BaseHandler):
|
||||
# It notifies the user they are about to give access to their matrix account
|
||||
# to the client.
|
||||
self._sso_redirect_confirm_template = load_jinja2_templates(
|
||||
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
|
||||
hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
|
||||
)[0]
|
||||
# The following template is shown during user interactive authentication
|
||||
# in the fallback auth scenario. It notifies the user that they are
|
||||
# authenticating for an operation to occur on their account.
|
||||
self._sso_auth_confirm_template = load_jinja2_templates(
|
||||
hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"],
|
||||
hs.config.sso_template_dir, ["sso_auth_confirm.html"],
|
||||
)[0]
|
||||
# The following template is shown after a successful user interactive
|
||||
# authentication session. It tells the user they can close the window.
|
||||
|
||||
@@ -125,10 +125,9 @@ class FederationHandler(BaseHandler):
|
||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||
self.config = hs.config
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
|
||||
hs
|
||||
)
|
||||
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
|
||||
self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
|
||||
hs
|
||||
)
|
||||
@@ -2681,8 +2680,7 @@ class FederationHandler(BaseHandler):
|
||||
member_handler = self.hs.get_room_member_handler()
|
||||
await member_handler.send_membership_event(None, event, context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_display_name_to_third_party_invite(
|
||||
async def add_display_name_to_third_party_invite(
|
||||
self, room_version, event_dict, event, context
|
||||
):
|
||||
key = (
|
||||
@@ -2690,10 +2688,10 @@ class FederationHandler(BaseHandler):
|
||||
event.content["third_party_invite"]["signed"]["token"],
|
||||
)
|
||||
original_invite = None
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
original_invite_id = prev_state_ids.get(key)
|
||||
if original_invite_id:
|
||||
original_invite = yield self.store.get_event(
|
||||
original_invite = await self.store.get_event(
|
||||
original_invite_id, allow_none=True
|
||||
)
|
||||
if original_invite:
|
||||
@@ -2714,14 +2712,13 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
builder = self.event_builder_factory.new(room_version, event_dict)
|
||||
EventValidator().validate_builder(builder)
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
event, context = await self.event_creation_handler.create_new_client_event(
|
||||
builder=builder
|
||||
)
|
||||
EventValidator().validate_new(event, self.config)
|
||||
return (event, context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_signature(self, event, context):
|
||||
async def _check_signature(self, event, context):
|
||||
"""
|
||||
Checks that the signature in the event is consistent with its invite.
|
||||
|
||||
@@ -2738,12 +2735,12 @@ class FederationHandler(BaseHandler):
|
||||
signed = event.content["third_party_invite"]["signed"]
|
||||
token = signed["token"]
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
|
||||
|
||||
invite_event = None
|
||||
if invite_event_id:
|
||||
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
|
||||
invite_event = await self.store.get_event(invite_event_id, allow_none=True)
|
||||
|
||||
if not invite_event:
|
||||
raise AuthError(403, "Could not find invite")
|
||||
@@ -2792,7 +2789,7 @@ class FederationHandler(BaseHandler):
|
||||
raise
|
||||
try:
|
||||
if "key_validity_url" in public_key_object:
|
||||
yield self._check_key_revocation(
|
||||
await self._check_key_revocation(
|
||||
public_key, public_key_object["key_validity_url"]
|
||||
)
|
||||
except Exception:
|
||||
@@ -2806,8 +2803,7 @@ class FederationHandler(BaseHandler):
|
||||
last_exception = e
|
||||
raise last_exception
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_key_revocation(self, public_key, url):
|
||||
async def _check_key_revocation(self, public_key, url):
|
||||
"""
|
||||
Checks whether public_key has been revoked.
|
||||
|
||||
@@ -2821,7 +2817,7 @@ class FederationHandler(BaseHandler):
|
||||
for revocation.
|
||||
"""
|
||||
try:
|
||||
response = yield self.http_client.get_json(url, {"public_key": public_key})
|
||||
response = await self.http_client.get_json(url, {"public_key": public_key})
|
||||
except Exception:
|
||||
raise SynapseError(502, "Third party certificate could not be checked")
|
||||
if "valid" not in response or not response["valid"]:
|
||||
@@ -2840,8 +2836,9 @@ class FederationHandler(BaseHandler):
|
||||
backfilled: Whether these events are a result of
|
||||
backfilling or not
|
||||
"""
|
||||
if self.config.worker_app:
|
||||
await self._send_events_to_master(
|
||||
if self.config.worker.writers.events != self._instance_name:
|
||||
await self._send_events(
|
||||
instance_name=self.config.worker.writers.events,
|
||||
store=self.store,
|
||||
event_and_contexts=event_and_contexts,
|
||||
backfilled=backfilled,
|
||||
@@ -2916,8 +2913,7 @@ class FederationHandler(BaseHandler):
|
||||
else:
|
||||
user_joined_room(self.distributor, user, room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_complexity(self, remote_room_hosts, room_id):
|
||||
async def get_room_complexity(self, remote_room_hosts, room_id):
|
||||
"""
|
||||
Fetch the complexity of a remote room over federation.
|
||||
|
||||
@@ -2931,12 +2927,12 @@ class FederationHandler(BaseHandler):
|
||||
"""
|
||||
|
||||
for host in remote_room_hosts:
|
||||
res = yield self.federation_client.get_room_complexity(host, room_id)
|
||||
res = await self.federation_client.get_room_complexity(host, room_id)
|
||||
|
||||
# We got a result, return it.
|
||||
if res:
|
||||
defer.returnValue(res)
|
||||
return res
|
||||
|
||||
# We fell off the bottom, couldn't get the complexity from anyone. Oh
|
||||
# well.
|
||||
defer.returnValue(None)
|
||||
return None
|
||||
|
||||
@@ -72,7 +72,6 @@ class MessageHandler(object):
|
||||
self.state_store = self.storage.state
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
|
||||
self._is_worker_app = bool(hs.config.worker_app)
|
||||
|
||||
# The scheduled call to self._expire_event. None if no call is currently
|
||||
# scheduled.
|
||||
@@ -260,7 +259,6 @@ class MessageHandler(object):
|
||||
Args:
|
||||
event (EventBase): The event to schedule the expiry of.
|
||||
"""
|
||||
assert not self._is_worker_app
|
||||
|
||||
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
|
||||
if not isinstance(expiry_ts, int) or event.is_state():
|
||||
@@ -367,10 +365,11 @@ class EventCreationHandler(object):
|
||||
self.notifier = hs.get_notifier()
|
||||
self.config = hs.config
|
||||
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self.room_invite_state_types = self.hs.config.room_invite_state_types
|
||||
|
||||
self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
|
||||
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
|
||||
|
||||
# This is only used to get at ratelimit function, and maybe_kick_guest_users
|
||||
self.base_handler = BaseHandler(hs)
|
||||
@@ -824,8 +823,9 @@ class EventCreationHandler(object):
|
||||
success = False
|
||||
try:
|
||||
# If we're a worker we need to hit out to the master.
|
||||
if self.config.worker_app:
|
||||
await self.send_event_to_master(
|
||||
if self.config.worker.writers.events != self._instance_name:
|
||||
await self.send_event(
|
||||
instance_name=self.config.worker.writers.events,
|
||||
event_id=event.event_id,
|
||||
store=self.store,
|
||||
requester=requester,
|
||||
@@ -890,7 +890,7 @@ class EventCreationHandler(object):
|
||||
|
||||
This should only be run on master.
|
||||
"""
|
||||
assert not self.config.worker_app
|
||||
assert self.config.worker.writers.events == self._instance_name
|
||||
|
||||
if ratelimit:
|
||||
# We check if this is a room admin redacting an event so that we
|
||||
|
||||
998
synapse/handlers/oidc_handler.py
Normal file
998
synapse/handlers/oidc_handler.py
Normal file
@@ -0,0 +1,998 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Quentin Gliech
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Generic, List, Optional, Tuple, TypeVar
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import attr
|
||||
import pymacaroons
|
||||
from authlib.common.security import generate_token
|
||||
from authlib.jose import JsonWebToken
|
||||
from authlib.oauth2.auth import ClientAuth
|
||||
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
||||
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
|
||||
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
|
||||
from jinja2 import Environment, Template
|
||||
from pymacaroons.exceptions import (
|
||||
MacaroonDeserializationException,
|
||||
MacaroonInvalidSignatureException,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from twisted.web.client import readBody
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SESSION_COOKIE_NAME = b"oidc_session"
|
||||
|
||||
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
|
||||
#: OpenID.Core sec 3.1.3.3.
|
||||
Token = TypedDict(
|
||||
"Token",
|
||||
{
|
||||
"access_token": str,
|
||||
"token_type": str,
|
||||
"id_token": Optional[str],
|
||||
"refresh_token": Optional[str],
|
||||
"expires_in": int,
|
||||
"scope": Optional[str],
|
||||
},
|
||||
)
|
||||
|
||||
#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
|
||||
#: there is no real point of doing this in our case.
|
||||
JWK = Dict[str, str]
|
||||
|
||||
#: A JWK Set, as per RFC7517 sec 5.
|
||||
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
|
||||
|
||||
|
||||
class OidcError(Exception):
|
||||
"""Used to catch errors when calling the token_endpoint
|
||||
"""
|
||||
|
||||
def __init__(self, error, error_description=None):
|
||||
self.error = error
|
||||
self.error_description = error_description
|
||||
|
||||
def __str__(self):
|
||||
if self.error_description:
|
||||
return "{}: {}".format(self.error, self.error_description)
|
||||
return self.error
|
||||
|
||||
|
||||
class MappingException(Exception):
|
||||
"""Used to catch errors when mapping the UserInfo object
|
||||
"""
|
||||
|
||||
|
||||
class OidcHandler:
|
||||
"""Handles requests related to the OpenID Connect login flow.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: HomeServer):
|
||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||
self._scopes = hs.config.oidc_scopes # type: List[str]
|
||||
self._client_auth = ClientAuth(
|
||||
hs.config.oidc_client_id,
|
||||
hs.config.oidc_client_secret,
|
||||
hs.config.oidc_client_auth_method,
|
||||
) # type: ClientAuth
|
||||
self._client_auth_method = hs.config.oidc_client_auth_method # type: str
|
||||
self._subject_claim = hs.config.oidc_subject_claim
|
||||
self._provider_metadata = OpenIDProviderMetadata(
|
||||
issuer=hs.config.oidc_issuer,
|
||||
authorization_endpoint=hs.config.oidc_authorization_endpoint,
|
||||
token_endpoint=hs.config.oidc_token_endpoint,
|
||||
userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
|
||||
jwks_uri=hs.config.oidc_jwks_uri,
|
||||
) # type: OpenIDProviderMetadata
|
||||
self._provider_needs_discovery = hs.config.oidc_discover # type: bool
|
||||
self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
|
||||
hs.config.oidc_user_mapping_provider_config
|
||||
) # type: OidcMappingProvider
|
||||
self._skip_verification = hs.config.oidc_skip_verification # type: bool
|
||||
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._datastore = hs.get_datastore()
|
||||
self._clock = hs.get_clock()
|
||||
self._hostname = hs.hostname # type: str
|
||||
self._server_name = hs.config.server_name # type: str
|
||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||
self._error_template = load_jinja2_templates(
|
||||
hs.config.sso_template_dir, ["sso_error.html"]
|
||||
)[0]
|
||||
|
||||
# identifier for the external_ids table
|
||||
self._auth_provider_id = "oidc"
|
||||
|
||||
def _render_error(
|
||||
self, request, error: str, error_description: Optional[str] = None
|
||||
) -> None:
|
||||
"""Renders the error template and respond with it.
|
||||
|
||||
This is used to show errors to the user. The template of this page can
|
||||
be found under ``synapse/res/templates/sso_error.html``.
|
||||
|
||||
Args:
|
||||
request: The incoming request from the browser.
|
||||
We'll respond with an HTML page describing the error.
|
||||
error: A technical identifier for this error. Those include
|
||||
well-known OAuth2/OIDC error types like invalid_request or
|
||||
access_denied.
|
||||
error_description: A human-readable description of the error.
|
||||
"""
|
||||
html_bytes = self._error_template.render(
|
||||
error=error, error_description=error_description
|
||||
).encode("utf-8")
|
||||
|
||||
request.setResponseCode(400)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%i" % len(html_bytes))
|
||||
request.write(html_bytes)
|
||||
finish_request(request)
|
||||
|
||||
def _validate_metadata(self):
|
||||
"""Verifies the provider metadata.
|
||||
|
||||
This checks the validity of the currently loaded provider. Not
|
||||
everything is checked, only:
|
||||
|
||||
- ``issuer``
|
||||
- ``authorization_endpoint``
|
||||
- ``token_endpoint``
|
||||
- ``response_types_supported`` (checks if "code" is in it)
|
||||
- ``jwks_uri``
|
||||
|
||||
Raises:
|
||||
ValueError: if something in the provider is not valid
|
||||
"""
|
||||
# Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
|
||||
if self._skip_verification is True:
|
||||
return
|
||||
|
||||
m = self._provider_metadata
|
||||
m.validate_issuer()
|
||||
m.validate_authorization_endpoint()
|
||||
m.validate_token_endpoint()
|
||||
|
||||
if m.get("token_endpoint_auth_methods_supported") is not None:
|
||||
m.validate_token_endpoint_auth_methods_supported()
|
||||
if (
|
||||
self._client_auth_method
|
||||
not in m["token_endpoint_auth_methods_supported"]
|
||||
):
|
||||
raise ValueError(
|
||||
'"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format(
|
||||
auth_method=self._client_auth_method,
|
||||
supported=m["token_endpoint_auth_methods_supported"],
|
||||
)
|
||||
)
|
||||
|
||||
if m.get("response_types_supported") is not None:
|
||||
m.validate_response_types_supported()
|
||||
|
||||
if "code" not in m["response_types_supported"]:
|
||||
raise ValueError(
|
||||
'"code" not in "response_types_supported" (%r)'
|
||||
% (m["response_types_supported"],)
|
||||
)
|
||||
|
||||
# If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
|
||||
if self._uses_userinfo:
|
||||
if m.get("userinfo_endpoint") is None:
|
||||
raise ValueError(
|
||||
'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
|
||||
)
|
||||
else:
|
||||
# If we're not using userinfo, we need a valid jwks to validate the ID token
|
||||
if m.get("jwks") is None:
|
||||
if m.get("jwks_uri") is not None:
|
||||
m.validate_jwks_uri()
|
||||
else:
|
||||
raise ValueError('"jwks_uri" must be set')
|
||||
|
||||
@property
|
||||
def _uses_userinfo(self) -> bool:
|
||||
"""Returns True if the ``userinfo_endpoint`` should be used.
|
||||
|
||||
This is based on the requested scopes: if the scopes include
|
||||
``openid``, the provider should give use an ID token containing the
|
||||
user informations. If not, we should fetch them using the
|
||||
``access_token`` with the ``userinfo_endpoint``.
|
||||
"""
|
||||
|
||||
# Maybe that should be user-configurable and not inferred?
|
||||
return "openid" not in self._scopes
|
||||
|
||||
async def load_metadata(self) -> OpenIDProviderMetadata:
|
||||
"""Load and validate the provider metadata.
|
||||
|
||||
The values metadatas are discovered if ``oidc_config.discovery`` is
|
||||
``True`` and then cached.
|
||||
|
||||
Raises:
|
||||
ValueError: if something in the provider is not valid
|
||||
|
||||
Returns:
|
||||
The provider's metadata.
|
||||
"""
|
||||
# If we are using the OpenID Discovery documents, it needs to be loaded once
|
||||
# FIXME: should there be a lock here?
|
||||
if self._provider_needs_discovery:
|
||||
url = get_well_known_url(self._provider_metadata["issuer"], external=True)
|
||||
metadata_response = await self._http_client.get_json(url)
|
||||
# TODO: maybe update the other way around to let user override some values?
|
||||
self._provider_metadata.update(metadata_response)
|
||||
self._provider_needs_discovery = False
|
||||
|
||||
self._validate_metadata()
|
||||
|
||||
return self._provider_metadata
|
||||
|
||||
async def load_jwks(self, force: bool = False) -> JWKS:
|
||||
"""Load the JSON Web Key Set used to sign ID tokens.
|
||||
|
||||
If we're not using the ``userinfo_endpoint``, user infos are extracted
|
||||
from the ID token, which is a JWT signed by keys given by the provider.
|
||||
The keys are then cached.
|
||||
|
||||
Args:
|
||||
force: Force reloading the keys.
|
||||
|
||||
Returns:
|
||||
The key set
|
||||
|
||||
Looks like this::
|
||||
|
||||
{
|
||||
'keys': [
|
||||
{
|
||||
'kid': 'abcdef',
|
||||
'kty': 'RSA',
|
||||
'alg': 'RS256',
|
||||
'use': 'sig',
|
||||
'e': 'XXXX',
|
||||
'n': 'XXXX',
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
if self._uses_userinfo:
|
||||
# We're not using jwt signing, return an empty jwk set
|
||||
return {"keys": []}
|
||||
|
||||
# First check if the JWKS are loaded in the provider metadata.
|
||||
# It can happen either if the provider gives its JWKS in the discovery
|
||||
# document directly or if it was already loaded once.
|
||||
metadata = await self.load_metadata()
|
||||
jwk_set = metadata.get("jwks")
|
||||
if jwk_set is not None and not force:
|
||||
return jwk_set
|
||||
|
||||
# Loading the JWKS using the `jwks_uri` metadata
|
||||
uri = metadata.get("jwks_uri")
|
||||
if not uri:
|
||||
raise RuntimeError('Missing "jwks_uri" in metadata')
|
||||
|
||||
jwk_set = await self._http_client.get_json(uri)
|
||||
|
||||
# Caching the JWKS in the provider's metadata
|
||||
self._provider_metadata["jwks"] = jwk_set
|
||||
return jwk_set
|
||||
|
||||
async def _exchange_code(self, code: str) -> Token:
|
||||
"""Exchange an authorization code for a token.
|
||||
|
||||
This calls the ``token_endpoint`` with the authorization code we
|
||||
received in the callback to exchange it for a token. The call uses the
|
||||
``ClientAuth`` to authenticate with the client with its ID and secret.
|
||||
|
||||
Args:
|
||||
code: The autorization code we got from the callback.
|
||||
|
||||
Returns:
|
||||
A dict containing various tokens.
|
||||
|
||||
May look like this::
|
||||
|
||||
{
|
||||
'token_type': 'bearer',
|
||||
'access_token': 'abcdef',
|
||||
'expires_in': 3599,
|
||||
'id_token': 'ghijkl',
|
||||
'refresh_token': 'mnopqr',
|
||||
}
|
||||
|
||||
Raises:
|
||||
OidcError: when the ``token_endpoint`` returned an error.
|
||||
"""
|
||||
metadata = await self.load_metadata()
|
||||
token_endpoint = metadata.get("token_endpoint")
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": self._http_client.user_agent,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
args = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self._callback_url,
|
||||
}
|
||||
body = urlencode(args, True)
|
||||
|
||||
# Fill the body/headers with credentials
|
||||
uri, headers, body = self._client_auth.prepare(
|
||||
method="POST", uri=token_endpoint, headers=headers, body=body
|
||||
)
|
||||
headers = {k: [v] for (k, v) in headers.items()}
|
||||
|
||||
# Do the actual request
|
||||
# We're not using the SimpleHttpClient util methods as we don't want to
|
||||
# check the HTTP status code and we do the body encoding ourself.
|
||||
response = await self._http_client.request(
|
||||
method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
|
||||
)
|
||||
|
||||
# This is used in multiple error messages below
|
||||
status = "{code} {phrase}".format(
|
||||
code=response.code, phrase=response.phrase.decode("utf-8")
|
||||
)
|
||||
|
||||
resp_body = await readBody(response)
|
||||
|
||||
if response.code >= 500:
|
||||
# In case of a server error, we should first try to decode the body
|
||||
# and check for an error field. If not, we respond with a generic
|
||||
# error message.
|
||||
try:
|
||||
resp = json.loads(resp_body.decode("utf-8"))
|
||||
error = resp["error"]
|
||||
description = resp.get("error_description", error)
|
||||
except (ValueError, KeyError):
|
||||
# Catch ValueError for the JSON decoding and KeyError for the "error" field
|
||||
error = "server_error"
|
||||
description = (
|
||||
(
|
||||
'Authorization server responded with a "{status}" error '
|
||||
"while exchanging the authorization code."
|
||||
).format(status=status),
|
||||
)
|
||||
|
||||
raise OidcError(error, description)
|
||||
|
||||
# Since it is a not a 5xx code, body should be a valid JSON. It will
|
||||
# raise if not.
|
||||
resp = json.loads(resp_body.decode("utf-8"))
|
||||
|
||||
if "error" in resp:
|
||||
error = resp["error"]
|
||||
# In case the authorization server responded with an error field,
|
||||
# it should be a 4xx code. If not, warn about it but don't do
|
||||
# anything special and report the original error message.
|
||||
if response.code < 400:
|
||||
logger.debug(
|
||||
"Invalid response from the authorization server: "
|
||||
'responded with a "{status}" '
|
||||
"but body has an error field: {error!r}".format(
|
||||
status=status, error=resp["error"]
|
||||
)
|
||||
)
|
||||
|
||||
description = resp.get("error_description", error)
|
||||
raise OidcError(error, description)
|
||||
|
||||
# Now, this should not be an error. According to RFC6749 sec 5.1, it
|
||||
# should be a 200 code. We're a bit more flexible than that, and will
|
||||
# only throw on a 4xx code.
|
||||
if response.code >= 400:
|
||||
description = (
|
||||
'Authorization server responded with a "{status}" error '
|
||||
'but did not include an "error" field in its response.'.format(
|
||||
status=status
|
||||
)
|
||||
)
|
||||
logger.warning(description)
|
||||
# Body was still valid JSON. Might be useful to log it for debugging.
|
||||
logger.warning("Code exchange response: {resp!r}".format(resp=resp))
|
||||
raise OidcError("server_error", description)
|
||||
|
||||
return resp
|
||||
|
||||
async def _fetch_userinfo(self, token: Token) -> UserInfo:
|
||||
"""Fetch user informations from the ``userinfo_endpoint``.
|
||||
|
||||
Args:
|
||||
token: the token given by the ``token_endpoint``.
|
||||
Must include an ``access_token`` field.
|
||||
|
||||
Returns:
|
||||
UserInfo: an object representing the user.
|
||||
"""
|
||||
metadata = await self.load_metadata()
|
||||
|
||||
resp = await self._http_client.get_json(
|
||||
metadata["userinfo_endpoint"],
|
||||
headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
|
||||
)
|
||||
|
||||
return UserInfo(resp)
|
||||
|
||||
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
|
||||
"""Return an instance of UserInfo from token's ``id_token``.
|
||||
|
||||
Args:
|
||||
token: the token given by the ``token_endpoint``.
|
||||
Must include an ``id_token`` field.
|
||||
nonce: the nonce value originally sent in the initial authorization
|
||||
request. This value should match the one inside the token.
|
||||
|
||||
Returns:
|
||||
An object representing the user.
|
||||
"""
|
||||
metadata = await self.load_metadata()
|
||||
claims_params = {
|
||||
"nonce": nonce,
|
||||
"client_id": self._client_auth.client_id,
|
||||
}
|
||||
if "access_token" in token:
|
||||
# If we got an `access_token`, there should be an `at_hash` claim
|
||||
# in the `id_token` that we can check against.
|
||||
claims_params["access_token"] = token["access_token"]
|
||||
claims_cls = CodeIDToken
|
||||
else:
|
||||
claims_cls = ImplicitIDToken
|
||||
|
||||
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
||||
|
||||
jwt = JsonWebToken(alg_values)
|
||||
|
||||
claim_options = {"iss": {"values": [metadata["issuer"]]}}
|
||||
|
||||
# Try to decode the keys in cache first, then retry by forcing the keys
|
||||
# to be reloaded
|
||||
jwk_set = await self.load_jwks()
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
token["id_token"],
|
||||
key=jwk_set,
|
||||
claims_cls=claims_cls,
|
||||
claims_options=claim_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
except ValueError:
|
||||
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
|
||||
claims = jwt.decode(
|
||||
token["id_token"],
|
||||
key=jwk_set,
|
||||
claims_cls=claims_cls,
|
||||
claims_options=claim_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
|
||||
claims.validate(leeway=120) # allows 2 min of clock skew
|
||||
return UserInfo(claims)
|
||||
|
||||
async def handle_redirect_request(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes
|
||||
) -> None:
|
||||
"""Handle an incoming request to /login/sso/redirect
|
||||
|
||||
It redirects the browser to the authorization endpoint with a few
|
||||
parameters:
|
||||
|
||||
- ``client_id``: the client ID set in ``oidc_config.client_id``
|
||||
- ``response_type``: ``code``
|
||||
- ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
|
||||
- ``scope``: the list of scopes set in ``oidc_config.scopes``
|
||||
- ``state``: a random string
|
||||
- ``nonce``: a random string
|
||||
|
||||
In addition to redirecting the client, we are setting a cookie with
|
||||
a signed macaroon token containing the state, the nonce and the
|
||||
client_redirect_url params. Those are then checked when the client
|
||||
comes back from the provider.
|
||||
|
||||
|
||||
Args:
|
||||
request: the incoming request from the browser.
|
||||
We'll respond to it with a redirect and a cookie.
|
||||
client_redirect_url: the URL that we should redirect the client to
|
||||
when everything is done
|
||||
"""
|
||||
|
||||
state = generate_token()
|
||||
nonce = generate_token()
|
||||
|
||||
cookie = self._generate_oidc_session_token(
|
||||
state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
|
||||
)
|
||||
request.addCookie(
|
||||
SESSION_COOKIE_NAME,
|
||||
cookie,
|
||||
path="/_synapse/oidc",
|
||||
max_age="3600",
|
||||
httpOnly=True,
|
||||
sameSite="lax",
|
||||
)
|
||||
|
||||
metadata = await self.load_metadata()
|
||||
authorization_endpoint = metadata.get("authorization_endpoint")
|
||||
uri = prepare_grant_uri(
|
||||
authorization_endpoint,
|
||||
client_id=self._client_auth.client_id,
|
||||
response_type="code",
|
||||
redirect_uri=self._callback_url,
|
||||
scope=self._scopes,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
request.redirect(uri)
|
||||
finish_request(request)
|
||||
|
||||
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
|
||||
"""Handle an incoming request to /_synapse/oidc/callback
|
||||
|
||||
Since we might want to display OIDC-related errors in a user-friendly
|
||||
way, we don't raise SynapseError from here. Instead, we call
|
||||
``self._render_error`` which displays an HTML page for the error.
|
||||
|
||||
Most of the OpenID Connect logic happens here:
|
||||
|
||||
- first, we check if there was any error returned by the provider and
|
||||
display it
|
||||
- then we fetch the session cookie, decode and verify it
|
||||
- the ``state`` query parameter should match with the one stored in the
|
||||
session cookie
|
||||
- once we known this session is legit, exchange the code with the
|
||||
provider using the ``token_endpoint`` (see ``_exchange_code``)
|
||||
- once we have the token, use it to either extract the UserInfo from
|
||||
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
|
||||
to fetch UserInfo from the ``userinfo_endpoint``
|
||||
(``_fetch_userinfo``)
|
||||
- map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
|
||||
finish the login
|
||||
|
||||
Args:
|
||||
request: the incoming request from the browser.
|
||||
"""
|
||||
|
||||
# The provider might redirect with an error.
|
||||
# In that case, just display it as-is.
|
||||
if b"error" in request.args:
|
||||
error = request.args[b"error"][0].decode()
|
||||
description = request.args.get(b"error_description", [b""])[0].decode()
|
||||
|
||||
# Most of the errors returned by the provider could be due by
|
||||
# either the provider misbehaving or Synapse being misconfigured.
|
||||
# The only exception of that is "access_denied", where the user
|
||||
# probably cancelled the login flow. In other cases, log those errors.
|
||||
if error != "access_denied":
|
||||
logger.error("Error from the OIDC provider: %s %s", error, description)
|
||||
|
||||
self._render_error(request, error, description)
|
||||
return
|
||||
|
||||
# Fetch the session cookie
|
||||
session = request.getCookie(SESSION_COOKIE_NAME)
|
||||
if session is None:
|
||||
logger.info("No session cookie found")
|
||||
self._render_error(request, "missing_session", "No session cookie found")
|
||||
return
|
||||
|
||||
# Remove the cookie. There is a good chance that if the callback failed
|
||||
# once, it will fail next time and the code will already be exchanged.
|
||||
# Removing it early avoids spamming the provider with token requests.
|
||||
request.addCookie(
|
||||
SESSION_COOKIE_NAME,
|
||||
b"",
|
||||
path="/_synapse/oidc",
|
||||
expires="Thu, Jan 01 1970 00:00:00 UTC",
|
||||
httpOnly=True,
|
||||
sameSite="lax",
|
||||
)
|
||||
|
||||
# Check for the state query parameter
|
||||
if b"state" not in request.args:
|
||||
logger.info("State parameter is missing")
|
||||
self._render_error(request, "invalid_request", "State parameter is missing")
|
||||
return
|
||||
|
||||
state = request.args[b"state"][0].decode()
|
||||
|
||||
# Deserialize the session token and verify it.
|
||||
try:
|
||||
nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
|
||||
except MacaroonDeserializationException as e:
|
||||
logger.exception("Invalid session")
|
||||
self._render_error(request, "invalid_session", str(e))
|
||||
return
|
||||
except MacaroonInvalidSignatureException as e:
|
||||
logger.exception("Could not verify session")
|
||||
self._render_error(request, "mismatching_session", str(e))
|
||||
return
|
||||
|
||||
# Exchange the code with the provider
|
||||
if b"code" not in request.args:
|
||||
logger.info("Code parameter is missing")
|
||||
self._render_error(request, "invalid_request", "Code parameter is missing")
|
||||
return
|
||||
|
||||
logger.info("Exchanging code")
|
||||
code = request.args[b"code"][0].decode()
|
||||
try:
|
||||
token = await self._exchange_code(code)
|
||||
except OidcError as e:
|
||||
logger.exception("Could not exchange code")
|
||||
self._render_error(request, e.error, e.error_description)
|
||||
return
|
||||
|
||||
# Now that we have a token, get the userinfo, either by decoding the
|
||||
# `id_token` or by fetching the `userinfo_endpoint`.
|
||||
if self._uses_userinfo:
|
||||
logger.info("Fetching userinfo")
|
||||
try:
|
||||
userinfo = await self._fetch_userinfo(token)
|
||||
except Exception as e:
|
||||
logger.exception("Could not fetch userinfo")
|
||||
self._render_error(request, "fetch_error", str(e))
|
||||
return
|
||||
else:
|
||||
logger.info("Extracting userinfo from id_token")
|
||||
try:
|
||||
userinfo = await self._parse_id_token(token, nonce=nonce)
|
||||
except Exception as e:
|
||||
logger.exception("Invalid id_token")
|
||||
self._render_error(request, "invalid_token", str(e))
|
||||
return
|
||||
|
||||
# Call the mapper to register/login the user
|
||||
try:
|
||||
user_id = await self._map_userinfo_to_user(userinfo, token)
|
||||
except MappingException as e:
|
||||
logger.exception("Could not map user")
|
||||
self._render_error(request, "mapping_error", str(e))
|
||||
return
|
||||
|
||||
# and finally complete the login
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id, request, client_redirect_url
|
||||
)
|
||||
|
||||
def _generate_oidc_session_token(
|
||||
self,
|
||||
state: str,
|
||||
nonce: str,
|
||||
client_redirect_url: str,
|
||||
duration_in_ms: int = (60 * 60 * 1000),
|
||||
) -> str:
|
||||
"""Generates a signed token storing data about an OIDC session.
|
||||
|
||||
When Synapse initiates an authorization flow, it creates a random state
|
||||
and a random nonce. Those parameters are given to the provider and
|
||||
should be verified when the client comes back from the provider.
|
||||
It is also used to store the client_redirect_url, which is used to
|
||||
complete the SSO login flow.
|
||||
|
||||
Args:
|
||||
state: The ``state`` parameter passed to the OIDC provider.
|
||||
nonce: The ``nonce`` parameter passed to the OIDC provider.
|
||||
client_redirect_url: The URL the client gave when it initiated the
|
||||
flow.
|
||||
duration_in_ms: An optional duration for the token in milliseconds.
|
||||
Defaults to an hour.
|
||||
|
||||
Returns:
|
||||
A signed macaroon token with the session informations.
|
||||
"""
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = session")
|
||||
macaroon.add_first_party_caveat("state = %s" % (state,))
|
||||
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
|
||||
macaroon.add_first_party_caveat(
|
||||
"client_redirect_url = %s" % (client_redirect_url,)
|
||||
)
|
||||
now = self._clock.time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
return macaroon.serialize()
|
||||
|
||||
def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
|
||||
"""Verifies and extract an OIDC session token.
|
||||
|
||||
This verifies that a given session token was issued by this homeserver
|
||||
and extract the nonce and client_redirect_url caveats.
|
||||
|
||||
Args:
|
||||
session: The session token to verify
|
||||
state: The state the OIDC provider gave back
|
||||
|
||||
Returns:
|
||||
The nonce and the client_redirect_url for this session
|
||||
"""
|
||||
macaroon = pymacaroons.Macaroon.deserialize(session)
|
||||
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_exact("gen = 1")
|
||||
v.satisfy_exact("type = session")
|
||||
v.satisfy_exact("state = %s" % (state,))
|
||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
|
||||
v.verify(macaroon, self._macaroon_secret_key)
|
||||
|
||||
# Extract the `nonce` and `client_redirect_url` from the token
|
||||
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
||||
client_redirect_url = self._get_value_from_macaroon(
|
||||
macaroon, "client_redirect_url"
|
||||
)
|
||||
|
||||
return nonce, client_redirect_url
|
||||
|
||||
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||
"""Extracts a caveat value from a macaroon token.
|
||||
|
||||
Args:
|
||||
macaroon: the token
|
||||
key: the key of the caveat to extract
|
||||
|
||||
Returns:
|
||||
The extracted value
|
||||
|
||||
Raises:
|
||||
Exception: if the caveat was not in the macaroon
|
||||
"""
|
||||
prefix = key + " = "
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(prefix):
|
||||
return caveat.caveat_id[len(prefix) :]
|
||||
raise Exception("No %s caveat in macaroon" % (key,))
|
||||
|
||||
def _verify_expiry(self, caveat: str) -> bool:
|
||||
prefix = "time < "
|
||||
if not caveat.startswith(prefix):
|
||||
return False
|
||||
expiry = int(caveat[len(prefix) :])
|
||||
now = self._clock.time_msec()
|
||||
return now < expiry
|
||||
|
||||
async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
|
||||
"""Maps a UserInfo object to a mxid.
|
||||
|
||||
UserInfo should have a claim that uniquely identifies users. This claim
|
||||
is usually `sub`, but can be configured with `oidc_config.subject_claim`.
|
||||
It is then used as an `external_id`.
|
||||
|
||||
If we don't find the user that way, we should register the user,
|
||||
mapping the localpart and the display name from the UserInfo.
|
||||
|
||||
If a user already exists with the mxid we've mapped, raise an exception.
|
||||
|
||||
Args:
|
||||
userinfo: an object representing the user
|
||||
token: a dict with the tokens obtained from the provider
|
||||
|
||||
Raises:
|
||||
MappingException: if there was an error while mapping some properties
|
||||
|
||||
Returns:
|
||||
The mxid of the user
|
||||
"""
|
||||
try:
|
||||
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
|
||||
except Exception as e:
|
||||
raise MappingException(
|
||||
"Failed to extract subject from OIDC response: %s" % (e,)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Looking for existing mapping for user %s:%s",
|
||||
self._auth_provider_id,
|
||||
remote_user_id,
|
||||
)
|
||||
|
||||
registered_user_id = await self._datastore.get_user_by_external_id(
|
||||
self._auth_provider_id, remote_user_id,
|
||||
)
|
||||
|
||||
if registered_user_id is not None:
|
||||
logger.info("Found existing mapping %s", registered_user_id)
|
||||
return registered_user_id
|
||||
|
||||
try:
|
||||
attributes = await self._user_mapping_provider.map_user_attributes(
|
||||
userinfo, token
|
||||
)
|
||||
except Exception as e:
|
||||
raise MappingException(
|
||||
"Could not extract user attributes from OIDC response: " + str(e)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Retrieved user attributes from user mapping provider: %r", attributes
|
||||
)
|
||||
|
||||
if not attributes["localpart"]:
|
||||
raise MappingException("localpart is empty")
|
||||
|
||||
localpart = map_username_to_mxid_localpart(attributes["localpart"])
|
||||
|
||||
user_id = UserID(localpart, self._hostname)
|
||||
if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
|
||||
# This mxid is taken
|
||||
raise MappingException(
|
||||
"mxid '{}' is already taken".format(user_id.to_string())
|
||||
)
|
||||
|
||||
# It's the first time this user is logging in and the mapped mxid was
|
||||
# not taken, register the user
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=attributes["display_name"],
|
||||
)
|
||||
|
||||
await self._datastore.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id,
|
||||
)
|
||||
return registered_user_id
|
||||
|
||||
|
||||
UserAttribute = TypedDict(
|
||||
"UserAttribute", {"localpart": str, "display_name": Optional[str]}
|
||||
)
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
class OidcMappingProvider(Generic[C]):
|
||||
"""A mapping provider maps a UserInfo object to user attributes.
|
||||
|
||||
It should provide the API described by this class.
|
||||
"""
|
||||
|
||||
def __init__(self, config: C):
|
||||
"""
|
||||
Args:
|
||||
config: A custom config object from this module, parsed by ``parse_config()``
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config: dict) -> C:
|
||||
"""Parse the dict provided by the homeserver's config
|
||||
|
||||
Args:
|
||||
config: A dictionary containing configuration options for this provider
|
||||
|
||||
Returns:
|
||||
A custom config object for this module
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_remote_user_id(self, userinfo: UserInfo) -> str:
|
||||
"""Get a unique user ID for this user.
|
||||
|
||||
Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object.
|
||||
|
||||
Args:
|
||||
userinfo: An object representing the user given by the OIDC provider
|
||||
|
||||
Returns:
|
||||
A unique user ID
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def map_user_attributes(
|
||||
self, userinfo: UserInfo, token: Token
|
||||
) -> UserAttribute:
|
||||
"""Map a ``UserInfo`` objects into user attributes.
|
||||
|
||||
Args:
|
||||
userinfo: An object representing the user given by the OIDC provider
|
||||
token: A dict with the tokens returned by the provider
|
||||
|
||||
Returns:
|
||||
A dict containing the ``localpart`` and (optionally) the ``display_name``
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# Used to clear out "None" values in templates
|
||||
def jinja_finalize(thing):
|
||||
return thing if thing is not None else ""
|
||||
|
||||
|
||||
env = Environment(finalize=jinja_finalize)
|
||||
|
||||
|
||||
@attr.s
|
||||
class JinjaOidcMappingConfig:
|
||||
subject_claim = attr.ib() # type: str
|
||||
localpart_template = attr.ib() # type: Template
|
||||
display_name_template = attr.ib() # type: Optional[Template]
|
||||
|
||||
|
||||
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||
"""An implementation of a mapping provider based on Jinja templates.
|
||||
|
||||
This is the default mapping provider.
|
||||
"""
|
||||
|
||||
def __init__(self, config: JinjaOidcMappingConfig):
|
||||
self._config = config
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config: dict) -> JinjaOidcMappingConfig:
|
||||
subject_claim = config.get("subject_claim", "sub")
|
||||
|
||||
if "localpart_template" not in config:
|
||||
raise ConfigError(
|
||||
"missing key: oidc_config.user_mapping_provider.config.localpart_template"
|
||||
)
|
||||
|
||||
try:
|
||||
localpart_template = env.from_string(config["localpart_template"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
|
||||
% (e,)
|
||||
)
|
||||
|
||||
display_name_template = None # type: Optional[Template]
|
||||
if "display_name_template" in config:
|
||||
try:
|
||||
display_name_template = env.from_string(config["display_name_template"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
|
||||
% (e,)
|
||||
)
|
||||
|
||||
return JinjaOidcMappingConfig(
|
||||
subject_claim=subject_claim,
|
||||
localpart_template=localpart_template,
|
||||
display_name_template=display_name_template,
|
||||
)
|
||||
|
||||
def get_remote_user_id(self, userinfo: UserInfo) -> str:
|
||||
return userinfo[self._config.subject_claim]
|
||||
|
||||
async def map_user_attributes(
|
||||
self, userinfo: UserInfo, token: Token
|
||||
) -> UserAttribute:
|
||||
localpart = self._config.localpart_template.render(user=userinfo).strip()
|
||||
|
||||
display_name = None # type: Optional[str]
|
||||
if self._config.display_name_template is not None:
|
||||
display_name = self._config.display_name_template.render(
|
||||
user=userinfo
|
||||
).strip()
|
||||
|
||||
if display_name == "":
|
||||
display_name = None
|
||||
|
||||
return UserAttribute(localpart=localpart, display_name=display_name)
|
||||
@@ -25,8 +25,6 @@ from collections import OrderedDict
|
||||
|
||||
from six import iteritems, string_types
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
self.third_party_event_rules = hs.get_third_party_event_rules()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def upgrade_room(
|
||||
async def upgrade_room(
|
||||
self, requester: Requester, old_room_id: str, new_version: RoomVersion
|
||||
):
|
||||
"""Replace a room with a new room with a different version
|
||||
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
Returns:
|
||||
Deferred[unicode]: the new room id
|
||||
"""
|
||||
yield self.ratelimit(requester)
|
||||
await self.ratelimit(requester)
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
# If this user has sent multiple upgrade requests for the same room
|
||||
# and one of them is not complete yet, cache the response and
|
||||
# return it to all subsequent requests
|
||||
ret = yield self._upgrade_response_cache.wrap(
|
||||
ret = await self._upgrade_response_cache.wrap(
|
||||
(old_room_id, user_id),
|
||||
self._upgrade_room,
|
||||
requester,
|
||||
@@ -856,8 +853,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
for (etype, state_key), content in initial_state.items():
|
||||
await send(etype=etype, state_key=state_key, content=content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_room_id(
|
||||
async def _generate_room_id(
|
||||
self, creator_id: str, is_public: str, room_version: RoomVersion,
|
||||
):
|
||||
# autogen room IDs and try to create it. We may clash, so just
|
||||
@@ -869,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
|
||||
if isinstance(gen_room_id, bytes):
|
||||
gen_room_id = gen_room_id.decode("utf-8")
|
||||
yield self.store.store_room(
|
||||
await self.store.store_room(
|
||||
room_id=gen_room_id,
|
||||
room_creator_user_id=creator_id,
|
||||
is_public=is_public,
|
||||
@@ -888,8 +884,7 @@ class RoomContextHandler(object):
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event_context(self, user, room_id, event_id, limit, event_filter):
|
||||
async def get_event_context(self, user, room_id, event_id, limit, event_filter):
|
||||
"""Retrieves events, pagination tokens and state around a given event
|
||||
in a room.
|
||||
|
||||
@@ -908,7 +903,7 @@ class RoomContextHandler(object):
|
||||
before_limit = math.floor(limit / 2.0)
|
||||
after_limit = limit - before_limit
|
||||
|
||||
users = yield self.store.get_users_in_room(room_id)
|
||||
users = await self.store.get_users_in_room(room_id)
|
||||
is_peeking = user.to_string() not in users
|
||||
|
||||
def filter_evts(events):
|
||||
@@ -916,17 +911,17 @@ class RoomContextHandler(object):
|
||||
self.storage, user.to_string(), events, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
event = yield self.store.get_event(
|
||||
event = await self.store.get_event(
|
||||
event_id, get_prev_content=True, allow_none=True
|
||||
)
|
||||
if not event:
|
||||
return None
|
||||
|
||||
filtered = yield (filter_evts([event]))
|
||||
filtered = await filter_evts([event])
|
||||
if not filtered:
|
||||
raise AuthError(403, "You don't have permission to access that event.")
|
||||
|
||||
results = yield self.store.get_events_around(
|
||||
results = await self.store.get_events_around(
|
||||
room_id, event_id, before_limit, after_limit, event_filter
|
||||
)
|
||||
|
||||
@@ -934,8 +929,8 @@ class RoomContextHandler(object):
|
||||
results["events_before"] = event_filter.filter(results["events_before"])
|
||||
results["events_after"] = event_filter.filter(results["events_after"])
|
||||
|
||||
results["events_before"] = yield filter_evts(results["events_before"])
|
||||
results["events_after"] = yield filter_evts(results["events_after"])
|
||||
results["events_before"] = await filter_evts(results["events_before"])
|
||||
results["events_after"] = await filter_evts(results["events_after"])
|
||||
# filter_evts can return a pruned event in case the user is allowed to see that
|
||||
# there's something there but not see the content, so use the event that's in
|
||||
# `filtered` rather than the event we retrieved from the datastore.
|
||||
@@ -962,7 +957,7 @@ class RoomContextHandler(object):
|
||||
# first? Shouldn't we be consistent with /sync?
|
||||
# https://github.com/matrix-org/matrix-doc/issues/687
|
||||
|
||||
state = yield self.state_store.get_state_for_events(
|
||||
state = await self.state_store.get_state_for_events(
|
||||
[last_event_id], state_filter=state_filter
|
||||
)
|
||||
|
||||
@@ -970,7 +965,7 @@ class RoomContextHandler(object):
|
||||
if event_filter:
|
||||
state_events = event_filter.filter(state_events)
|
||||
|
||||
results["state"] = yield filter_evts(state_events)
|
||||
results["state"] = await filter_evts(state_events)
|
||||
|
||||
# We use a dummy token here as we only care about the room portion of
|
||||
# the token, which we replace.
|
||||
@@ -989,13 +984,12 @@ class RoomEventSource(object):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events(
|
||||
async def get_new_events(
|
||||
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
|
||||
):
|
||||
# We just ignore the key for now.
|
||||
|
||||
to_key = yield self.get_current_key()
|
||||
to_key = await self.get_current_key()
|
||||
|
||||
from_token = RoomStreamToken.parse(from_key)
|
||||
if from_token.topological:
|
||||
@@ -1008,11 +1002,11 @@ class RoomEventSource(object):
|
||||
# See https://github.com/matrix-org/matrix-doc/issues/1144
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
room_events = yield self.store.get_membership_changes_for_user(
|
||||
room_events = await self.store.get_membership_changes_for_user(
|
||||
user.to_string(), from_key, to_key
|
||||
)
|
||||
|
||||
room_to_events = yield self.store.get_room_events_stream_for_rooms(
|
||||
room_to_events = await self.store.get_room_events_stream_for_rooms(
|
||||
room_ids=room_ids,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
|
||||
@@ -875,8 +875,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||
self.distributor.declare("user_joined_room")
|
||||
self.distributor.declare("user_left_room")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
|
||||
async def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
|
||||
"""
|
||||
Check if complexity of a remote room is too great.
|
||||
|
||||
@@ -888,7 +887,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||
if unable to be fetched
|
||||
"""
|
||||
max_complexity = self.hs.config.limit_remote_rooms.complexity
|
||||
complexity = yield self.federation_handler.get_room_complexity(
|
||||
complexity = await self.federation_handler.get_room_complexity(
|
||||
remote_room_hosts, room_id
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from typing import Callable, Dict, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
import saml2
|
||||
@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.module_api.errors import RedirectException
|
||||
from synapse.types import (
|
||||
@@ -81,17 +82,19 @@ class SamlHandler:
|
||||
|
||||
self._error_html_content = hs.config.saml2_error_html_content
|
||||
|
||||
def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
|
||||
def handle_redirect_request(
|
||||
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
|
||||
) -> bytes:
|
||||
"""Handle an incoming request to /login/sso/redirect
|
||||
|
||||
Args:
|
||||
client_redirect_url (bytes): the URL that we should redirect the
|
||||
client_redirect_url: the URL that we should redirect the
|
||||
client to when everything is done
|
||||
ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
|
||||
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
||||
None if this is a login).
|
||||
|
||||
Returns:
|
||||
bytes: URL to redirect to
|
||||
URL to redirect to
|
||||
"""
|
||||
reqid, info = self._saml_client.prepare_for_authenticate(
|
||||
relay_state=client_redirect_url
|
||||
@@ -109,15 +112,15 @@ class SamlHandler:
|
||||
# this shouldn't happen!
|
||||
raise Exception("prepare_for_authenticate didn't return a Location header")
|
||||
|
||||
async def handle_saml_response(self, request):
|
||||
async def handle_saml_response(self, request: SynapseRequest) -> None:
|
||||
"""Handle an incoming request to /_matrix/saml2/authn_response
|
||||
|
||||
Args:
|
||||
request (SynapseRequest): the incoming request from the browser. We'll
|
||||
request: the incoming request from the browser. We'll
|
||||
respond to it with a redirect.
|
||||
|
||||
Returns:
|
||||
Deferred[none]: Completes once we have handled the request.
|
||||
Completes once we have handled the request.
|
||||
"""
|
||||
resp_bytes = parse_string(request, "SAMLResponse", required=True)
|
||||
relay_state = parse_string(request, "RelayState", required=True)
|
||||
@@ -310,6 +313,7 @@ DOT_REPLACE_PATTERN = re.compile(
|
||||
|
||||
|
||||
def dot_replace_for_mxid(username: str) -> str:
|
||||
"""Replace any characters which are not allowed in Matrix IDs with a dot."""
|
||||
username = username.lower()
|
||||
username = DOT_REPLACE_PATTERN.sub(".", username)
|
||||
|
||||
@@ -321,7 +325,7 @@ def dot_replace_for_mxid(username: str) -> str:
|
||||
MXID_MAPPER_MAP = {
|
||||
"hexencode": map_username_to_mxid_localpart,
|
||||
"dotreplace": dot_replace_for_mxid,
|
||||
}
|
||||
} # type: Dict[str, Callable[[str], str]]
|
||||
|
||||
|
||||
@attr.s
|
||||
@@ -349,7 +353,7 @@ class DefaultSamlMappingProvider(object):
|
||||
|
||||
def get_remote_user_id(
|
||||
self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
|
||||
):
|
||||
) -> str:
|
||||
"""Extracts the remote user id from the SAML response"""
|
||||
try:
|
||||
return saml_response.ava["uid"][0]
|
||||
@@ -428,14 +432,14 @@ class DefaultSamlMappingProvider(object):
|
||||
return SamlConfig(mxid_source_attribute, mxid_mapper)
|
||||
|
||||
@staticmethod
|
||||
def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
|
||||
def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
|
||||
"""Returns the required attributes of a SAML
|
||||
|
||||
Args:
|
||||
config: A SamlConfig object containing configuration params for this provider
|
||||
|
||||
Returns:
|
||||
tuple[set,set]: The first set equates to the saml auth response
|
||||
The first set equates to the saml auth response
|
||||
attributes that are required for the module to function, whereas the
|
||||
second set consists of those attributes which can be used if
|
||||
available, but are not necessary
|
||||
|
||||
@@ -18,8 +18,6 @@ import logging
|
||||
|
||||
from unpaddedbase64 import decode_base64, encode_base64
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
@@ -39,8 +37,7 @@ class SearchHandler(BaseHandler):
|
||||
self.state_store = self.storage.state
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_old_rooms_from_upgraded_room(self, room_id):
|
||||
async def get_old_rooms_from_upgraded_room(self, room_id):
|
||||
"""Retrieves room IDs of old rooms in the history of an upgraded room.
|
||||
|
||||
We do so by checking the m.room.create event of the room for a
|
||||
@@ -60,7 +57,7 @@ class SearchHandler(BaseHandler):
|
||||
historical_room_ids = []
|
||||
|
||||
# The initial room must have been known for us to get this far
|
||||
predecessor = yield self.store.get_room_predecessor(room_id)
|
||||
predecessor = await self.store.get_room_predecessor(room_id)
|
||||
|
||||
while True:
|
||||
if not predecessor:
|
||||
@@ -75,7 +72,7 @@ class SearchHandler(BaseHandler):
|
||||
|
||||
# Don't add it to the list until we have checked that we are in the room
|
||||
try:
|
||||
next_predecessor_room = yield self.store.get_room_predecessor(
|
||||
next_predecessor_room = await self.store.get_room_predecessor(
|
||||
predecessor_room_id
|
||||
)
|
||||
except NotFoundError:
|
||||
@@ -89,8 +86,7 @@ class SearchHandler(BaseHandler):
|
||||
|
||||
return historical_room_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def search(self, user, content, batch=None):
|
||||
async def search(self, user, content, batch=None):
|
||||
"""Performs a full text search for a user.
|
||||
|
||||
Args:
|
||||
@@ -179,7 +175,7 @@ class SearchHandler(BaseHandler):
|
||||
search_filter = Filter(filter_dict)
|
||||
|
||||
# TODO: Search through left rooms too
|
||||
rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
|
||||
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
|
||||
user.to_string(),
|
||||
membership_list=[Membership.JOIN],
|
||||
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
|
||||
@@ -192,7 +188,7 @@ class SearchHandler(BaseHandler):
|
||||
historical_room_ids = []
|
||||
for room_id in search_filter.rooms:
|
||||
# Add any previous rooms to the search if they exist
|
||||
ids = yield self.get_old_rooms_from_upgraded_room(room_id)
|
||||
ids = await self.get_old_rooms_from_upgraded_room(room_id)
|
||||
historical_room_ids += ids
|
||||
|
||||
# Prevent any historical events from being filtered
|
||||
@@ -223,7 +219,7 @@ class SearchHandler(BaseHandler):
|
||||
count = None
|
||||
|
||||
if order_by == "rank":
|
||||
search_result = yield self.store.search_msgs(room_ids, search_term, keys)
|
||||
search_result = await self.store.search_msgs(room_ids, search_term, keys)
|
||||
|
||||
count = search_result["count"]
|
||||
|
||||
@@ -238,7 +234,7 @@ class SearchHandler(BaseHandler):
|
||||
|
||||
filtered_events = search_filter.filter([r["event"] for r in results])
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
events = await filter_events_for_client(
|
||||
self.storage, user.to_string(), filtered_events
|
||||
)
|
||||
|
||||
@@ -267,7 +263,7 @@ class SearchHandler(BaseHandler):
|
||||
# But only go around 5 times since otherwise synapse will be sad.
|
||||
while len(room_events) < search_filter.limit() and i < 5:
|
||||
i += 1
|
||||
search_result = yield self.store.search_rooms(
|
||||
search_result = await self.store.search_rooms(
|
||||
room_ids,
|
||||
search_term,
|
||||
keys,
|
||||
@@ -288,7 +284,7 @@ class SearchHandler(BaseHandler):
|
||||
|
||||
filtered_events = search_filter.filter([r["event"] for r in results])
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
events = await filter_events_for_client(
|
||||
self.storage, user.to_string(), filtered_events
|
||||
)
|
||||
|
||||
@@ -343,11 +339,11 @@ class SearchHandler(BaseHandler):
|
||||
# If client has asked for "context" for each event (i.e. some surrounding
|
||||
# events and state), fetch that
|
||||
if event_context is not None:
|
||||
now_token = yield self.hs.get_event_sources().get_current_token()
|
||||
now_token = await self.hs.get_event_sources().get_current_token()
|
||||
|
||||
contexts = {}
|
||||
for event in allowed_events:
|
||||
res = yield self.store.get_events_around(
|
||||
res = await self.store.get_events_around(
|
||||
event.room_id, event.event_id, before_limit, after_limit
|
||||
)
|
||||
|
||||
@@ -357,11 +353,11 @@ class SearchHandler(BaseHandler):
|
||||
len(res["events_after"]),
|
||||
)
|
||||
|
||||
res["events_before"] = yield filter_events_for_client(
|
||||
res["events_before"] = await filter_events_for_client(
|
||||
self.storage, user.to_string(), res["events_before"]
|
||||
)
|
||||
|
||||
res["events_after"] = yield filter_events_for_client(
|
||||
res["events_after"] = await filter_events_for_client(
|
||||
self.storage, user.to_string(), res["events_after"]
|
||||
)
|
||||
|
||||
@@ -390,7 +386,7 @@ class SearchHandler(BaseHandler):
|
||||
[(EventTypes.Member, sender) for sender in senders]
|
||||
)
|
||||
|
||||
state = yield self.state_store.get_state_for_event(
|
||||
state = await self.state_store.get_state_for_event(
|
||||
last_event_id, state_filter
|
||||
)
|
||||
|
||||
@@ -412,10 +408,10 @@ class SearchHandler(BaseHandler):
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
for context in contexts.values():
|
||||
context["events_before"] = yield self._event_serializer.serialize_events(
|
||||
context["events_before"] = await self._event_serializer.serialize_events(
|
||||
context["events_before"], time_now
|
||||
)
|
||||
context["events_after"] = yield self._event_serializer.serialize_events(
|
||||
context["events_after"] = await self._event_serializer.serialize_events(
|
||||
context["events_after"], time_now
|
||||
)
|
||||
|
||||
@@ -423,7 +419,7 @@ class SearchHandler(BaseHandler):
|
||||
if include_state:
|
||||
rooms = {e.room_id for e in allowed_events}
|
||||
for room_id in rooms:
|
||||
state = yield self.state_handler.get_current_state(room_id)
|
||||
state = await self.state_handler.get_current_state(room_id)
|
||||
state_results[room_id] = list(state.values())
|
||||
|
||||
state_results.values()
|
||||
@@ -437,7 +433,7 @@ class SearchHandler(BaseHandler):
|
||||
{
|
||||
"rank": rank_map[e.event_id],
|
||||
"result": (
|
||||
yield self._event_serializer.serialize_event(e, time_now)
|
||||
await self._event_serializer.serialize_event(e, time_now)
|
||||
),
|
||||
"context": contexts.get(e.event_id, {}),
|
||||
}
|
||||
@@ -452,7 +448,7 @@ class SearchHandler(BaseHandler):
|
||||
if state_results:
|
||||
s = {}
|
||||
for room_id, state in state_results.items():
|
||||
s[room_id] = yield self._event_serializer.serialize_events(
|
||||
s[room_id] = await self._event_serializer.serialize_events(
|
||||
state, time_now
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +49,6 @@ from synapse.http.proxyagent import ProxyAgent
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -241,7 +240,10 @@ class SimpleHttpClient(object):
|
||||
# tends to do so in batches, so we need to allow the pool to keep
|
||||
# lots of idle connections around.
|
||||
pool = HTTPConnectionPool(self.reactor)
|
||||
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
|
||||
# XXX: The justification for using the cache factor here is that larger instances
|
||||
# will need both more cache and more connections.
|
||||
# Still, this should probably be a separate dial
|
||||
pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
|
||||
pool.cachedConnectionTimeout = 2 * 60
|
||||
|
||||
self.agent = ProxyAgent(
|
||||
@@ -359,6 +361,7 @@ class SimpleHttpClient(object):
|
||||
actual_headers = {
|
||||
b"Content-Type": [b"application/x-www-form-urlencoded"],
|
||||
b"User-Agent": [self.user_agent],
|
||||
b"Accept": [b"application/json"],
|
||||
}
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
@@ -399,6 +402,7 @@ class SimpleHttpClient(object):
|
||||
actual_headers = {
|
||||
b"Content-Type": [b"application/json"],
|
||||
b"User-Agent": [self.user_agent],
|
||||
b"Accept": [b"application/json"],
|
||||
}
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
@@ -434,6 +438,10 @@ class SimpleHttpClient(object):
|
||||
|
||||
ValueError: if the response was not JSON
|
||||
"""
|
||||
actual_headers = {b"Accept": [b"application/json"]}
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
body = yield self.get_raw(uri, args, headers=headers)
|
||||
return json.loads(body)
|
||||
|
||||
@@ -467,6 +475,7 @@ class SimpleHttpClient(object):
|
||||
actual_headers = {
|
||||
b"Content-Type": [b"application/json"],
|
||||
b"User-Agent": [self.user_agent],
|
||||
b"Accept": [b"application/json"],
|
||||
}
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
@@ -33,6 +33,8 @@ from prometheus_client import REGISTRY
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.util import caches
|
||||
|
||||
try:
|
||||
from prometheus_client.samples import Sample
|
||||
except ImportError:
|
||||
@@ -103,13 +105,15 @@ def nameify_sample(sample):
|
||||
|
||||
|
||||
def generate_latest(registry, emit_help=False):
|
||||
|
||||
# Trigger the cache metrics to be rescraped, which updates the common
|
||||
# metrics but do not produce metrics themselves
|
||||
for collector in caches.collectors_by_name.values():
|
||||
collector.collect()
|
||||
|
||||
output = []
|
||||
|
||||
for metric in registry.collect():
|
||||
|
||||
if metric.name.startswith("__unused"):
|
||||
continue
|
||||
|
||||
if not metric.samples:
|
||||
# No samples, don't bother.
|
||||
continue
|
||||
|
||||
@@ -51,6 +51,7 @@ push_rules_delta_state_cache_metric = register_cache(
|
||||
"cache",
|
||||
"push_rules_delta_state_cache_metric",
|
||||
cache=[], # Meaningless size, as this isn't a cache that stores values
|
||||
resizable=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -67,7 +68,8 @@ class BulkPushRuleEvaluator(object):
|
||||
self.room_push_rule_cache_metrics = register_cache(
|
||||
"cache",
|
||||
"room_push_rule_cache",
|
||||
cache=[], # Meaningless size, as this isn't a cache that stores values
|
||||
cache=[], # Meaningless size, as this isn't a cache that stores values,
|
||||
resizable=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -22,7 +22,7 @@ from six import string_types
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import UserID
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
||||
from synapse.util.caches import register_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -165,7 +165,7 @@ class PushRuleEvaluatorForEvent(object):
|
||||
|
||||
|
||||
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
|
||||
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
|
||||
regex_cache = LruCache(50000)
|
||||
register_cache("cache", "regex_push_cache", regex_cache)
|
||||
|
||||
|
||||
|
||||
@@ -92,6 +92,7 @@ CONDITIONAL_REQUIREMENTS = {
|
||||
'eliot<1.8.0;python_version<"3.5.3"',
|
||||
],
|
||||
"saml2": ["pysaml2>=4.5.0"],
|
||||
"oidc": ["authlib>=0.14.0"],
|
||||
"systemd": ["systemd-python>=231"],
|
||||
"url_preview": ["lxml>=3.5.0"],
|
||||
"test": ["mock>=2.0", "parameterized"],
|
||||
|
||||
@@ -34,9 +34,11 @@ class ReplicationRestResource(JsonResource):
|
||||
|
||||
def register_servlets(self, hs):
|
||||
send_event.register_servlets(hs, self)
|
||||
membership.register_servlets(hs, self)
|
||||
federation.register_servlets(hs, self)
|
||||
login.register_servlets(hs, self)
|
||||
register.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
streams.register_servlets(hs, self)
|
||||
|
||||
if hs.config.worker.worker_app is None:
|
||||
membership.register_servlets(hs, self)
|
||||
login.register_servlets(hs, self)
|
||||
register.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
streams.register_servlets(hs, self)
|
||||
|
||||
@@ -141,17 +141,29 @@ class ReplicationEndpoint(object):
|
||||
Returns a callable that accepts the same parameters as `_serialize_payload`.
|
||||
"""
|
||||
clock = hs.get_clock()
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_http_port
|
||||
|
||||
client = hs.get_simple_http_client()
|
||||
local_instance_name = hs.get_instance_name()
|
||||
|
||||
master_host = hs.config.worker_replication_host
|
||||
master_port = hs.config.worker_replication_http_port
|
||||
|
||||
instance_map = hs.config.worker.instance_map
|
||||
|
||||
@trace(opname="outgoing_replication_request")
|
||||
@defer.inlineCallbacks
|
||||
def send_request(instance_name="master", **kwargs):
|
||||
# Currently we only support sending requests to master process.
|
||||
if instance_name != "master":
|
||||
raise Exception("Unknown instance")
|
||||
if instance_name == local_instance_name:
|
||||
raise Exception("Trying to send HTTP request to self")
|
||||
if instance_name == "master":
|
||||
host = master_host
|
||||
port = master_port
|
||||
elif instance_name in instance_map:
|
||||
host = instance_map[instance_name].host
|
||||
port = instance_map[instance_name].port
|
||||
else:
|
||||
raise Exception(
|
||||
"Instance %r not in 'instance_map' config" % (instance_name,)
|
||||
)
|
||||
|
||||
data = yield cls._serialize_payload(**kwargs)
|
||||
|
||||
|
||||
@@ -52,9 +52,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
# We pull the streams from the replication steamer (if we try and make
|
||||
# We pull the streams from the replication handler (if we try and make
|
||||
# them ourselves we end up in an import loop).
|
||||
self.streams = hs.get_replication_streamer().get_streams()
|
||||
self.streams = hs.get_tcp_replication().get_streams()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(stream_name, from_token, upto_token):
|
||||
|
||||
@@ -18,14 +18,10 @@ from typing import Optional
|
||||
|
||||
import six
|
||||
|
||||
from synapse.storage.data_stores.main.cache import (
|
||||
CURRENT_STATE_CACHE_NAME,
|
||||
CacheInvalidationWorkerStore,
|
||||
)
|
||||
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = SlavedIdTracker(
|
||||
db_conn, "cache_invalidation_stream", "stream_id"
|
||||
) # type: Optional[SlavedIdTracker]
|
||||
self._cache_id_gen = MultiWriterIdGenerator(
|
||||
db_conn,
|
||||
database,
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="cache_invalidation_stream_by_instance",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_id",
|
||||
sequence_name="cache_invalidation_stream_seq",
|
||||
) # type: Optional[MultiWriterIdGenerator]
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
|
||||
self.hs = hs
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "caches":
|
||||
if self._cache_id_gen:
|
||||
self._cache_id_gen.advance(token)
|
||||
for row in rows:
|
||||
if row.cache_func == CURRENT_STATE_CACHE_NAME:
|
||||
if row.keys is None:
|
||||
raise Exception(
|
||||
"Can't send an 'invalidate all' for current state cache"
|
||||
)
|
||||
|
||||
room_id = row.keys[0]
|
||||
members_changed = set(row.keys[1:])
|
||||
self._invalidate_state_caches(room_id, members_changed)
|
||||
else:
|
||||
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
|
||||
|
||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||
txn.call_after(cache_func.invalidate, keys)
|
||||
txn.call_after(self._send_invalidation_poke, cache_func, keys)
|
||||
|
||||
def _send_invalidation_poke(self, cache_func, keys):
|
||||
self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
|
||||
|
||||
@@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
||||
def get_max_account_data_stream_id(self):
|
||||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "tag_account_data":
|
||||
self._account_data_id_gen.advance(token)
|
||||
for row in rows:
|
||||
@@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
||||
(row.user_id, row.room_id, row.data_type)
|
||||
)
|
||||
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
|
||||
return super(SlavedAccountDataStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
|
||||
from synapse.storage.database import Database
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
from synapse.util.caches.descriptors import Cache
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
@@ -26,7 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore):
|
||||
super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
|
||||
name="client_ip_last_seen", keylen=4, max_entries=50000
|
||||
)
|
||||
|
||||
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
||||
|
||||
@@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||
expiry_ms=30 * 60 * 1000,
|
||||
)
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "to_device":
|
||||
self._device_inbox_id_gen.advance(token)
|
||||
for row in rows:
|
||||
@@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||
self._device_federation_outbox_stream_cache.entity_has_changed(
|
||||
row.entity, token
|
||||
)
|
||||
return super(SlavedDeviceInboxStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
"DeviceListFederationStreamChangeCache", device_list_max
|
||||
)
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(token)
|
||||
self._invalidate_caches_for_devices(token, rows)
|
||||
@@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
self._device_list_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||
return super(SlavedDeviceStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def _invalidate_caches_for_devices(self, token, rows):
|
||||
for row in rows:
|
||||
|
||||
@@ -15,11 +15,6 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStreamCurrentStateRow,
|
||||
EventsStreamEventRow,
|
||||
)
|
||||
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
|
||||
from synapse.storage.data_stores.main.event_push_actions import (
|
||||
EventPushActionsWorkerStore,
|
||||
@@ -35,7 +30,6 @@ from synapse.storage.database import Database
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,11 +56,6 @@ class SlavedEventStore(
|
||||
BaseSlavedStore,
|
||||
):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
|
||||
self._backfill_id_gen = SlavedIdTracker(
|
||||
db_conn, "events", "stream_ordering", step=-1
|
||||
)
|
||||
|
||||
super(SlavedEventStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
events_max = self._stream_id_gen.get_current_token()
|
||||
@@ -92,83 +81,3 @@ class SlavedEventStore(
|
||||
|
||||
def get_room_min_stream_ordering(self):
|
||||
return self._backfill_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "events":
|
||||
self._stream_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self._process_event_stream_row(token, row)
|
||||
elif stream_name == "backfill":
|
||||
self._backfill_id_gen.advance(-token)
|
||||
for row in rows:
|
||||
self.invalidate_caches_for_event(
|
||||
-token,
|
||||
row.event_id,
|
||||
row.room_id,
|
||||
row.type,
|
||||
row.state_key,
|
||||
row.redacts,
|
||||
row.relates_to,
|
||||
backfilled=True,
|
||||
)
|
||||
return super(SlavedEventStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
|
||||
def _process_event_stream_row(self, token, row):
|
||||
data = row.data
|
||||
|
||||
if row.type == EventsStreamEventRow.TypeId:
|
||||
self.invalidate_caches_for_event(
|
||||
token,
|
||||
data.event_id,
|
||||
data.room_id,
|
||||
data.type,
|
||||
data.state_key,
|
||||
data.redacts,
|
||||
data.relates_to,
|
||||
backfilled=False,
|
||||
)
|
||||
elif row.type == EventsStreamCurrentStateRow.TypeId:
|
||||
self._curr_state_delta_stream_cache.entity_has_changed(
|
||||
row.data.room_id, token
|
||||
)
|
||||
|
||||
if data.type == EventTypes.Member:
|
||||
self.get_rooms_for_user_with_stream_ordering.invalidate(
|
||||
(data.state_key,)
|
||||
)
|
||||
else:
|
||||
raise Exception("Unknown events stream row type %s" % (row.type,))
|
||||
|
||||
def invalidate_caches_for_event(
|
||||
self,
|
||||
stream_ordering,
|
||||
event_id,
|
||||
room_id,
|
||||
etype,
|
||||
state_key,
|
||||
redacts,
|
||||
relates_to,
|
||||
backfilled,
|
||||
):
|
||||
self._invalidate_get_event_cache(event_id)
|
||||
|
||||
self.get_latest_event_ids_in_room.invalidate((room_id,))
|
||||
|
||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
|
||||
|
||||
if not backfilled:
|
||||
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
|
||||
|
||||
if redacts:
|
||||
self._invalidate_get_event_cache(redacts)
|
||||
|
||||
if etype == EventTypes.Member:
|
||||
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
|
||||
self.get_invited_rooms_for_local_user.invalidate((state_key,))
|
||||
|
||||
if relates_to:
|
||||
self.get_relations_for_event.invalidate_many((relates_to,))
|
||||
self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
|
||||
self.get_applicable_edit.invalidate((relates_to,))
|
||||
|
||||
@@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
||||
def get_group_stream_token(self):
|
||||
return self._group_updates_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "groups":
|
||||
self._group_updates_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
|
||||
|
||||
return super(SlavedGroupServerStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||
def get_current_presence_token(self):
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "presence":
|
||||
self._presence_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self.presence_stream_cache.entity_has_changed(row.user_id, token)
|
||||
self._get_presence_for_user.invalidate((row.user_id,))
|
||||
return super(SlavedPresenceStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -15,19 +15,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from .events import SlavedEventStore
|
||||
|
||||
|
||||
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
self._push_rules_stream_id_gen = SlavedIdTracker(
|
||||
db_conn, "push_rules_stream", "stream_id"
|
||||
)
|
||||
super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
def get_push_rules_stream_token(self):
|
||||
return (
|
||||
self._push_rules_stream_id_gen.get_current_token(),
|
||||
@@ -37,13 +29,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||
def get_max_push_rules_stream_id(self):
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "push_rules":
|
||||
self._push_rules_stream_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self.get_push_rules_for_user.invalidate((row.user_id,))
|
||||
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
|
||||
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
|
||||
return super(SlavedPushRuleStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||
def get_pushers_stream_token(self):
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "pushers":
|
||||
self._pushers_id_gen.advance(token)
|
||||
return super(SlavedPusherStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
|
||||
self.get_receipts_for_room.invalidate((room_id, receipt_type))
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "receipts":
|
||||
self._receipts_id_gen.advance(token)
|
||||
for row in rows:
|
||||
@@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||
)
|
||||
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
|
||||
|
||||
return super(SlavedReceiptsStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
||||
def get_current_public_room_stream_id(self):
|
||||
return self._public_room_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "public_rooms":
|
||||
self._public_room_id_gen.advance(token)
|
||||
|
||||
return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -16,12 +16,17 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStream,
|
||||
EventsStreamEventRow,
|
||||
EventsStreamRow,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -83,8 +88,10 @@ class ReplicationDataHandler:
|
||||
to handle updates in additional ways.
|
||||
"""
|
||||
|
||||
def __init__(self, store: BaseSlavedStore):
|
||||
self.store = store
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
@@ -100,10 +107,32 @@ class ReplicationDataHandler:
|
||||
token: stream token for this batch of rows
|
||||
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
|
||||
"""
|
||||
self.store.process_replication_rows(stream_name, token, rows)
|
||||
self.store.process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
async def on_position(self, stream_name: str, token: int):
|
||||
self.store.process_replication_rows(stream_name, token, [])
|
||||
if stream_name == EventsStream.NAME:
|
||||
# We shouldn't get multiple rows per token for events stream, so
|
||||
# we don't need to optimise this for multiple rows.
|
||||
for row in rows:
|
||||
if row.type != EventsStreamEventRow.TypeId:
|
||||
continue
|
||||
assert isinstance(row, EventsStreamRow)
|
||||
|
||||
event = await self.store.get_event(
|
||||
row.data.event_id, allow_rejected=True
|
||||
)
|
||||
if event.rejected_reason:
|
||||
continue
|
||||
|
||||
extra_users = () # type: Tuple[str, ...]
|
||||
if event.type == EventTypes.Member:
|
||||
extra_users = (event.state_key,)
|
||||
max_token = self.store.get_room_max_stream_ordering()
|
||||
self.notifier.on_new_room_event(event, token, max_token, extra_users)
|
||||
|
||||
await self.pusher_pool.on_new_notifications(token, token)
|
||||
|
||||
async def on_position(self, stream_name: str, instance_name: str, token: int):
|
||||
self.store.process_replication_rows(stream_name, instance_name, token, [])
|
||||
|
||||
def on_remote_server_up(self, server: str):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
|
||||
@@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
|
||||
return " ".join((self.app_id, self.push_key, self.user_id))
|
||||
|
||||
|
||||
class InvalidateCacheCommand(Command):
|
||||
"""Sent by the client to invalidate an upstream cache.
|
||||
|
||||
THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
|
||||
NOT DISASTROUS IF WE DROP ON THE FLOOR.
|
||||
|
||||
Mainly used to invalidate destination retry timing caches.
|
||||
|
||||
Format::
|
||||
|
||||
INVALIDATE_CACHE <cache_func> <keys_json>
|
||||
|
||||
Where <keys_json> is a json list.
|
||||
"""
|
||||
|
||||
NAME = "INVALIDATE_CACHE"
|
||||
|
||||
def __init__(self, cache_func, keys):
|
||||
self.cache_func = cache_func
|
||||
self.keys = keys
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
cache_func, keys_json = line.split(" ", 1)
|
||||
|
||||
return cls(cache_func, json.loads(keys_json))
|
||||
|
||||
def to_line(self):
|
||||
return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
|
||||
|
||||
|
||||
class UserIpCommand(Command):
|
||||
"""Sent periodically when a worker sees activity from a client.
|
||||
|
||||
@@ -439,7 +408,6 @@ _COMMANDS = (
|
||||
UserSyncCommand,
|
||||
FederationAckCommand,
|
||||
RemovePusherCommand,
|
||||
InvalidateCacheCommand,
|
||||
UserIpCommand,
|
||||
RemoteServerUpCommand,
|
||||
ClearUserSyncsCommand,
|
||||
@@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
|
||||
ClearUserSyncsCommand.NAME,
|
||||
FederationAckCommand.NAME,
|
||||
RemovePusherCommand.NAME,
|
||||
InvalidateCacheCommand.NAME,
|
||||
UserIpCommand.NAME,
|
||||
ErrorCommand.NAME,
|
||||
RemoteServerUpCommand.NAME,
|
||||
|
||||
@@ -15,18 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
|
||||
ClearUserSyncsCommand,
|
||||
Command,
|
||||
FederationAckCommand,
|
||||
InvalidateCacheCommand,
|
||||
PositionCommand,
|
||||
RdataCommand,
|
||||
RemoteServerUpCommand,
|
||||
@@ -48,7 +36,14 @@ from synapse.replication.tcp.commands import (
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.protocol import AbstractConnection
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||
from synapse.replication.tcp.streams import (
|
||||
STREAMS_MAP,
|
||||
BackfillStream,
|
||||
CachesStream,
|
||||
EventsStream,
|
||||
FederationStream,
|
||||
Stream,
|
||||
)
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -85,6 +80,32 @@ class ReplicationCommandHandler:
|
||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||
} # type: Dict[str, Stream]
|
||||
|
||||
# List of streams that this instance is the source of
|
||||
self._streams_to_replicate = [] # type: List[Stream]
|
||||
|
||||
for stream in self._streams.values():
|
||||
if stream.NAME == CachesStream.NAME:
|
||||
# All workers can write to the cache invalidation stream.
|
||||
self._streams_to_replicate.append(stream)
|
||||
continue
|
||||
|
||||
if (
|
||||
isinstance(stream, (EventsStream, BackfillStream))
|
||||
and hs.config.worker.writers.events == hs.get_instance_name()
|
||||
):
|
||||
self._streams_to_replicate.append(stream)
|
||||
|
||||
# Only add any other streams if we're on master.
|
||||
if hs.config.worker_app is not None:
|
||||
continue
|
||||
|
||||
if stream.NAME == FederationStream.NAME and hs.config.send_federation:
|
||||
# We only support federation stream if federation sending
|
||||
# has been disabled on the master.
|
||||
continue
|
||||
|
||||
self._streams_to_replicate.append(stream)
|
||||
|
||||
self._position_linearizer = Linearizer(
|
||||
"replication_position", clock=self._clock
|
||||
)
|
||||
@@ -162,16 +183,33 @@ class ReplicationCommandHandler:
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, self._factory)
|
||||
|
||||
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||
# We only want to announce positions by the writer of the streams.
|
||||
# Currently this is just the master process.
|
||||
if not self._is_master:
|
||||
return
|
||||
def get_streams(self) -> Dict[str, Stream]:
|
||||
"""Get a map from stream name to all streams.
|
||||
"""
|
||||
return self._streams
|
||||
|
||||
for stream_name, stream in self._streams.items():
|
||||
current_token = stream.current_token()
|
||||
def get_streams_to_replicate(self) -> List[Stream]:
|
||||
"""Get a list of streams that this instances replicates.
|
||||
"""
|
||||
return self._streams_to_replicate
|
||||
|
||||
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||
self.send_positions_to_connection(conn)
|
||||
|
||||
def send_positions_to_connection(self, conn: AbstractConnection):
|
||||
"""Send current position of all streams this process is source of to
|
||||
the connection.
|
||||
"""
|
||||
|
||||
# We respond with current position of all streams this instance
|
||||
# replicates.
|
||||
for stream in self.get_streams_to_replicate():
|
||||
self.send_command(
|
||||
PositionCommand(stream_name, self._instance_name, current_token)
|
||||
PositionCommand(
|
||||
stream.NAME,
|
||||
self._instance_name,
|
||||
stream.current_token(self._instance_name),
|
||||
)
|
||||
)
|
||||
|
||||
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
|
||||
@@ -208,18 +246,6 @@ class ReplicationCommandHandler:
|
||||
|
||||
self._notifier.on_new_replication_data()
|
||||
|
||||
async def on_INVALIDATE_CACHE(
|
||||
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
|
||||
):
|
||||
invalidate_cache_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
# We invalidate the cache locally, but then also stream that to other
|
||||
# workers.
|
||||
await self._store.invalidate_cache_and_stream(
|
||||
cmd.cache_func, tuple(cmd.keys)
|
||||
)
|
||||
|
||||
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
|
||||
user_ip_cache_counter.inc()
|
||||
|
||||
@@ -293,7 +319,7 @@ class ReplicationCommandHandler:
|
||||
rows: a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
|
||||
await self._replication_data_handler.on_rdata(
|
||||
stream_name, instance_name, token, rows
|
||||
)
|
||||
@@ -324,7 +350,7 @@ class ReplicationCommandHandler:
|
||||
self._pending_batches.pop(stream_name, [])
|
||||
|
||||
# Find where we previously streamed up to.
|
||||
current_token = stream.current_token()
|
||||
current_token = stream.current_token(cmd.instance_name)
|
||||
|
||||
# If the position token matches our current token then we're up to
|
||||
# date and there's nothing to do. Otherwise, fetch all updates
|
||||
@@ -361,7 +387,9 @@ class ReplicationCommandHandler:
|
||||
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
await self._replication_data_handler.on_position(stream_name, cmd.token)
|
||||
await self._replication_data_handler.on_position(
|
||||
cmd.stream_name, cmd.instance_name, cmd.token
|
||||
)
|
||||
|
||||
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
||||
|
||||
@@ -489,12 +517,6 @@ class ReplicationCommandHandler:
|
||||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
|
||||
"""Poke the master to invalidate a cache.
|
||||
"""
|
||||
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_user_ip(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@@ -70,7 +70,6 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
logger.info("Connected to redis")
|
||||
super().connectionMade()
|
||||
run_as_background_process("subscribe-replication", self._send_subscribe)
|
||||
self.handler.new_connection(self)
|
||||
|
||||
async def _send_subscribe(self):
|
||||
# it's important to make sure that we only send the REPLICATE command once we
|
||||
@@ -81,9 +80,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
logger.info(
|
||||
"Successfully subscribed to redis stream, sending REPLICATE command"
|
||||
)
|
||||
self.handler.new_connection(self)
|
||||
await self._async_send_command(ReplicateCommand())
|
||||
logger.info("REPLICATE successfully sent")
|
||||
|
||||
# We send out our positions when there is a new connection in case the
|
||||
# other side missed updates. We do this for Redis connections as the
|
||||
# otherside won't know we've connected and so won't issue a REPLICATE.
|
||||
self.handler.send_positions_to_connection(self)
|
||||
|
||||
def messageReceived(self, pattern: str, channel: str, message: str):
|
||||
"""Received a message from redis.
|
||||
"""
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict, List
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
@@ -25,7 +24,6 @@ from twisted.internet.protocol import Factory
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
stream_updates_counter = Counter(
|
||||
@@ -71,26 +69,11 @@ class ReplicationStreamer(object):
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self._replication_torture_level = hs.config.replication_torture_level
|
||||
|
||||
# Work out list of streams that this instance is the source of.
|
||||
self.streams = [] # type: List[Stream]
|
||||
if hs.config.worker_app is None:
|
||||
for stream in STREAMS_MAP.values():
|
||||
if stream == FederationStream and hs.config.send_federation:
|
||||
# We only support federation stream if federation sending
|
||||
# hase been disabled on the master.
|
||||
continue
|
||||
|
||||
self.streams.append(stream(hs))
|
||||
|
||||
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
|
||||
|
||||
# Only bother registering the notifier callback if we have streams to
|
||||
# publish.
|
||||
if self.streams:
|
||||
self.notifier.add_replication_callback(self.on_notifier_poke)
|
||||
self.notifier.add_replication_callback(self.on_notifier_poke)
|
||||
|
||||
# Keeps track of whether we are currently checking for updates
|
||||
self.is_looping = False
|
||||
@@ -98,10 +81,8 @@ class ReplicationStreamer(object):
|
||||
|
||||
self.command_handler = hs.get_tcp_replication()
|
||||
|
||||
def get_streams(self) -> Dict[str, Stream]:
|
||||
"""Get a mapp from stream name to stream instance.
|
||||
"""
|
||||
return self.streams_by_name
|
||||
# Set of streams to replicate.
|
||||
self.streams = self.command_handler.get_streams_to_replicate()
|
||||
|
||||
def on_notifier_poke(self):
|
||||
"""Checks if there is actually any new data and sends it to the
|
||||
@@ -145,7 +126,9 @@ class ReplicationStreamer(object):
|
||||
random.shuffle(all_streams)
|
||||
|
||||
for stream in all_streams:
|
||||
if stream.last_token == stream.current_token():
|
||||
if stream.last_token == stream.current_token(
|
||||
self._instance_name
|
||||
):
|
||||
continue
|
||||
|
||||
if self._replication_torture_level:
|
||||
@@ -157,7 +140,7 @@ class ReplicationStreamer(object):
|
||||
"Getting stream: %s: %s -> %s",
|
||||
stream.NAME,
|
||||
stream.last_token,
|
||||
stream.current_token(),
|
||||
stream.current_token(self._instance_name),
|
||||
)
|
||||
try:
|
||||
updates, current_token, limited = await stream.get_updates()
|
||||
|
||||
@@ -95,19 +95,25 @@ class Stream(object):
|
||||
def __init__(
|
||||
self,
|
||||
local_instance_name: str,
|
||||
current_token_function: Callable[[], Token],
|
||||
current_token_function: Callable[[str], Token],
|
||||
update_function: UpdateFunction,
|
||||
):
|
||||
"""Instantiate a Stream
|
||||
|
||||
current_token_function and update_function are callbacks which should be
|
||||
implemented by subclasses.
|
||||
`current_token_function` and `update_function` are callbacks which
|
||||
should be implemented by subclasses.
|
||||
|
||||
current_token_function is called to get the current token of the underlying
|
||||
stream.
|
||||
`current_token_function` takes an instance name, which is a writer to
|
||||
the stream, and returns the position in the stream of the writer (as
|
||||
viewed from the current process). On the writer process this is where
|
||||
the writer has successfully written up to, whereas on other processes
|
||||
this is the position which we have received updates up to over
|
||||
replication. (Note that most streams have a single writer and so their
|
||||
implementations ignore the instance name passed in).
|
||||
|
||||
update_function is called to get updates for this stream between a pair of
|
||||
stream tokens. See the UpdateFunction type definition for more info.
|
||||
`update_function` is called to get updates for this stream between a
|
||||
pair of stream tokens. See the `UpdateFunction` type definition for more
|
||||
info.
|
||||
|
||||
Args:
|
||||
local_instance_name: The instance name of the current process
|
||||
@@ -119,13 +125,13 @@ class Stream(object):
|
||||
self.update_function = update_function
|
||||
|
||||
# The token from which we last asked for updates
|
||||
self.last_token = self.current_token()
|
||||
self.last_token = self.current_token(self.local_instance_name)
|
||||
|
||||
def discard_updates_and_advance(self):
|
||||
"""Called when the stream should advance but the updates would be discarded,
|
||||
e.g. when there are no currently connected workers.
|
||||
"""
|
||||
self.last_token = self.current_token()
|
||||
self.last_token = self.current_token(self.local_instance_name)
|
||||
|
||||
async def get_updates(self) -> StreamUpdateResult:
|
||||
"""Gets all updates since the last time this function was called (or
|
||||
@@ -137,7 +143,7 @@ class Stream(object):
|
||||
position in stream, and `limited` is whether there are more updates
|
||||
to fetch.
|
||||
"""
|
||||
current_token = self.current_token()
|
||||
current_token = self.current_token(self.local_instance_name)
|
||||
updates, current_token, limited = await self.get_updates_since(
|
||||
self.local_instance_name, self.last_token, current_token
|
||||
)
|
||||
@@ -169,6 +175,16 @@ class Stream(object):
|
||||
return updates, upto_token, limited
|
||||
|
||||
|
||||
def current_token_without_instance(
|
||||
current_token: Callable[[], int]
|
||||
) -> Callable[[str], int]:
|
||||
"""Takes a current token callback function for a single writer stream
|
||||
that doesn't take an instance name parameter and wraps it in a function that
|
||||
does accept an instance name parameter but ignores it.
|
||||
"""
|
||||
return lambda instance_name: current_token()
|
||||
|
||||
|
||||
def db_query_to_update_function(
|
||||
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||
) -> UpdateFunction:
|
||||
@@ -234,7 +250,7 @@ class BackfillStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_current_backfill_token,
|
||||
current_token_without_instance(store.get_current_backfill_token),
|
||||
db_query_to_update_function(store.get_all_new_backfill_event_rows),
|
||||
)
|
||||
|
||||
@@ -270,7 +286,9 @@ class PresenceStream(Stream):
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
|
||||
super().__init__(
|
||||
hs.get_instance_name(), store.get_current_presence_token, update_function
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_current_presence_token),
|
||||
update_function,
|
||||
)
|
||||
|
||||
|
||||
@@ -295,7 +313,9 @@ class TypingStream(Stream):
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
|
||||
super().__init__(
|
||||
hs.get_instance_name(), typing_handler.get_current_token, update_function
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(typing_handler.get_current_token),
|
||||
update_function,
|
||||
)
|
||||
|
||||
|
||||
@@ -318,7 +338,7 @@ class ReceiptsStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_max_receipt_stream_id,
|
||||
current_token_without_instance(store.get_max_receipt_stream_id),
|
||||
db_query_to_update_function(store.get_all_updated_receipts),
|
||||
)
|
||||
|
||||
@@ -338,7 +358,7 @@ class PushRulesStream(Stream):
|
||||
hs.get_instance_name(), self._current_token, self._update_function
|
||||
)
|
||||
|
||||
def _current_token(self) -> int:
|
||||
def _current_token(self, instance_name: str) -> int:
|
||||
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
||||
return push_rules_token
|
||||
|
||||
@@ -372,7 +392,7 @@ class PushersStream(Stream):
|
||||
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_pushers_stream_token,
|
||||
current_token_without_instance(store.get_pushers_stream_token),
|
||||
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
||||
)
|
||||
|
||||
@@ -401,13 +421,27 @@ class CachesStream(Stream):
|
||||
ROW_TYPE = CachesStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
self.store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_cache_stream_token,
|
||||
db_query_to_update_function(store.get_all_updated_caches),
|
||||
self.store.get_cache_stream_token,
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
async def _update_function(
|
||||
self, instance_name: str, from_token: int, upto_token: int, limit: int
|
||||
):
|
||||
rows = await self.store.get_all_updated_caches(
|
||||
instance_name, from_token, upto_token, limit
|
||||
)
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
|
||||
class PublicRoomsStream(Stream):
|
||||
"""The public rooms list changed
|
||||
@@ -430,7 +464,7 @@ class PublicRoomsStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_current_public_room_stream_id,
|
||||
current_token_without_instance(store.get_current_public_room_stream_id),
|
||||
db_query_to_update_function(store.get_all_new_public_rooms),
|
||||
)
|
||||
|
||||
@@ -451,7 +485,7 @@ class DeviceListsStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_device_stream_token,
|
||||
current_token_without_instance(store.get_device_stream_token),
|
||||
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
||||
)
|
||||
|
||||
@@ -469,7 +503,7 @@ class ToDeviceStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_to_device_stream_token,
|
||||
current_token_without_instance(store.get_to_device_stream_token),
|
||||
db_query_to_update_function(store.get_all_new_device_messages),
|
||||
)
|
||||
|
||||
@@ -489,7 +523,7 @@ class TagAccountDataStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_max_account_data_stream_id,
|
||||
current_token_without_instance(store.get_max_account_data_stream_id),
|
||||
db_query_to_update_function(store.get_all_updated_tags),
|
||||
)
|
||||
|
||||
@@ -509,7 +543,7 @@ class AccountDataStream(Stream):
|
||||
self.store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
self.store.get_max_account_data_stream_id,
|
||||
current_token_without_instance(self.store.get_max_account_data_stream_id),
|
||||
db_query_to_update_function(self._update_function),
|
||||
)
|
||||
|
||||
@@ -540,7 +574,7 @@ class GroupServerStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_group_stream_token,
|
||||
current_token_without_instance(store.get_group_stream_token),
|
||||
db_query_to_update_function(store.get_all_groups_changes),
|
||||
)
|
||||
|
||||
@@ -558,7 +592,7 @@ class UserSignatureStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_device_stream_token,
|
||||
current_token_without_instance(store.get_device_stream_token),
|
||||
db_query_to_update_function(
|
||||
store.get_all_user_signature_changes_for_remotes
|
||||
),
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import List, Tuple, Type
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Stream, StreamUpdateResult, Token
|
||||
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
|
||||
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
@@ -119,7 +119,7 @@ class EventsStream(Stream):
|
||||
self._store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
self._store.get_current_events_token,
|
||||
current_token_without_instance(self._store.get_current_events_token),
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,11 @@
|
||||
# limitations under the License.
|
||||
from collections import namedtuple
|
||||
|
||||
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
Stream,
|
||||
current_token_without_instance,
|
||||
make_http_update_function,
|
||||
)
|
||||
|
||||
|
||||
class FederationStream(Stream):
|
||||
@@ -35,21 +39,35 @@ class FederationStream(Stream):
|
||||
ROW_TYPE = FederationStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
# Not all synapse instances will have a federation sender instance,
|
||||
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
|
||||
# so we stub the stream out when that is the case.
|
||||
if hs.config.worker_app is None or hs.should_send_federation():
|
||||
if hs.config.worker_app is None:
|
||||
# master process: get updates from the FederationRemoteSendQueue.
|
||||
# (if the master is configured to send federation itself, federation_sender
|
||||
# will be a real FederationSender, which has stubs for current_token and
|
||||
# get_replication_rows.)
|
||||
federation_sender = hs.get_federation_sender()
|
||||
current_token = federation_sender.get_current_token
|
||||
update_function = db_query_to_update_function(
|
||||
federation_sender.get_replication_rows
|
||||
current_token = current_token_without_instance(
|
||||
federation_sender.get_current_token
|
||||
)
|
||||
update_function = federation_sender.get_replication_rows
|
||||
|
||||
elif hs.should_send_federation():
|
||||
# federation sender: Query master process
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
current_token = self._stub_current_token
|
||||
|
||||
else:
|
||||
current_token = lambda: 0
|
||||
# other worker: stub out the update function (we're not interested in
|
||||
# any updates so when we get a POSITION we do nothing)
|
||||
update_function = self._stub_update_function
|
||||
current_token = self._stub_current_token
|
||||
|
||||
super().__init__(hs.get_instance_name(), current_token, update_function)
|
||||
|
||||
@staticmethod
|
||||
def _stub_current_token(instance_name: str) -> int:
|
||||
# dummy current-token method for use on workers
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def _stub_update_function(instance_name, from_token, upto_token, limit):
|
||||
return [], upto_token, False
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
<tr>
|
||||
<td colspan="2">
|
||||
<div class="noticetext">Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.</div>
|
||||
<div class="noticetext">To extend the validity of your account, please click on the link bellow (or copy and paste it into a new browser tab):</div>
|
||||
<div class="noticetext">To extend the validity of your account, please click on the link below (or copy and paste it into a new browser tab):</div>
|
||||
<div class="noticetext"><a href="{{ url }}">{{ url }}</a></div>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
@@ -2,6 +2,6 @@ Hi {{ display_name }},
|
||||
|
||||
Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
|
||||
|
||||
To extend the validity of your account, please click on the link bellow (or copy and paste it to a new browser tab):
|
||||
To extend the validity of your account, please click on the link below (or copy and paste it to a new browser tab):
|
||||
|
||||
{{ url }}
|
||||
|
||||
18
synapse/res/templates/sso_error.html
Normal file
18
synapse/res/templates/sso_error.html
Normal file
@@ -0,0 +1,18 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>SSO error</title>
|
||||
</head>
|
||||
<body>
|
||||
<p>Oops! Something went wrong during authentication.</p>
|
||||
<p>
|
||||
Try logging in again from your Matrix client and if the problem persists
|
||||
please contact the server's administrator.
|
||||
</p>
|
||||
<p>Error: <code>{{ error }}</code></p>
|
||||
{% if error_description %}
|
||||
<pre><code>{{ error_description }}</code></pre>
|
||||
{% endif %}
|
||||
</body>
|
||||
</html>
|
||||
@@ -32,6 +32,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
|
||||
from synapse.rest.admin.rooms import (
|
||||
JoinRoomAliasServlet,
|
||||
ListRoomRestServlet,
|
||||
RoomRestServlet,
|
||||
ShutdownRoomRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
|
||||
@@ -193,6 +194,7 @@ def register_servlets(hs, http_server):
|
||||
"""
|
||||
register_servlets_for_client_rest_resource(hs, http_server)
|
||||
ListRoomRestServlet(hs).register(http_server)
|
||||
RoomRestServlet(hs).register(http_server)
|
||||
JoinRoomAliasServlet(hs).register(http_server)
|
||||
PurgeRoomServlet(hs).register(http_server)
|
||||
SendServerNoticeServlet(hs).register(http_server)
|
||||
|
||||
@@ -26,6 +26,7 @@ from synapse.http.servlet import (
|
||||
)
|
||||
from synapse.rest.admin._base import (
|
||||
admin_patterns,
|
||||
assert_requester_is_admin,
|
||||
assert_user_is_admin,
|
||||
historical_admin_path_patterns,
|
||||
)
|
||||
@@ -169,7 +170,7 @@ class ListRoomRestServlet(RestServlet):
|
||||
in a dictionary containing room information. Supports pagination.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms")
|
||||
PATTERNS = admin_patterns("/rooms$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
@@ -253,6 +254,29 @@ class ListRoomRestServlet(RestServlet):
|
||||
return 200, response
|
||||
|
||||
|
||||
class RoomRestServlet(RestServlet):
|
||||
"""Get room details.
|
||||
|
||||
TODO: Add on_POST to allow room creation without joining the room
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, room_id):
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
ret = await self.store.get_room_with_stats(room_id)
|
||||
if not ret:
|
||||
raise NotFoundError("Room not found")
|
||||
|
||||
return 200, ret
|
||||
|
||||
|
||||
class JoinRoomAliasServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
|
||||
|
||||
@@ -83,6 +83,7 @@ class LoginRestServlet(RestServlet):
|
||||
self.jwt_algorithm = hs.config.jwt_algorithm
|
||||
self.saml2_enabled = hs.config.saml2_enabled
|
||||
self.cas_enabled = hs.config.cas_enabled
|
||||
self.oidc_enabled = hs.config.oidc_enabled
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
self.handlers = hs.get_handlers()
|
||||
@@ -96,9 +97,7 @@ class LoginRestServlet(RestServlet):
|
||||
flows = []
|
||||
if self.jwt_enabled:
|
||||
flows.append({"type": LoginRestServlet.JWT_TYPE})
|
||||
if self.saml2_enabled:
|
||||
flows.append({"type": LoginRestServlet.SSO_TYPE})
|
||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||
|
||||
if self.cas_enabled:
|
||||
flows.append({"type": LoginRestServlet.SSO_TYPE})
|
||||
|
||||
@@ -114,6 +113,11 @@ class LoginRestServlet(RestServlet):
|
||||
# fall back to the fallback API if they don't understand one of the
|
||||
# login flow types returned.
|
||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||
elif self.saml2_enabled:
|
||||
flows.append({"type": LoginRestServlet.SSO_TYPE})
|
||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||
elif self.oidc_enabled:
|
||||
flows.append({"type": LoginRestServlet.SSO_TYPE})
|
||||
|
||||
flows.extend(
|
||||
({"type": t} for t in self.auth_handler.get_supported_login_types())
|
||||
@@ -465,6 +469,22 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
|
||||
return self._saml_handler.handle_redirect_request(client_redirect_url)
|
||||
|
||||
|
||||
class OIDCRedirectServlet(RestServlet):
|
||||
"""Implementation for /login/sso/redirect for the OIDC login flow."""
|
||||
|
||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
self._oidc_handler = hs.get_oidc_handler()
|
||||
|
||||
async def on_GET(self, request):
|
||||
args = request.args
|
||||
if b"redirectUrl" not in args:
|
||||
return 400, "Redirect URL not specified for SSO auth"
|
||||
client_redirect_url = args[b"redirectUrl"][0]
|
||||
await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
if hs.config.cas_enabled:
|
||||
@@ -472,3 +492,5 @@ def register_servlets(hs, http_server):
|
||||
CasTicketServlet(hs).register(http_server)
|
||||
elif hs.config.saml2_enabled:
|
||||
SAMLRedirectServlet(hs).register(http_server)
|
||||
elif hs.config.oidc_enabled:
|
||||
OIDCRedirectServlet(hs).register(http_server)
|
||||
|
||||
27
synapse/rest/oidc/__init__.py
Normal file
27
synapse/rest/oidc/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Quentin Gliech
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.rest.oidc.callback_resource import OIDCCallbackResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OIDCResource(Resource):
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self)
|
||||
self.putChild(b"callback", OIDCCallbackResource(hs))
|
||||
31
synapse/rest/oidc/callback_resource.py
Normal file
31
synapse/rest/oidc/callback_resource.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Quentin Gliech
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from synapse.http.server import DirectServeResource, wrap_html_request_handler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OIDCCallbackResource(DirectServeResource):
|
||||
isLeaf = 1
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__()
|
||||
self._oidc_handler = hs.get_oidc_handler()
|
||||
|
||||
@wrap_html_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
return await self._oidc_handler.handle_oidc_callback(request)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user