1
0

Compare commits

..

1 Commits

Author SHA1 Message Date
Erik Johnston
aa2fe082ae Allow testing old deps with postgres 2021-01-22 10:48:56 +00:00
146 changed files with 1225 additions and 4243 deletions

View File

@@ -9,8 +9,3 @@ apt-get update
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
export LANG="C.UTF-8" export LANG="C.UTF-8"
# Prevent virtualenv from auto-updating pip to an incompatible version
export VIRTUALENV_NO_DOWNLOAD=1
exec tox -e py35-old,combine

View File

@@ -1,106 +1,3 @@
Synapse 1.27.0rc1 (2021-02-02)
==============================
Note that this release includes a change in Synapse to use Redis as a cache ─ as well as a pub/sub mechanism ─ if Redis support is enabled. No action is needed by server administrators, and we do not expect resource usage of the Redis instance to change dramatically.
This release also changes the callback URI for OpenID Connect (OIDC) identity providers. If your server is configured to use single sign-on via an OIDC/OAuth2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
This release also changes escaping of variables in the HTML templates for SSO or email notifications. If you have customised these templates, please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
Features
--------
- Add an admin API for getting and deleting forward extremities for a room. ([\#9062](https://github.com/matrix-org/synapse/issues/9062))
- Add an admin API for retrieving the current room state of a room. ([\#9168](https://github.com/matrix-org/synapse/issues/9168))
- Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9183](https://github.com/matrix-org/synapse/issues/9183), [\#9242](https://github.com/matrix-org/synapse/issues/9242))
- Add an admin API endpoint for shadow-banning users. ([\#9209](https://github.com/matrix-org/synapse/issues/9209))
- Add ratelimits to the 3PID `/requestToken` APIs. ([\#9238](https://github.com/matrix-org/synapse/issues/9238))
- Add support to the OpenID Connect integration for adding the user's email address. ([\#9245](https://github.com/matrix-org/synapse/issues/9245))
- Add ratelimits to invites in rooms and to specific users. ([\#9258](https://github.com/matrix-org/synapse/issues/9258))
- Improve the user experience of setting up an account via single-sign on. ([\#9262](https://github.com/matrix-org/synapse/issues/9262), [\#9272](https://github.com/matrix-org/synapse/issues/9272), [\#9275](https://github.com/matrix-org/synapse/issues/9275), [\#9276](https://github.com/matrix-org/synapse/issues/9276), [\#9277](https://github.com/matrix-org/synapse/issues/9277), [\#9286](https://github.com/matrix-org/synapse/issues/9286), [\#9287](https://github.com/matrix-org/synapse/issues/9287))
- Add phone home stats for encrypted messages. ([\#9283](https://github.com/matrix-org/synapse/issues/9283))
- Update the redirect URI for OIDC authentication. ([\#9288](https://github.com/matrix-org/synapse/issues/9288))
Bugfixes
--------
- Fix spurious errors in logs when deleting a non-existant pusher. ([\#9121](https://github.com/matrix-org/synapse/issues/9121))
- Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled). ([\#9163](https://github.com/matrix-org/synapse/issues/9163))
- Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding. ([\#9164](https://github.com/matrix-org/synapse/issues/9164))
- Fix a long-standing bug where invalid data could cause errors when calculating the presentable room name for push. ([\#9165](https://github.com/matrix-org/synapse/issues/9165))
- Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data. ([\#9218](https://github.com/matrix-org/synapse/issues/9218))
- Fix a bug where `None` was passed to Synapse modules instead of an empty dictionary if an empty module `config` block was provided in the homeserver config. ([\#9229](https://github.com/matrix-org/synapse/issues/9229))
- Fix a bug in the `make_room_admin` admin API where it failed if the admin with the greatest power level was not in the room. Contributed by Pankaj Yadav. ([\#9235](https://github.com/matrix-org/synapse/issues/9235))
- Prevent password hashes from getting dropped if a client failed threepid validation during a User Interactive Auth stage. Removes a workaround for an ancient bug in Riot Web <v0.7.4. ([\#9265](https://github.com/matrix-org/synapse/issues/9265))
- Fix single-sign-on when the endpoints are routed to synapse workers. ([\#9271](https://github.com/matrix-org/synapse/issues/9271))
Improved Documentation
----------------------
- Add docs for using Gitea as OpenID provider. ([\#9134](https://github.com/matrix-org/synapse/issues/9134))
- Add link to Matrix VoIP tester for turn-howto. ([\#9135](https://github.com/matrix-org/synapse/issues/9135))
- Add notes on integrating with Facebook for SSO login. ([\#9244](https://github.com/matrix-org/synapse/issues/9244))
Deprecations and Removals
-------------------------
- The `service_url` parameter in `cas_config` is deprecated in favor of `public_baseurl`. ([\#9199](https://github.com/matrix-org/synapse/issues/9199))
- Add new endpoint `/_synapse/client/saml2` for SAML2 authentication callbacks, and deprecate the old endpoint `/_matrix/saml2`. ([\#9289](https://github.com/matrix-org/synapse/issues/9289))
Internal Changes
----------------
- Add tests to `test_user.UsersListTestCase` for List Users Admin API. ([\#9045](https://github.com/matrix-org/synapse/issues/9045))
- Various improvements to the federation client. ([\#9129](https://github.com/matrix-org/synapse/issues/9129))
- Speed up chain cover calculation when persisting a batch of state events at once. ([\#9176](https://github.com/matrix-org/synapse/issues/9176))
- Add a `long_description_type` to the package metadata. ([\#9180](https://github.com/matrix-org/synapse/issues/9180))
- Speed up batch insertion when using PostgreSQL. ([\#9181](https://github.com/matrix-org/synapse/issues/9181), [\#9188](https://github.com/matrix-org/synapse/issues/9188))
- Emit an error at startup if different Identity Providers are configured with the same `idp_id`. ([\#9184](https://github.com/matrix-org/synapse/issues/9184))
- Improve performance of concurrent use of `StreamIDGenerators`. ([\#9190](https://github.com/matrix-org/synapse/issues/9190))
- Add some missing source directories to the automatic linting script. ([\#9191](https://github.com/matrix-org/synapse/issues/9191))
- Precompute joined hosts and store in Redis. ([\#9198](https://github.com/matrix-org/synapse/issues/9198), [\#9227](https://github.com/matrix-org/synapse/issues/9227))
- Clean-up template loading code. ([\#9200](https://github.com/matrix-org/synapse/issues/9200))
- Fix the Python 3.5 old dependencies build. ([\#9217](https://github.com/matrix-org/synapse/issues/9217))
- Update `isort` to v5.7.0 to bypass a bug where it would disagree with `black` about formatting. ([\#9222](https://github.com/matrix-org/synapse/issues/9222))
- Add type hints to handlers code. ([\#9223](https://github.com/matrix-org/synapse/issues/9223), [\#9232](https://github.com/matrix-org/synapse/issues/9232))
- Fix Debian package building on Ubuntu 16.04 LTS (Xenial). ([\#9254](https://github.com/matrix-org/synapse/issues/9254))
- Minor performance improvement during TLS handshake. ([\#9255](https://github.com/matrix-org/synapse/issues/9255))
- Refactor the generation of summary text for email notifications. ([\#9260](https://github.com/matrix-org/synapse/issues/9260))
- Restore PyPy compatibility by not calling CPython-specific GC methods when under PyPy. ([\#9270](https://github.com/matrix-org/synapse/issues/9270))
Synapse 1.26.0 (2021-01-27)
===========================
This release brings a new schema version for Synapse and rolling back to a previous
version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
on these changes and for general upgrade guidance.
No significant changes since 1.26.0rc2.
Synapse 1.26.0rc2 (2021-01-25)
==============================
Bugfixes
--------
- Fix receipts and account data not being sent down sync. Introduced in v1.26.0rc1. ([\#9193](https://github.com/matrix-org/synapse/issues/9193), [\#9195](https://github.com/matrix-org/synapse/issues/9195))
- Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. ([\#9210](https://github.com/matrix-org/synapse/issues/9210))
Internal Changes
----------------
- Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. ([\#9189](https://github.com/matrix-org/synapse/issues/9189))
- Bump minimum `psycopg2` version to v2.8. ([\#9204](https://github.com/matrix-org/synapse/issues/9204))
Synapse 1.26.0rc1 (2021-01-20) Synapse 1.26.0rc1 (2021-01-20)
============================== ==============================

View File

@@ -85,58 +85,6 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.27.0
====================
Changes to callback URI for OAuth2 / OpenID Connect
---------------------------------------------------
This version changes the URI used for callbacks from OAuth2 identity providers. If
your server is configured for single sign-on via an OpenID Connect or OAuth2 identity
provider, you will need to add ``[synapse public baseurl]/_synapse/client/oidc/callback``
to the list of permitted "redirect URIs" at the identity provider.
See `docs/openid.md <docs/openid.md>`_ for more information on setting up OpenID
Connect.
(Note: a similar change is being made for SAML2; in this case the old URI
``[synapse public baseurl]/_matrix/saml2`` is being deprecated, but will continue to
work, so no immediate changes are required for existing installations.)
Changes to HTML templates
-------------------------
The HTML templates for SSO and email notifications now have `Jinja2's autoescape <https://jinja.palletsprojects.com/en/2.11.x/api/#autoescaping>`_
enabled for files ending in ``.html``, ``.htm``, and ``.xml``. If you have customised
these templates and see issues when viewing them you might need to update them.
It is expected that most configurations will need no changes.
If you have customised the templates *names* for these templates, it is recommended
to verify they end in ``.html`` to ensure autoescape is enabled.
The above applies to the following templates:
* ``add_threepid.html``
* ``add_threepid_failure.html``
* ``add_threepid_success.html``
* ``notice_expiry.html``
* ``notice_expiry.html``
* ``notif_mail.html`` (which, by default, includes ``room.html`` and ``notif.html``)
* ``password_reset.html``
* ``password_reset_confirmation.html``
* ``password_reset_failure.html``
* ``password_reset_success.html``
* ``registration.html``
* ``registration_failure.html``
* ``registration_success.html``
* ``sso_account_deactivated.html``
* ``sso_auth_bad_user.html``
* ``sso_auth_confirm.html``
* ``sso_auth_success.html``
* ``sso_error.html``
* ``sso_login_idp_picker.html``
* ``sso_redirect_confirm.html``
Upgrading to v1.26.0 Upgrading to v1.26.0
==================== ====================

1
changelog.d/9045.misc Normal file
View File

@@ -0,0 +1 @@
Add tests to `test_user.UsersListTestCase` for List Users Admin API.

1
changelog.d/9129.misc Normal file
View File

@@ -0,0 +1 @@
Various improvements to the federation client.

1
changelog.d/9135.doc Normal file
View File

@@ -0,0 +1 @@
Add link to Matrix VoIP tester for turn-howto.

1
changelog.d/9163.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled).

1
changelog.d/9176.misc Normal file
View File

@@ -0,0 +1 @@
Speed up chain cover calculation when persisting a batch of state events at once.

1
changelog.d/9180.misc Normal file
View File

@@ -0,0 +1 @@
Add a `long_description_type` to the package metadata.

1
changelog.d/9181.misc Normal file
View File

@@ -0,0 +1 @@
Speed up batch insertion when using PostgreSQL.

1
changelog.d/9184.misc Normal file
View File

@@ -0,0 +1 @@
Emit an error at startup if different Identity Providers are configured with the same `idp_id`.

1
changelog.d/9188.misc Normal file
View File

@@ -0,0 +1 @@
Speed up batch insertion when using PostgreSQL.

1
changelog.d/9189.misc Normal file
View File

@@ -0,0 +1 @@
Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration.

1
changelog.d/9190.misc Normal file
View File

@@ -0,0 +1 @@
Improve performance of concurrent use of `StreamIDGenerators`.

1
changelog.d/9191.misc Normal file
View File

@@ -0,0 +1 @@
Add some missing source directories to the automatic linting script.

1
changelog.d/9193.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.

1
changelog.d/9195.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.

View File

@@ -33,13 +33,11 @@ esac
# Use --builtin-venv to use the better `venv` module from CPython 3.4+ rather # Use --builtin-venv to use the better `venv` module from CPython 3.4+ rather
# than the 2/3 compatible `virtualenv`. # than the 2/3 compatible `virtualenv`.
# Pin pip to 20.3.4 to fix breakage in 21.0 on py3.5 (xenial)
dh_virtualenv \ dh_virtualenv \
--install-suffix "matrix-synapse" \ --install-suffix "matrix-synapse" \
--builtin-venv \ --builtin-venv \
--python "$SNAKE" \ --python "$SNAKE" \
--upgrade-pip-to="20.3.4" \ --upgrade-pip \
--preinstall="lxml" \ --preinstall="lxml" \
--preinstall="mock" \ --preinstall="mock" \
--extra-pip-arg="--no-cache-dir" \ --extra-pip-arg="--no-cache-dir" \

14
debian/changelog vendored
View File

@@ -1,18 +1,8 @@
matrix-synapse-py3 (1.26.0+nmu1) UNRELEASED; urgency=medium matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium
* Fix build on Ubuntu 16.04 LTS (Xenial).
-- Dan Callahan <danc@element.io> Thu, 28 Jan 2021 16:21:03 +0000
matrix-synapse-py3 (1.26.0) stable; urgency=medium
[ Richard van der Hoff ]
* Remove dependency on `python3-distutils`. * Remove dependency on `python3-distutils`.
[ Synapse Packaging team ] -- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000
* New synapse release 1.26.0.
-- Synapse Packaging team <packages@matrix.org> Wed, 27 Jan 2021 12:43:35 -0500
matrix-synapse-py3 (1.25.0) stable; urgency=medium matrix-synapse-py3 (1.25.0) stable; urgency=medium

View File

@@ -27,7 +27,6 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \
wget wget
# fetch and unpack the package # fetch and unpack the package
# TODO: Upgrade to 1.2.2 once xenial is dropped
RUN mkdir /dh-virtualenv RUN mkdir /dh-virtualenv
RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz
RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz

View File

@@ -9,7 +9,6 @@
* [Response](#response) * [Response](#response)
* [Undoing room shutdowns](#undoing-room-shutdowns) * [Undoing room shutdowns](#undoing-room-shutdowns)
- [Make Room Admin API](#make-room-admin-api) - [Make Room Admin API](#make-room-admin-api)
- [Forward Extremities Admin API](#forward-extremities-admin-api)
# List Room API # List Room API
@@ -368,36 +367,6 @@ Response:
} }
``` ```
# Room State API
The Room State admin API allows server admins to get a list of all state events in a room.
The response includes the following fields:
* `state` - The current state of the room at the time of request.
## Usage
A standard request:
```
GET /_synapse/admin/v1/rooms/<room_id>/state
{}
```
Response:
```json
{
"state": [
{"type": "m.room.create", "state_key": "", "etc": true},
{"type": "m.room.power_levels", "state_key": "", "etc": true},
{"type": "m.room.name", "state_key": "", "etc": true}
]
}
```
# Delete Room API # Delete Room API
The Delete Room admin API allows server admins to remove rooms from server The Delete Room admin API allows server admins to remove rooms from server
@@ -542,55 +511,3 @@ optionally be specified, e.g.:
"user_id": "@foo:example.com" "user_id": "@foo:example.com"
} }
``` ```
# Forward Extremities Admin API
Enables querying and deleting forward extremities from rooms. When a lot of forward
extremities accumulate in a room, performance can become degraded. For details, see
[#1760](https://github.com/matrix-org/synapse/issues/1760).
## Check for forward extremities
To check the status of forward extremities for a room:
```
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
```
A response as follows will be returned:
```json
{
"count": 1,
"results": [
{
"event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdefgh",
"state_group": 439,
"depth": 123,
"received_ts": 1611263016761
}
]
}
```
## Deleting forward extremities
**WARNING**: Please ensure you know what you're doing and have read
the related issue [#1760](https://github.com/matrix-org/synapse/issues/1760).
Under no situations should this API be executed as an automated maintenance task!
If a room has lots of forward extremities, the extra can be
deleted as follows:
```
DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
```
A response as follows will be returned, indicating the amount of forward extremities
that were deleted.
```json
{
"deleted": 1
}
```

View File

@@ -760,33 +760,3 @@ The following fields are returned in the JSON response body:
- ``total`` - integer - Number of pushers. - ``total`` - integer - Number of pushers.
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_ See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
Shadow-banning users
====================
Shadow-banning is a useful tool for moderating malicious or egregiously abusive users.
A shadow-banned users receives successful responses to their client-server API requests,
but the events are not propagated into rooms. This can be an effective tool as it
(hopefully) takes longer for the user to realise they are being moderated before
pivoting to another account.
Shadow-banning a user should be used as a tool of last resort and may lead to confusing
or broken behaviour for the client. A shadow-banned user will not receive any
notification and it is generally more appropriate to ban or kick abusive users.
A shadow-banned user will be unable to contact anyone on the server.
The API is::
POST /_synapse/admin/v1/users/<user_id>/shadow_ban
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
An empty JSON dict is returned.
**Parameters**
The following parameters should be set in the URL:
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
be local.

View File

@@ -54,8 +54,7 @@ Here are a few configs for providers that should work with Synapse.
### Microsoft Azure Active Directory ### Microsoft Azure Active Directory
Azure AD can act as an OpenID Connect Provider. Register a new application under Azure AD can act as an OpenID Connect Provider. Register a new application under
*App registrations* in the Azure AD management console. The RedirectURI for your *App registrations* in the Azure AD management console. The RedirectURI for your
application should point to your matrix server: application should point to your matrix server: `[synapse public baseurl]/_synapse/oidc/callback`
`[synapse public baseurl]/_synapse/client/oidc/callback`
Go to *Certificates & secrets* and register a new client secret. Make note of your Go to *Certificates & secrets* and register a new client secret. Make note of your
Directory (tenant) ID as it will be used in the Azure links. Directory (tenant) ID as it will be used in the Azure links.
@@ -95,7 +94,7 @@ staticClients:
- id: synapse - id: synapse
secret: secret secret: secret
redirectURIs: redirectURIs:
- '[synapse public baseurl]/_synapse/client/oidc/callback' - '[synapse public baseurl]/_synapse/oidc/callback'
name: 'Synapse' name: 'Synapse'
``` ```
@@ -141,7 +140,7 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
| Enabled | `On` | | Enabled | `On` |
| Client Protocol | `openid-connect` | | Client Protocol | `openid-connect` |
| Access Type | `confidential` | | Access Type | `confidential` |
| Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` | | Valid Redirect URIs | `[synapse public baseurl]/_synapse/oidc/callback` |
5. Click `Save` 5. Click `Save`
6. On the Credentials tab, update the fields: 6. On the Credentials tab, update the fields:
@@ -169,7 +168,7 @@ oidc_providers:
### [Auth0][auth0] ### [Auth0][auth0]
1. Create a regular web application for Synapse 1. Create a regular web application for Synapse
2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/client/oidc/callback` 2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/oidc/callback`
3. Add a rule to add the `preferred_username` claim. 3. Add a rule to add the `preferred_username` claim.
<details> <details>
<summary>Code sample</summary> <summary>Code sample</summary>
@@ -218,7 +217,7 @@ login mechanism needs an attribute to uniquely identify users, and that endpoint
does not return a `sub` property, an alternative `subject_claim` has to be set. does not return a `sub` property, an alternative `subject_claim` has to be set.
1. Create a new OAuth application: https://github.com/settings/applications/new. 1. Create a new OAuth application: https://github.com/settings/applications/new.
2. Set the callback URL to `[synapse public baseurl]/_synapse/client/oidc/callback`. 2. Set the callback URL to `[synapse public baseurl]/_synapse/oidc/callback`.
Synapse config: Synapse config:
@@ -226,7 +225,6 @@ Synapse config:
oidc_providers: oidc_providers:
- idp_id: github - idp_id: github
idp_name: Github idp_name: Github
idp_brand: "org.matrix.github" # optional: styling hint for clients
discover: false discover: false
issuer: "https://github.com/" issuer: "https://github.com/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
@@ -252,7 +250,6 @@ oidc_providers:
oidc_providers: oidc_providers:
- idp_id: google - idp_id: google
idp_name: Google idp_name: Google
idp_brand: "org.matrix.google" # optional: styling hint for clients
issuer: "https://accounts.google.com/" issuer: "https://accounts.google.com/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED
@@ -263,13 +260,13 @@ oidc_providers:
display_name_template: "{{ user.name }}" display_name_template: "{{ user.name }}"
``` ```
4. Back in the Google console, add this Authorized redirect URI: `[synapse 4. Back in the Google console, add this Authorized redirect URI: `[synapse
public baseurl]/_synapse/client/oidc/callback`. public baseurl]/_synapse/oidc/callback`.
### Twitch ### Twitch
1. Setup a developer account on [Twitch](https://dev.twitch.tv/) 1. Setup a developer account on [Twitch](https://dev.twitch.tv/)
2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/) 2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/)
3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/client/oidc/callback` 3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/oidc/callback`
Synapse config: Synapse config:
@@ -291,7 +288,7 @@ oidc_providers:
1. Create a [new application](https://gitlab.com/profile/applications). 1. Create a [new application](https://gitlab.com/profile/applications).
2. Add the `read_user` and `openid` scopes. 2. Add the `read_user` and `openid` scopes.
3. Add this Callback URL: `[synapse public baseurl]/_synapse/client/oidc/callback` 3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback`
Synapse config: Synapse config:
@@ -299,7 +296,6 @@ Synapse config:
oidc_providers: oidc_providers:
- idp_id: gitlab - idp_id: gitlab
idp_name: Gitlab idp_name: Gitlab
idp_brand: "org.matrix.gitlab" # optional: styling hint for clients
issuer: "https://gitlab.com/" issuer: "https://gitlab.com/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED
@@ -311,80 +307,3 @@ oidc_providers:
localpart_template: '{{ user.nickname }}' localpart_template: '{{ user.nickname }}'
display_name_template: '{{ user.name }}' display_name_template: '{{ user.name }}'
``` ```
### Facebook
Like Github, Facebook provide a custom OAuth2 API rather than an OIDC-compliant
one so requires a little more configuration.
0. You will need a Facebook developer account. You can register for one
[here](https://developers.facebook.com/async/registration/).
1. On the [apps](https://developers.facebook.com/apps/) page of the developer
console, "Create App", and choose "Build Connected Experiences".
2. Once the app is created, add "Facebook Login" and choose "Web". You don't
need to go through the whole form here.
3. In the left-hand menu, open "Products"/"Facebook Login"/"Settings".
* Add `[synapse public baseurl]/_synapse/client/oidc/callback` as an OAuth Redirect
URL.
4. In the left-hand menu, open "Settings/Basic". Here you can copy the "App ID"
and "App Secret" for use below.
Synapse config:
```yaml
- idp_id: facebook
idp_name: Facebook
idp_brand: "org.matrix.facebook" # optional: styling hint for clients
discover: false
issuer: "https://facebook.com"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
scopes: ["openid", "email"]
authorization_endpoint: https://facebook.com/dialog/oauth
token_endpoint: https://graph.facebook.com/v9.0/oauth/access_token
user_profile_method: "userinfo_endpoint"
userinfo_endpoint: "https://graph.facebook.com/v9.0/me?fields=id,name,email,picture"
user_mapping_provider:
config:
subject_claim: "id"
display_name_template: "{{ user.name }}"
```
Relevant documents:
* https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow
* Using Facebook's Graph API: https://developers.facebook.com/docs/graph-api/using-graph-api/
* Reference to the User endpoint: https://developers.facebook.com/docs/graph-api/reference/user
### Gitea
Gitea is, like Github, not an OpenID provider, but just an OAuth2 provider.
The [`/user` API endpoint](https://try.gitea.io/api/swagger#/user/userGetCurrent)
can be used to retrieve information on the authenticated user. As the Synapse
login mechanism needs an attribute to uniquely identify users, and that endpoint
does not return a `sub` property, an alternative `subject_claim` has to be set.
1. Create a new application.
2. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback`
Synapse config:
```yaml
oidc_providers:
- idp_id: gitea
idp_name: Gitea
discover: false
issuer: "https://your-gitea.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
client_auth_method: client_secret_post
scopes: [] # Gitea doesn't support Scopes
authorization_endpoint: "https://your-gitea.com/login/oauth/authorize"
token_endpoint: "https://your-gitea.com/login/oauth/access_token"
userinfo_endpoint: "https://your-gitea.com/api/v1/user"
user_mapping_provider:
config:
subject_claim: "id"
localpart_template: "{{ user.login }}"
display_name_template: "{{ user.full_name }}"
```

View File

@@ -824,9 +824,6 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# users are joining rooms the server is already in (this is cheap) vs # users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which # "remote" for when users are trying to join rooms not on the server (which
# can be more expensive) # can be more expensive)
# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
# - two for ratelimiting how often invites can be sent in a room or to a
# specific user.
# #
# The defaults are as shown below. # The defaults are as shown below.
# #
@@ -860,18 +857,7 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# remote: # remote:
# per_second: 0.01 # per_second: 0.01
# burst_count: 3 # burst_count: 3
#
#rc_3pid_validation:
# per_second: 0.003
# burst_count: 5
#
#rc_invites:
# per_room:
# per_second: 0.3
# burst_count: 10
# per_user:
# per_second: 0.003
# burst_count: 5
# Ratelimiting settings for incoming federation # Ratelimiting settings for incoming federation
# #
@@ -1566,10 +1552,10 @@ trusted_key_servers:
# enable SAML login. # enable SAML login.
# #
# Once SAML support is enabled, a metadata file will be exposed at # Once SAML support is enabled, a metadata file will be exposed at
# https://<server>:<port>/_synapse/client/saml2/metadata.xml, which you may be able to # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
# use to configure your SAML IdP with. Alternatively, you can manually configure # use to configure your SAML IdP with. Alternatively, you can manually configure
# the IdP to use an ACS location of # the IdP to use an ACS location of
# https://<server>:<port>/_synapse/client/saml2/authn_response. # https://<server>:<port>/_matrix/saml2/authn_response.
# #
saml2_config: saml2_config:
# `sp_config` is the configuration for the pysaml2 Service Provider. # `sp_config` is the configuration for the pysaml2 Service Provider.
@@ -1741,14 +1727,10 @@ saml2_config:
# offer the user a choice of login mechanisms. # offer the user a choice of login mechanisms.
# #
# idp_icon: An optional icon for this identity provider, which is presented # idp_icon: An optional icon for this identity provider, which is presented
# by clients and Synapse's own IdP picker page. If given, must be an # by identity picker pages. If given, must be an MXC URI of the format
# MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to # mxc://<server-name>/<media-id>. (An easy way to obtain such an MXC URI
# obtain such an MXC URI is to upload an image to an (unencrypted) room # is to upload an image to an (unencrypted) room and then copy the "url"
# and then copy the "url" from the source of the event.) # from the source of the event.)
#
# idp_brand: An optional brand for this identity provider, allowing clients
# to style the login flow according to the identity provider in question.
# See the spec for possible options here.
# #
# discover: set to 'false' to disable the use of the OIDC discovery mechanism # discover: set to 'false' to disable the use of the OIDC discovery mechanism
# to discover endpoints. Defaults to true. # to discover endpoints. Defaults to true.
@@ -1809,21 +1791,17 @@ saml2_config:
# #
# For the default provider, the following settings are available: # For the default provider, the following settings are available:
# #
# subject_claim: name of the claim containing a unique identifier # sub: name of the claim containing a unique identifier for the
# for the user. Defaults to 'sub', which OpenID Connect # user. Defaults to 'sub', which OpenID Connect compliant
# compliant providers should provide. # providers should provide.
# #
# localpart_template: Jinja2 template for the localpart of the MXID. # localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their # If this is not set, the user will be prompted to choose their
# own username (see 'sso_auth_account_details.html' in the 'sso' # own username.
# section of this file).
# #
# display_name_template: Jinja2 template for the display name to set # display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set. # on first login. If unset, no displayname will be set.
# #
# email_template: Jinja2 template for the email address of the user.
# If unset, no email address will be added to the account.
#
# extra_attributes: a map of Jinja2 templates for extra attributes # extra_attributes: a map of Jinja2 templates for extra attributes
# to send back to the client during login. # to send back to the client during login.
# Note that these are non-standard and clients will ignore them # Note that these are non-standard and clients will ignore them
@@ -1859,12 +1837,6 @@ oidc_providers:
# userinfo_endpoint: "https://accounts.example.com/userinfo" # userinfo_endpoint: "https://accounts.example.com/userinfo"
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json" # jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# skip_verification: true # skip_verification: true
# user_mapping_provider:
# config:
# subject_claim: "id"
# localpart_template: "{ user.login }"
# display_name_template: "{ user.name }"
# email_template: "{ user.email }"
# For use with Keycloak # For use with Keycloak
# #
@@ -1879,7 +1851,6 @@ oidc_providers:
# #
#- idp_id: github #- idp_id: github
# idp_name: Github # idp_name: Github
# idp_brand: org.matrix.github
# discover: false # discover: false
# issuer: "https://github.com/" # issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED # client_id: "your-client-id" # TO BE FILLED
@@ -1907,6 +1878,10 @@ cas_config:
# #
#server_url: "https://cas-server.com" #server_url: "https://cas-server.com"
# The public URL of the homeserver.
#
#service_url: "https://homeserver.domain.com:8448"
# The attribute of the CAS response to use as the display name. # The attribute of the CAS response to use as the display name.
# #
# If unset, no displayname will be set. # If unset, no displayname will be set.
@@ -1968,13 +1943,8 @@ sso:
# #
# * providers: a list of available Identity Providers. Each element is # * providers: a list of available Identity Providers. Each element is
# an object with the following attributes: # an object with the following attributes:
#
# * idp_id: unique identifier for the IdP # * idp_id: unique identifier for the IdP
# * idp_name: user-facing name for the IdP # * idp_name: user-facing name for the IdP
# * idp_icon: if specified in the IdP config, an MXC URI for an icon
# for the IdP
# * idp_brand: if specified in the IdP config, a textual identifier
# for the brand of the IdP
# #
# The rendered HTML page should contain a form which submits its results # The rendered HTML page should contain a form which submits its results
# back as a GET request, with the following query parameters: # back as a GET request, with the following query parameters:
@@ -1984,62 +1954,10 @@ sso:
# #
# * idp: the 'idp_id' of the chosen IDP. # * idp: the 'idp_id' of the chosen IDP.
# #
# * HTML page to prompt new users to enter a userid and confirm other
# details: 'sso_auth_account_details.html'. This is only shown if the
# SSO implementation (with any user_mapping_provider) does not return
# a localpart.
#
# When rendering, this template is given the following variables:
#
# * server_name: the homeserver's name.
#
# * idp: details of the SSO Identity Provider that the user logged in
# with: an object with the following attributes:
#
# * idp_id: unique identifier for the IdP
# * idp_name: user-facing name for the IdP
# * idp_icon: if specified in the IdP config, an MXC URI for an icon
# for the IdP
# * idp_brand: if specified in the IdP config, a textual identifier
# for the brand of the IdP
#
# * user_attributes: an object containing details about the user that
# we received from the IdP. May have the following attributes:
#
# * display_name: the user's display_name
# * emails: a list of email addresses
#
# The template should render a form which submits the following fields:
#
# * username: the localpart of the user's chosen user id
#
# * HTML page allowing the user to consent to the server's terms and
# conditions. This is only shown for new users, and only if
# `user_consent.require_at_registration` is set.
#
# When rendering, this template is given the following variables:
#
# * server_name: the homeserver's name.
#
# * user_id: the user's matrix proposed ID.
#
# * user_profile.display_name: the user's proposed display name, if any.
#
# * consent_version: the version of the terms that the user will be
# shown
#
# * terms_url: a link to the page showing the terms.
#
# The template should render a form which submits the following fields:
#
# * accepted_version: the version of the terms accepted by the user
# (ie, 'consent_version' from the input variables).
#
# * HTML page for a confirmation step before redirecting back to the client # * HTML page for a confirmation step before redirecting back to the client
# with the login token: 'sso_redirect_confirm.html'. # with the login token: 'sso_redirect_confirm.html'.
# #
# When rendering, this template is given the following variables: # When rendering, this template is given three variables:
#
# * redirect_url: the URL the user is about to be redirected to. Needs # * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see # manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
@@ -2052,17 +1970,6 @@ sso:
# #
# * server_name: the homeserver's name. # * server_name: the homeserver's name.
# #
# * new_user: a boolean indicating whether this is the user's first time
# logging in.
#
# * user_id: the user's matrix ID.
#
# * user_profile.avatar_url: an MXC URI for the user's avatar, if any.
# None if the user has not set an avatar.
#
# * user_profile.display_name: the user's display name. None if the user
# has not set a display name.
#
# * HTML page which notifies the user that they are authenticating to confirm # * HTML page which notifies the user that they are authenticating to confirm
# an operation on their account during the user interactive authentication # an operation on their account during the user interactive authentication
# process: 'sso_auth_confirm.html'. # process: 'sso_auth_confirm.html'.
@@ -2074,16 +1981,6 @@ sso:
# #
# * description: the operation which the user is being asked to confirm # * description: the operation which the user is being asked to confirm
# #
# * idp: details of the Identity Provider that we will use to confirm
# the user's identity: an object with the following attributes:
#
# * idp_id: unique identifier for the IdP
# * idp_name: user-facing name for the IdP
# * idp_icon: if specified in the IdP config, an MXC URI for an icon
# for the IdP
# * idp_brand: if specified in the IdP config, a textual identifier
# for the brand of the IdP
#
# * HTML page shown after a successful user interactive authentication session: # * HTML page shown after a successful user interactive authentication session:
# 'sso_auth_success.html'. # 'sso_auth_success.html'.
# #

View File

@@ -40,9 +40,6 @@ which relays replication commands between processes. This can give a significant
cpu saving on the main process and will be a prerequisite for upcoming cpu saving on the main process and will be a prerequisite for upcoming
performance improvements. performance improvements.
If Redis support is enabled Synapse will use it as a shared cache, as well as a
pub/sub mechanism.
See the [Architectural diagram](#architectural-diagram) section at the end for See the [Architectural diagram](#architectural-diagram) section at the end for
a visualisation of what this looks like. a visualisation of what this looks like.
@@ -228,6 +225,7 @@ expressions:
^/_matrix/client/(api/v1|r0|unstable)/joined_groups$ ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
^/_synapse/client/password_reset/email/submit_token$
# Registration/login requests # Registration/login requests
^/_matrix/client/(api/v1|r0|unstable)/login$ ^/_matrix/client/(api/v1|r0|unstable)/login$
@@ -258,29 +256,25 @@ Additionally, the following endpoints should be included if Synapse is configure
to use SSO (you only need to include the ones for whichever SSO provider you're to use SSO (you only need to include the ones for whichever SSO provider you're
using): using):
# for all SSO providers
^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect
^/_synapse/client/pick_idp$
^/_synapse/client/pick_username
^/_synapse/client/new_user_consent$
^/_synapse/client/sso_register$
# OpenID Connect requests. # OpenID Connect requests.
^/_synapse/client/oidc/callback$ ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
^/_synapse/oidc/callback$
# SAML requests. # SAML requests.
^/_synapse/client/saml2/authn_response$ ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
^/_matrix/saml2/authn_response$
# CAS requests. # CAS requests.
^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$
^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$
Ensure that all SSO logins go to a single process.
For multiple workers not handling the SSO endpoints properly, see
[#7530](https://github.com/matrix-org/synapse/issues/7530).
Note that a HTTP listener with `client` and `federation` resources must be Note that a HTTP listener with `client` and `federation` resources must be
configured in the `worker_listeners` option in the worker config. configured in the `worker_listeners` option in the worker config.
Ensure that all SSO logins go to a single process (usually the main process).
For multiple workers not handling the SSO endpoints properly, see
[#7530](https://github.com/matrix-org/synapse/issues/7530).
#### Load balancing #### Load balancing
It is possible to run multiple instances of this worker app, with incoming requests It is possible to run multiple instances of this worker app, with incoming requests

View File

@@ -23,7 +23,39 @@ files =
synapse/events/validator.py, synapse/events/validator.py,
synapse/events/spamcheck.py, synapse/events/spamcheck.py,
synapse/federation, synapse/federation,
synapse/handlers, synapse/handlers/_base.py,
synapse/handlers/account_data.py,
synapse/handlers/account_validity.py,
synapse/handlers/admin.py,
synapse/handlers/appservice.py,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
synapse/handlers/deactivate_account.py,
synapse/handlers/device.py,
synapse/handlers/devicemessage.py,
synapse/handlers/directory.py,
synapse/handlers/events.py,
synapse/handlers/federation.py,
synapse/handlers/identity.py,
synapse/handlers/initial_sync.py,
synapse/handlers/message.py,
synapse/handlers/oidc_handler.py,
synapse/handlers/pagination.py,
synapse/handlers/password_policy.py,
synapse/handlers/presence.py,
synapse/handlers/profile.py,
synapse/handlers/read_marker.py,
synapse/handlers/receipts.py,
synapse/handlers/register.py,
synapse/handlers/room.py,
synapse/handlers/room_list.py,
synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py,
synapse/handlers/saml_handler.py,
synapse/handlers/sso.py,
synapse/handlers/sync.py,
synapse/handlers/user_directory.py,
synapse/handlers/ui_auth,
synapse/http/client.py, synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py, synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/well_known_resolver.py, synapse/http/federation/well_known_resolver.py,
@@ -162,9 +194,3 @@ ignore_missing_imports = True
[mypy-hiredis] [mypy-hiredis]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-josepy.*]
ignore_missing_imports = True
[mypy-txacme.*]
ignore_missing_imports = True

View File

@@ -96,7 +96,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
# #
# We pin black so that our tests don't start failing on new releases. # We pin black so that our tests don't start failing on new releases.
CONDITIONAL_REQUIREMENTS["lint"] = [ CONDITIONAL_REQUIREMENTS["lint"] = [
"isort==5.7.0", "isort==5.0.3",
"black==19.10b0", "black==19.10b0",
"flake8-comprehensions", "flake8-comprehensions",
"flake8", "flake8",

View File

@@ -15,23 +15,13 @@
"""Contains *incomplete* type hints for txredisapi. """Contains *incomplete* type hints for txredisapi.
""" """
from typing import Any, List, Optional, Type, Union
from typing import List, Optional, Type, Union
class RedisProtocol: class RedisProtocol:
def publish(self, channel: str, message: bytes): ... def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
self,
key: str,
value: Any,
expire: Optional[int] = None,
pexpire: Optional[int] = None,
only_if_not_exists: bool = False,
only_if_exists: bool = False,
) -> None: ...
async def get(self, key: str) -> Any: ...
class SubscriberProtocol(RedisProtocol): class SubscriberProtocol:
def __init__(self, *args, **kwargs): ... def __init__(self, *args, **kwargs): ...
password: Optional[str] password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ... def subscribe(self, channels: Union[str, List[str]]): ...
@@ -50,13 +40,14 @@ def lazyConnection(
convertNumbers: bool = ..., convertNumbers: bool = ...,
) -> RedisProtocol: ... ) -> RedisProtocol: ...
class SubscriberFactory:
def buildProtocol(self, addr): ...
class ConnectionHandler: ... class ConnectionHandler: ...
class RedisFactory: class RedisFactory:
continueTrying: bool continueTrying: bool
handler: RedisProtocol handler: RedisProtocol
pool: List[RedisProtocol]
replyTimeout: Optional[int]
def __init__( def __init__(
self, self,
uuid: str, uuid: str,
@@ -69,7 +60,3 @@ class RedisFactory:
replyTimeout: Optional[int] = None, replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True, convertNumbers: Optional[int] = True,
): ... ): ...
def buildProtocol(self, addr) -> RedisProtocol: ...
class SubscriberFactory(RedisFactory):
def __init__(self): ...

View File

@@ -48,7 +48,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.27.0rc1" __version__ = "1.26.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View File

@@ -16,7 +16,6 @@
import gc import gc
import logging import logging
import os import os
import platform
import signal import signal
import socket import socket
import sys import sys
@@ -340,7 +339,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
# rest of time. Doing so means less work each GC (hopefully). # rest of time. Doing so means less work each GC (hopefully).
# #
# This only works on Python 3.7 # This only works on Python 3.7
if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7): if sys.version_info >= (3, 7):
gc.collect() gc.collect()
gc.freeze() gc.freeze()

View File

@@ -22,7 +22,6 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager from typing_extensions import ContextManager
from twisted.internet import address from twisted.internet import address
from twisted.web.resource import IResource
import synapse import synapse
import synapse.events import synapse.events
@@ -91,8 +90,9 @@ from synapse.replication.tcp.streams import (
ToDeviceStream, ToDeviceStream,
) )
from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v1 import events, room
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.rest.client.v1.login import LoginRestServlet
from synapse.rest.client.v1.profile import ( from synapse.rest.client.v1.profile import (
ProfileAvatarURLRestServlet, ProfileAvatarURLRestServlet,
ProfileDisplaynameRestServlet, ProfileDisplaynameRestServlet,
@@ -127,7 +127,6 @@ from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.health import HealthResource from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer, cache_in_self from synapse.server import HomeServer, cache_in_self
from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
@@ -508,7 +507,7 @@ class GenericWorkerServer(HomeServer):
site_tag = port site_tag = port
# We always include a health resource. # We always include a health resource.
resources = {"/health": HealthResource()} # type: Dict[str, IResource] resources = {"/health": HealthResource()}
for res in listener_config.http_options.resources: for res in listener_config.http_options.resources:
for name in res.names: for name in res.names:
@@ -518,7 +517,7 @@ class GenericWorkerServer(HomeServer):
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
RegisterRestServlet(self).register(resource) RegisterRestServlet(self).register(resource)
login.register_servlets(self, resource) LoginRestServlet(self).register(resource)
ThreepidRestServlet(self).register(resource) ThreepidRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource) DevicesRestServlet(self).register(resource)
KeyQueryServlet(self).register(resource) KeyQueryServlet(self).register(resource)
@@ -558,8 +557,6 @@ class GenericWorkerServer(HomeServer):
groups.register_servlets(self, resource) groups.register_servlets(self, resource)
resources.update({CLIENT_API_PREFIX: resource}) resources.update({CLIENT_API_PREFIX: resource})
resources.update(build_synapse_client_resource_tree(self))
elif name == "federation": elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
elif name == "media": elif name == "media":

View File

@@ -60,7 +60,8 @@ from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource from synapse.rest.admin import AdminRestResource
from synapse.rest.health import HealthResource from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.synapse.client.pick_idp import PickIdpResource
from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.rest.well_known import WellKnownResource from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
@@ -189,10 +190,21 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/versions": client_resource, "/_matrix/client/versions": client_resource,
"/.well-known/matrix/client": WellKnownResource(self), "/.well-known/matrix/client": WellKnownResource(self),
"/_synapse/admin": AdminRestResource(self), "/_synapse/admin": AdminRestResource(self),
**build_synapse_client_resource_tree(self), "/_synapse/client/pick_username": pick_username_resource(self),
"/_synapse/client/pick_idp": PickIdpResource(self),
} }
) )
if self.get_config().oidc_enabled:
from synapse.rest.oidc import OIDCResource
resources["/_synapse/oidc"] = OIDCResource(self)
if self.get_config().saml2_enabled:
from synapse.rest.saml2 import SAML2Resource
resources["/_matrix/saml2"] = SAML2Resource(self)
if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL: if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL:
from synapse.rest.synapse.client.password_reset import ( from synapse.rest.synapse.client.password_reset import (
PasswordResetSubmitTokenResource, PasswordResetSubmitTokenResource,

View File

@@ -93,20 +93,15 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["daily_active_users"] = await hs.get_datastore().count_daily_users() stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = await hs.get_datastore().count_daily_messages() stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
r30_results = await hs.get_datastore().count_r30_users() r30_results = await hs.get_datastore().count_r30_users()
for name, count in r30_results.items(): for name, count in r30_results.items():
stats["r30_users_" + name] = count stats["r30_users_" + name] = count
daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = hs.config.caches.global_factor stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size stats["event_cache_size"] = hs.config.caches.event_cache_size

View File

@@ -18,18 +18,18 @@
import argparse import argparse
import errno import errno
import os import os
import time
import urllib.parse
from collections import OrderedDict from collections import OrderedDict
from hashlib import sha256 from hashlib import sha256
from textwrap import dedent from textwrap import dedent
from typing import Any, Iterable, List, MutableMapping, Optional from typing import Any, Callable, Iterable, List, MutableMapping, Optional
import attr import attr
import jinja2 import jinja2
import pkg_resources import pkg_resources
import yaml import yaml
from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter
class ConfigError(Exception): class ConfigError(Exception):
"""Represents a problem parsing the configuration """Represents a problem parsing the configuration
@@ -203,28 +203,11 @@ class Config:
with open(file_path) as file_stream: with open(file_path) as file_stream:
return file_stream.read() return file_stream.read()
def read_template(self, filename: str) -> jinja2.Template:
"""Load a template file from disk.
This function will attempt to load the given template from the default Synapse
template directory.
Files read are treated as Jinja templates. The templates is not rendered yet
and has autoescape enabled.
Args:
filename: A template filename to read.
Raises:
ConfigError: if the file's path is incorrect or otherwise cannot be read.
Returns:
A jinja2 template.
"""
return self.read_templates([filename])[0]
def read_templates( def read_templates(
self, filenames: List[str], custom_template_directory: Optional[str] = None, self,
filenames: List[str],
custom_template_directory: Optional[str] = None,
autoescape: bool = False,
) -> List[jinja2.Template]: ) -> List[jinja2.Template]:
"""Load a list of template files from disk using the given variables. """Load a list of template files from disk using the given variables.
@@ -232,8 +215,7 @@ class Config:
template directory. If `custom_template_directory` is supplied, that directory template directory. If `custom_template_directory` is supplied, that directory
is tried first. is tried first.
Files read are treated as Jinja templates. The templates are not rendered yet Files read are treated as Jinja templates. These templates are not rendered yet.
and have autoescape enabled.
Args: Args:
filenames: A list of template filenames to read. filenames: A list of template filenames to read.
@@ -241,12 +223,16 @@ class Config:
custom_template_directory: A directory to try to look for the templates custom_template_directory: A directory to try to look for the templates
before using the default Synapse template directory instead. before using the default Synapse template directory instead.
autoescape: Whether to autoescape variables before inserting them into the
template.
Raises: Raises:
ConfigError: if the file's path is incorrect or otherwise cannot be read. ConfigError: if the file's path is incorrect or otherwise cannot be read.
Returns: Returns:
A list of jinja2 templates. A list of jinja2 templates.
""" """
templates = []
search_directories = [self.default_template_dir] search_directories = [self.default_template_dir]
# The loader will first look in the custom template directory (if specified) for the # The loader will first look in the custom template directory (if specified) for the
@@ -262,9 +248,8 @@ class Config:
# Search the custom template directory as well # Search the custom template directory as well
search_directories.insert(0, custom_template_directory) search_directories.insert(0, custom_template_directory)
# TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(search_directories) loader = jinja2.FileSystemLoader(search_directories)
env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),) env = jinja2.Environment(loader=loader, autoescape=autoescape)
# Update the environment with our custom filters # Update the environment with our custom filters
env.filters.update( env.filters.update(
@@ -274,8 +259,44 @@ class Config:
} }
) )
# Load the templates for filename in filenames:
return [env.get_template(filename) for filename in filenames] # Load the template
template = env.get_template(filename)
templates.append(template)
return templates
def _format_ts_filter(value: int, format: str):
return time.strftime(format, time.localtime(value / 1000))
def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
"""Create and return a jinja2 filter that converts MXC urls to HTTP
Args:
public_baseurl: The public, accessible base URL of the homeserver
"""
def mxc_to_http_filter(value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
return ""
server_and_media_id = value[6:]
fragment = None
if "#" in server_and_media_id:
server_and_media_id, fragment = server_and_media_id.split("#", 1)
fragment = "#" + fragment
params = {"width": width, "height": height, "method": resize_method}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
public_baseurl,
server_and_media_id,
urllib.parse.urlencode(params),
fragment or "",
)
return mxc_to_http_filter
class RootConfig: class RootConfig:

View File

@@ -9,7 +9,6 @@ from synapse.config import (
consent_config, consent_config,
database, database,
emailconfig, emailconfig,
experimental,
groups, groups,
jwt_config, jwt_config,
key, key,
@@ -19,7 +18,6 @@ from synapse.config import (
password_auth_providers, password_auth_providers,
push, push,
ratelimiting, ratelimiting,
redis,
registration, registration,
repository, repository,
room_directory, room_directory,
@@ -50,11 +48,10 @@ def path_exists(file_path: str): ...
class RootConfig: class RootConfig:
server: server.ServerConfig server: server.ServerConfig
experimental: experimental.ExperimentalConfig
tls: tls.TlsConfig tls: tls.TlsConfig
database: database.DatabaseConfig database: database.DatabaseConfig
logging: logger.LoggingConfig logging: logger.LoggingConfig
ratelimiting: ratelimiting.RatelimitConfig ratelimit: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig media: repository.ContentRepositoryConfig
captcha: captcha.CaptchaConfig captcha: captcha.CaptchaConfig
voip: voip.VoipConfig voip: voip.VoipConfig
@@ -82,7 +79,6 @@ class RootConfig:
roomdirectory: room_directory.RoomDirectoryConfig roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig tracer: tracer.TracerConfig
redis: redis.RedisConfig
config_classes: List = ... config_classes: List = ...
def __init__(self) -> None: ... def __init__(self) -> None: ...

View File

@@ -28,7 +28,9 @@ class CaptchaConfig(Config):
"recaptcha_siteverify_api", "recaptcha_siteverify_api",
"https://www.recaptcha.net/recaptcha/api/siteverify", "https://www.recaptcha.net/recaptcha/api/siteverify",
) )
self.recaptcha_template = self.read_template("recaptcha.html") self.recaptcha_template = self.read_templates(
["recaptcha.html"], autoescape=True
)[0]
def generate_config_section(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\

View File

@@ -30,13 +30,7 @@ class CasConfig(Config):
if self.cas_enabled: if self.cas_enabled:
self.cas_server_url = cas_config["server_url"] self.cas_server_url = cas_config["server_url"]
public_base_url = cas_config.get("service_url") or self.public_baseurl self.cas_service_url = cas_config["service_url"]
if public_base_url[-1] != "/":
public_base_url += "/"
# TODO Update this to a _synapse URL.
self.cas_service_url = (
public_base_url + "_matrix/client/r0/login/cas/ticket"
)
self.cas_displayname_attribute = cas_config.get("displayname_attribute") self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes") or {} self.cas_required_attributes = cas_config.get("required_attributes") or {}
else: else:
@@ -59,6 +53,10 @@ class CasConfig(Config):
# #
#server_url: "https://cas-server.com" #server_url: "https://cas-server.com"
# The public URL of the homeserver.
#
#service_url: "https://homeserver.domain.com:8448"
# The attribute of the CAS response to use as the display name. # The attribute of the CAS response to use as the display name.
# #
# If unset, no displayname will be set. # If unset, no displayname will be set.

View File

@@ -89,7 +89,7 @@ class ConsentConfig(Config):
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
consent_config = config.get("user_consent") consent_config = config.get("user_consent")
self.terms_template = self.read_template("terms.html") self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0]
if consent_config is None: if consent_config is None:
return return

View File

@@ -1,29 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.config._base import Config
from synapse.types import JsonDict
class ExperimentalConfig(Config):
"""Config section for enabling experimental features"""
section = "experimental"
def read_config(self, config: JsonDict, **kwargs):
experimental = config.get("experimental_features") or {}
# MSC2858 (multiple SSO identity providers)
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool

View File

@@ -24,7 +24,6 @@ from .cas import CasConfig
from .consent_config import ConsentConfig from .consent_config import ConsentConfig
from .database import DatabaseConfig from .database import DatabaseConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .experimental import ExperimentalConfig
from .federation import FederationConfig from .federation import FederationConfig
from .groups import GroupsConfig from .groups import GroupsConfig
from .jwt_config import JWTConfig from .jwt_config import JWTConfig
@@ -58,7 +57,6 @@ class HomeServerConfig(RootConfig):
config_classes = [ config_classes = [
ServerConfig, ServerConfig,
ExperimentalConfig,
TlsConfig, TlsConfig,
FederationConfig, FederationConfig,
CacheConfig, CacheConfig,

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import string
from collections import Counter from collections import Counter
from typing import Iterable, Optional, Tuple, Type from typing import Iterable, Optional, Tuple, Type
@@ -53,7 +54,8 @@ class OIDCConfig(Config):
"Multiple OIDC providers have the idp_id %r." % idp_id "Multiple OIDC providers have the idp_id %r." % idp_id
) )
self.oidc_callback_url = self.public_baseurl + "_synapse/client/oidc/callback" public_baseurl = self.public_baseurl
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
@property @property
def oidc_enabled(self) -> bool: def oidc_enabled(self) -> bool:
@@ -77,14 +79,10 @@ class OIDCConfig(Config):
# offer the user a choice of login mechanisms. # offer the user a choice of login mechanisms.
# #
# idp_icon: An optional icon for this identity provider, which is presented # idp_icon: An optional icon for this identity provider, which is presented
# by clients and Synapse's own IdP picker page. If given, must be an # by identity picker pages. If given, must be an MXC URI of the format
# MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to # mxc://<server-name>/<media-id>. (An easy way to obtain such an MXC URI
# obtain such an MXC URI is to upload an image to an (unencrypted) room # is to upload an image to an (unencrypted) room and then copy the "url"
# and then copy the "url" from the source of the event.) # from the source of the event.)
#
# idp_brand: An optional brand for this identity provider, allowing clients
# to style the login flow according to the identity provider in question.
# See the spec for possible options here.
# #
# discover: set to 'false' to disable the use of the OIDC discovery mechanism # discover: set to 'false' to disable the use of the OIDC discovery mechanism
# to discover endpoints. Defaults to true. # to discover endpoints. Defaults to true.
@@ -145,21 +143,17 @@ class OIDCConfig(Config):
# #
# For the default provider, the following settings are available: # For the default provider, the following settings are available:
# #
# subject_claim: name of the claim containing a unique identifier # sub: name of the claim containing a unique identifier for the
# for the user. Defaults to 'sub', which OpenID Connect # user. Defaults to 'sub', which OpenID Connect compliant
# compliant providers should provide. # providers should provide.
# #
# localpart_template: Jinja2 template for the localpart of the MXID. # localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their # If this is not set, the user will be prompted to choose their
# own username (see 'sso_auth_account_details.html' in the 'sso' # own username.
# section of this file).
# #
# display_name_template: Jinja2 template for the display name to set # display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set. # on first login. If unset, no displayname will be set.
# #
# email_template: Jinja2 template for the email address of the user.
# If unset, no email address will be added to the account.
#
# extra_attributes: a map of Jinja2 templates for extra attributes # extra_attributes: a map of Jinja2 templates for extra attributes
# to send back to the client during login. # to send back to the client during login.
# Note that these are non-standard and clients will ignore them # Note that these are non-standard and clients will ignore them
@@ -195,12 +189,6 @@ class OIDCConfig(Config):
# userinfo_endpoint: "https://accounts.example.com/userinfo" # userinfo_endpoint: "https://accounts.example.com/userinfo"
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json" # jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# skip_verification: true # skip_verification: true
# user_mapping_provider:
# config:
# subject_claim: "id"
# localpart_template: "{{ user.login }}"
# display_name_template: "{{ user.name }}"
# email_template: "{{ user.email }}"
# For use with Keycloak # For use with Keycloak
# #
@@ -215,7 +203,6 @@ class OIDCConfig(Config):
# #
#- idp_id: github #- idp_id: github
# idp_name: Github # idp_name: Github
# idp_brand: org.matrix.github
# discover: false # discover: false
# issuer: "https://github.com/" # issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED # client_id: "your-client-id" # TO BE FILLED
@@ -239,22 +226,11 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "object", "type": "object",
"required": ["issuer", "client_id", "client_secret"], "required": ["issuer", "client_id", "client_secret"],
"properties": { "properties": {
"idp_id": { # TODO: fix the maxLength here depending on what MSC2528 decides
"type": "string", # remember that we prefix the ID given here with `oidc-`
"minLength": 1, "idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
# MSC2858 allows a maxlen of 255, but we prefix with "oidc-"
"maxLength": 250,
"pattern": "^[A-Za-z0-9._~-]+$",
},
"idp_name": {"type": "string"}, "idp_name": {"type": "string"},
"idp_icon": {"type": "string"}, "idp_icon": {"type": "string"},
"idp_brand": {
"type": "string",
# MSC2758-style namespaced identifier
"minLength": 1,
"maxLength": 255,
"pattern": "^[a-z][a-z0-9_.-]*$",
},
"discover": {"type": "boolean"}, "discover": {"type": "boolean"},
"issuer": {"type": "string"}, "issuer": {"type": "string"},
"client_id": {"type": "string"}, "client_id": {"type": "string"},
@@ -373,8 +349,25 @@ def _parse_oidc_config_dict(
config_path + ("user_mapping_provider", "module"), config_path + ("user_mapping_provider", "module"),
) )
# MSC2858 will apply certain limits in what can be used as an IdP id, so let's
# enforce those limits now.
# TODO: factor out this stuff to a generic function
idp_id = oidc_config.get("idp_id", "oidc") idp_id = oidc_config.get("idp_id", "oidc")
# TODO: update this validity check based on what MSC2858 decides.
valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._")
if any(c not in valid_idp_chars for c in idp_id):
raise ConfigError(
'idp_id may only contain a-z, 0-9, "-", ".", "_"',
config_path + ("idp_id",),
)
if idp_id[0] not in string.ascii_lowercase:
raise ConfigError(
"idp_id must start with a-z", config_path + ("idp_id",),
)
# prefix the given IDP with a prefix specific to the SSO mechanism, to avoid # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid
# clashes with other mechs (such as SAML, CAS). # clashes with other mechs (such as SAML, CAS).
# #
@@ -400,7 +393,6 @@ def _parse_oidc_config_dict(
idp_id=idp_id, idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"), idp_name=oidc_config.get("idp_name", "OIDC"),
idp_icon=idp_icon, idp_icon=idp_icon,
idp_brand=oidc_config.get("idp_brand"),
discover=oidc_config.get("discover", True), discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"], issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"], client_id=oidc_config["client_id"],
@@ -431,9 +423,6 @@ class OidcProviderConfig:
# Optional MXC URI for icon for this IdP. # Optional MXC URI for icon for this IdP.
idp_icon = attr.ib(type=Optional[str]) idp_icon = attr.ib(type=Optional[str])
# Optional brand identifier for this IdP.
idp_brand = attr.ib(type=Optional[str])
# whether the OIDC discovery mechanism is used to discover endpoints # whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool) discover = attr.ib(type=bool)

View File

@@ -24,7 +24,7 @@ class RateLimitConfig:
defaults={"per_second": 0.17, "burst_count": 3.0}, defaults={"per_second": 0.17, "burst_count": 3.0},
): ):
self.per_second = config.get("per_second", defaults["per_second"]) self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = int(config.get("burst_count", defaults["burst_count"])) self.burst_count = config.get("burst_count", defaults["burst_count"])
class FederationRateLimitConfig: class FederationRateLimitConfig:
@@ -102,20 +102,6 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.01, "burst_count": 3}, defaults={"per_second": 0.01, "burst_count": 3},
) )
self.rc_3pid_validation = RateLimitConfig(
config.get("rc_3pid_validation") or {},
defaults={"per_second": 0.003, "burst_count": 5},
)
self.rc_invites_per_room = RateLimitConfig(
config.get("rc_invites", {}).get("per_room", {}),
defaults={"per_second": 0.3, "burst_count": 10},
)
self.rc_invites_per_user = RateLimitConfig(
config.get("rc_invites", {}).get("per_user", {}),
defaults={"per_second": 0.003, "burst_count": 5},
)
def generate_config_section(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
## Ratelimiting ## ## Ratelimiting ##
@@ -145,9 +131,6 @@ class RatelimitConfig(Config):
# users are joining rooms the server is already in (this is cheap) vs # users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which # "remote" for when users are trying to join rooms not on the server (which
# can be more expensive) # can be more expensive)
# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
# - two for ratelimiting how often invites can be sent in a room or to a
# specific user.
# #
# The defaults are as shown below. # The defaults are as shown below.
# #
@@ -181,18 +164,7 @@ class RatelimitConfig(Config):
# remote: # remote:
# per_second: 0.01 # per_second: 0.01
# burst_count: 3 # burst_count: 3
#
#rc_3pid_validation:
# per_second: 0.003
# burst_count: 5
#
#rc_invites:
# per_room:
# per_second: 0.3
# burst_count: 10
# per_user:
# per_second: 0.003
# burst_count: 5
# Ratelimiting settings for incoming federation # Ratelimiting settings for incoming federation
# #

View File

@@ -176,7 +176,9 @@ class RegistrationConfig(Config):
self.session_lifetime = session_lifetime self.session_lifetime = session_lifetime
# The success template used during fallback auth. # The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html") self.fallback_success_template = self.read_templates(
["auth_success.html"], autoescape=True
)[0]
def generate_config_section(self, generate_secrets=False, **kwargs): def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets: if generate_secrets:

View File

@@ -194,8 +194,8 @@ class SAML2Config(Config):
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes optional_attributes -= required_attributes
metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml" metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
response_url = public_baseurl + "_synapse/client/saml2/authn_response" response_url = public_baseurl + "_matrix/saml2/authn_response"
return { return {
"entityid": metadata_url, "entityid": metadata_url,
"service": { "service": {
@@ -233,10 +233,10 @@ class SAML2Config(Config):
# enable SAML login. # enable SAML login.
# #
# Once SAML support is enabled, a metadata file will be exposed at # Once SAML support is enabled, a metadata file will be exposed at
# https://<server>:<port>/_synapse/client/saml2/metadata.xml, which you may be able to # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
# use to configure your SAML IdP with. Alternatively, you can manually configure # use to configure your SAML IdP with. Alternatively, you can manually configure
# the IdP to use an ACS location of # the IdP to use an ACS location of
# https://<server>:<port>/_synapse/client/saml2/authn_response. # https://<server>:<port>/_matrix/saml2/authn_response.
# #
saml2_config: saml2_config:
# `sp_config` is the configuration for the pysaml2 Service Provider. # `sp_config` is the configuration for the pysaml2 Service Provider.

View File

@@ -27,7 +27,7 @@ class SSOConfig(Config):
sso_config = config.get("sso") or {} # type: Dict[str, Any] sso_config = config.get("sso") or {} # type: Dict[str, Any]
# The sso-specific template_dir # The sso-specific template_dir
self.sso_template_dir = sso_config.get("template_dir") template_dir = sso_config.get("template_dir")
# Read templates from disk # Read templates from disk
( (
@@ -48,7 +48,7 @@ class SSOConfig(Config):
"sso_auth_success.html", "sso_auth_success.html",
"sso_auth_bad_user.html", "sso_auth_bad_user.html",
], ],
self.sso_template_dir, template_dir,
) )
# These templates have no placeholders, so render them here # These templates have no placeholders, so render them here
@@ -113,13 +113,8 @@ class SSOConfig(Config):
# #
# * providers: a list of available Identity Providers. Each element is # * providers: a list of available Identity Providers. Each element is
# an object with the following attributes: # an object with the following attributes:
#
# * idp_id: unique identifier for the IdP # * idp_id: unique identifier for the IdP
# * idp_name: user-facing name for the IdP # * idp_name: user-facing name for the IdP
# * idp_icon: if specified in the IdP config, an MXC URI for an icon
# for the IdP
# * idp_brand: if specified in the IdP config, a textual identifier
# for the brand of the IdP
# #
# The rendered HTML page should contain a form which submits its results # The rendered HTML page should contain a form which submits its results
# back as a GET request, with the following query parameters: # back as a GET request, with the following query parameters:
@@ -129,62 +124,10 @@ class SSOConfig(Config):
# #
# * idp: the 'idp_id' of the chosen IDP. # * idp: the 'idp_id' of the chosen IDP.
# #
# * HTML page to prompt new users to enter a userid and confirm other
# details: 'sso_auth_account_details.html'. This is only shown if the
# SSO implementation (with any user_mapping_provider) does not return
# a localpart.
#
# When rendering, this template is given the following variables:
#
# * server_name: the homeserver's name.
#
# * idp: details of the SSO Identity Provider that the user logged in
# with: an object with the following attributes:
#
# * idp_id: unique identifier for the IdP
# * idp_name: user-facing name for the IdP
# * idp_icon: if specified in the IdP config, an MXC URI for an icon
# for the IdP
# * idp_brand: if specified in the IdP config, a textual identifier
# for the brand of the IdP
#
# * user_attributes: an object containing details about the user that
# we received from the IdP. May have the following attributes:
#
# * display_name: the user's display_name
# * emails: a list of email addresses
#
# The template should render a form which submits the following fields:
#
# * username: the localpart of the user's chosen user id
#
# * HTML page allowing the user to consent to the server's terms and
# conditions. This is only shown for new users, and only if
# `user_consent.require_at_registration` is set.
#
# When rendering, this template is given the following variables:
#
# * server_name: the homeserver's name.
#
# * user_id: the user's matrix proposed ID.
#
# * user_profile.display_name: the user's proposed display name, if any.
#
# * consent_version: the version of the terms that the user will be
# shown
#
# * terms_url: a link to the page showing the terms.
#
# The template should render a form which submits the following fields:
#
# * accepted_version: the version of the terms accepted by the user
# (ie, 'consent_version' from the input variables).
#
# * HTML page for a confirmation step before redirecting back to the client # * HTML page for a confirmation step before redirecting back to the client
# with the login token: 'sso_redirect_confirm.html'. # with the login token: 'sso_redirect_confirm.html'.
# #
# When rendering, this template is given the following variables: # When rendering, this template is given three variables:
#
# * redirect_url: the URL the user is about to be redirected to. Needs # * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see # manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
@@ -197,17 +140,6 @@ class SSOConfig(Config):
# #
# * server_name: the homeserver's name. # * server_name: the homeserver's name.
# #
# * new_user: a boolean indicating whether this is the user's first time
# logging in.
#
# * user_id: the user's matrix ID.
#
# * user_profile.avatar_url: an MXC URI for the user's avatar, if any.
# None if the user has not set an avatar.
#
# * user_profile.display_name: the user's display name. None if the user
# has not set a display name.
#
# * HTML page which notifies the user that they are authenticating to confirm # * HTML page which notifies the user that they are authenticating to confirm
# an operation on their account during the user interactive authentication # an operation on their account during the user interactive authentication
# process: 'sso_auth_confirm.html'. # process: 'sso_auth_confirm.html'.
@@ -219,16 +151,6 @@ class SSOConfig(Config):
# #
# * description: the operation which the user is being asked to confirm # * description: the operation which the user is being asked to confirm
# #
# * idp: details of the Identity Provider that we will use to confirm
# the user's identity: an object with the following attributes:
#
# * idp_id: unique identifier for the IdP
# * idp_name: user-facing name for the IdP
# * idp_icon: if specified in the IdP config, an MXC URI for an icon
# for the IdP
# * idp_brand: if specified in the IdP config, a textual identifier
# for the brand of the IdP
#
# * HTML page shown after a successful user interactive authentication session: # * HTML page shown after a successful user interactive authentication session:
# 'sso_auth_success.html'. # 'sso_auth_success.html'.
# #

View File

@@ -125,24 +125,19 @@ class FederationPolicyForHTTPS:
self._no_verify_ssl_context = _no_verify_ssl.getContext() self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb) self._no_verify_ssl_context.set_info_callback(_context_info_cb)
self._should_verify = self._config.federation_verify_certificates
self._federation_certificate_verification_whitelist = (
self._config.federation_certificate_verification_whitelist
)
def get_options(self, host: bytes): def get_options(self, host: bytes):
# IPolicyForHTTPS.get_options takes bytes, but we want to compare # IPolicyForHTTPS.get_options takes bytes, but we want to compare
# against the str whitelist. The hostnames in the whitelist are already # against the str whitelist. The hostnames in the whitelist are already
# IDNA-encoded like the hosts will be here. # IDNA-encoded like the hosts will be here.
ascii_host = host.decode("ascii") ascii_host = host.decode("ascii")
# Check if certificate verification has been enabled # Check if certificate verification has been enabled
should_verify = self._should_verify should_verify = self._config.federation_verify_certificates
# Check if we've disabled certificate verification for this host # Check if we've disabled certificate verification for this host
if self._should_verify: if should_verify:
for regex in self._federation_certificate_verification_whitelist: for regex in self._config.federation_certificate_verification_whitelist:
if regex.match(ascii_host): if regex.match(ascii_host):
should_verify = False should_verify = False
break break

View File

@@ -810,7 +810,7 @@ class FederationClient(FederationBase):
"User's homeserver does not support this room version", "User's homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION, Codes.UNSUPPORTED_ROOM_VERSION,
) )
elif e.code in (403, 429): elif e.code == 403:
raise e.to_synapse_error() raise e.to_synapse_error()
else: else:
raise raise

View File

@@ -142,8 +142,6 @@ class FederationSender:
self._wake_destinations_needing_catchup, self._wake_destinations_needing_catchup,
) )
self._external_cache = hs.get_external_cache()
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination """Get or create a PerDestinationQueue for the given destination
@@ -199,24 +197,6 @@ class FederationSender:
if not event.internal_metadata.should_proactively_send(): if not event.internal_metadata.should_proactively_send():
return return
destinations = None # type: Optional[Set[str]]
if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty
# and so no remote servers in the room
destinations = set()
else:
# We check the external cache for the destinations, which is
# stored per state group.
sg = await self._external_cache.get(
"event_to_prev_state_group", event.event_id
)
if sg:
destinations = await self._external_cache.get(
"get_joined_hosts", str(sg)
)
if destinations is None:
try: try:
# Get the state from before the event. # Get the state from before the event.
# We need to make sure that this is the state from before # We need to make sure that this is the state from before

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
import twisted import twisted
import twisted.internet.error import twisted.internet.error
@@ -23,9 +22,6 @@ from twisted.web.resource import Resource
from synapse.app import check_bind_error from synapse.app import check_bind_error
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ACME_REGISTER_FAIL_ERROR = """ ACME_REGISTER_FAIL_ERROR = """
@@ -39,12 +35,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
class AcmeHandler: class AcmeHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
self.hs = hs self.hs = hs
self.reactor = hs.get_reactor() self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain self._acme_domain = hs.config.acme_domain
async def start_listening(self) -> None: async def start_listening(self):
from synapse.handlers import acme_issuing_service from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug # Configure logging for txacme, if you need to debug
@@ -89,7 +85,7 @@ class AcmeHandler:
logger.error(ACME_REGISTER_FAIL_ERROR) logger.error(ACME_REGISTER_FAIL_ERROR)
raise raise
async def provision_certificate(self) -> None: async def provision_certificate(self):
logger.warning("Reprovisioning %s", self._acme_domain) logger.warning("Reprovisioning %s", self._acme_domain)
@@ -114,3 +110,5 @@ class AcmeHandler:
except Exception: except Exception:
logger.exception("Failed saving!") logger.exception("Failed saving!")
raise raise
return True

View File

@@ -22,10 +22,8 @@ only need (and may only have available) if we are doing ACME, so is designed to
imported conditionally. imported conditionally.
""" """
import logging import logging
from typing import Dict, Iterable, List
import attr import attr
import pem
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from josepy import JWKRSA from josepy import JWKRSA
@@ -38,27 +36,20 @@ from txacme.util import generate_private_key
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.interfaces import IReactorTCP
from twisted.python.filepath import FilePath from twisted.python.filepath import FilePath
from twisted.python.url import URL from twisted.python.url import URL
from twisted.web.resource import IResource
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_issuing_service( def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
reactor: IReactorTCP,
acme_url: str,
account_key_file: str,
well_known_resource: IResource,
) -> AcmeIssuingService:
"""Create an ACME issuing service, and attach it to a web Resource """Create an ACME issuing service, and attach it to a web Resource
Args: Args:
reactor: twisted reactor reactor: twisted reactor
acme_url: URL to use to request certificates acme_url (str): URL to use to request certificates
account_key_file: where to store the account key account_key_file (str): where to store the account key
well_known_resource: web resource for .well-known. well_known_resource (twisted.web.IResource): web resource for .well-known.
we will attach a child resource for "acme-challenge". we will attach a child resource for "acme-challenge".
Returns: Returns:
@@ -92,20 +83,18 @@ class ErsatzStore:
A store that only stores in memory. A store that only stores in memory.
""" """
certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict)) certs = attr.ib(default=attr.Factory(dict))
def store( def store(self, server_name, pem_objects):
self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
) -> defer.Deferred:
self.certs[server_name] = [o.as_bytes() for o in pem_objects] self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None) return defer.succeed(None)
def load_or_create_client_key(key_file: str) -> JWKRSA: def load_or_create_client_key(key_file):
"""Load the ACME account key from a file, creating it if it does not exist. """Load the ACME account key from a file, creating it if it does not exist.
Args: Args:
key_file: name of the file to use as the account key key_file (str): name of the file to use as the account key
""" """
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't # this is based on txacme.endpoint.load_or_create_client_key, but doesn't
# hardcode the 'client.key' filename # hardcode the 'client.key' filename

View File

@@ -61,7 +61,6 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
@@ -568,6 +567,16 @@ class AuthHandler(BaseHandler):
session.session_id, login_type, result session.session_id, login_type, result
) )
except LoginError as e: except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
# got a 401 with a 'flows' field.
# (https://github.com/vector-im/vector-web/issues/2447).
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
raise
# this step failed. Merge the error dict into the response # this step failed. Merge the error dict into the response
# so that the client can have another go. # so that the client can have another go.
errordict = e.error_dict() errordict = e.error_dict()
@@ -1378,9 +1387,7 @@ class AuthHandler(BaseHandler):
) )
return self._sso_auth_confirm_template.render( return self._sso_auth_confirm_template.render(
description=session.description, description=session.description, redirect_url=redirect_url,
redirect_url=redirect_url,
idp=sso_auth_provider,
) )
async def complete_sso_login( async def complete_sso_login(
@@ -1389,7 +1396,6 @@ class AuthHandler(BaseHandler):
request: Request, request: Request,
client_redirect_url: str, client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None, extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
): ):
"""Having figured out a mxid for this user, complete the HTTP request """Having figured out a mxid for this user, complete the HTTP request
@@ -1400,8 +1406,6 @@ class AuthHandler(BaseHandler):
process. process.
extra_attributes: Extra attributes which will be passed to the client extra_attributes: Extra attributes which will be passed to the client
during successful login. Must be JSON serializable. during successful login. Must be JSON serializable.
new_user: True if we should use wording appropriate to a user who has just
registered.
""" """
# If the account has been deactivated, do not proceed with the login # If the account has been deactivated, do not proceed with the login
# flow. # flow.
@@ -1410,17 +1414,8 @@ class AuthHandler(BaseHandler):
respond_with_html(request, 403, self._sso_account_deactivated_template) respond_with_html(request, 403, self._sso_account_deactivated_template)
return return
profile = await self.store.get_profileinfo(
UserID.from_string(registered_user_id).localpart
)
self._complete_sso_login( self._complete_sso_login(
registered_user_id, registered_user_id, request, client_redirect_url, extra_attributes
request,
client_redirect_url,
extra_attributes,
new_user=new_user,
user_profile_data=profile,
) )
def _complete_sso_login( def _complete_sso_login(
@@ -1429,18 +1424,12 @@ class AuthHandler(BaseHandler):
request: Request, request: Request,
client_redirect_url: str, client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None, extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
): ):
""" """
The synchronous portion of complete_sso_login. The synchronous portion of complete_sso_login.
This exists purely for backwards compatibility of synapse.module_api.ModuleApi. This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
""" """
if user_profile_data is None:
user_profile_data = ProfileInfo(None, None)
# Store any extra attributes which will be passed in the login response. # Store any extra attributes which will be passed in the login response.
# Note that this is per-user so it may overwrite a previous value, this # Note that this is per-user so it may overwrite a previous value, this
# is considered OK since the newest SSO attributes should be most valid. # is considered OK since the newest SSO attributes should be most valid.
@@ -1478,9 +1467,6 @@ class AuthHandler(BaseHandler):
display_url=redirect_url_no_params, display_url=redirect_url_no_params,
redirect_url=redirect_url, redirect_url=redirect_url,
server_name=self._server_name, server_name=self._server_name,
new_user=new_user,
user_id=registered_user_id,
user_profile=user_profile_data,
) )
respond_with_html(request, 200, html) respond_with_html(request, 200, html)

View File

@@ -80,10 +80,9 @@ class CasHandler:
# user-facing name of this auth provider # user-facing name of this auth provider
self.idp_name = "CAS" self.idp_name = "CAS"
# we do not currently support brands/icons for CAS auth, but this is required by # we do not currently support icons for CAS auth, but this is required by
# the SsoIdentityProvider protocol type. # the SsoIdentityProvider protocol type.
self.idp_icon = None self.idp_icon = None
self.idp_brand = None
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
@@ -100,7 +99,11 @@ class CasHandler:
Returns: Returns:
The URL to use as a "service" parameter. The URL to use as a "service" parameter.
""" """
return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),) return "%s%s?%s" % (
self._cas_service_url,
"/_matrix/client/r0/login/cas/ticket",
urllib.parse.urlencode(args),
)
async def _validate_ticket( async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str] self, ticket: str, service_args: Dict[str, str]

View File

@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@trace @trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
""" """
Retrieve the given user's devices Retrieve the given user's devices
@@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
return devices return devices
@trace @trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict: async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
""" Retrieve the given device """ Retrieve the given device
Args: Args:
@@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips( def _update_device_from_client_ips(
device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict] device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
) -> None: ) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -946,8 +946,8 @@ class DeviceListUpdater:
async def process_cross_signing_key_update( async def process_cross_signing_key_update(
self, self,
user_id: str, user_id: str,
master_key: Optional[JsonDict], master_key: Optional[Dict[str, Any]],
self_signing_key: Optional[JsonDict], self_signing_key: Optional[Dict[str, Any]],
) -> List[str]: ) -> List[str]:
"""Process the given new master and self-signing key for the given remote user. """Process the given new master and self-signing key for the given remote user.

View File

@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import attr import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@@ -31,7 +31,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import ( from synapse.types import (
JsonDict,
UserID, UserID,
get_domain_from_id, get_domain_from_id,
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
@@ -41,14 +40,11 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class E2eKeysHandler: class E2eKeysHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@@ -82,9 +78,7 @@ class E2eKeysHandler:
) )
@trace @trace
async def query_devices( async def query_devices(self, query_body, timeout, from_user_id):
self, query_body: JsonDict, timeout: int, from_user_id: str
) -> JsonDict:
""" Handle a device key query from a client """ Handle a device key query from a client
{ {
@@ -104,14 +98,12 @@ class E2eKeysHandler:
} }
Args: Args:
from_user_id: the user making the query. This is used when from_user_id (str): the user making the query. This is used when
adding cross-signing signatures to limit what signatures users adding cross-signing signatures to limit what signatures users
can see. can see.
""" """
device_keys_query = query_body.get( device_keys_query = query_body.get("device_keys", {})
"device_keys", {}
) # type: Dict[str, Iterable[str]]
# separate users by domain. # separate users by domain.
# make a map from domain to user_id to device_ids # make a map from domain to user_id to device_ids
@@ -129,8 +121,7 @@ class E2eKeysHandler:
set_tag("remote_key_query", remote_queries) set_tag("remote_key_query", remote_queries)
# First get local devices. # First get local devices.
# A map of destination -> failure response. failures = {}
failures = {} # type: Dict[str, JsonDict]
results = {} results = {}
if local_query: if local_query:
local_result = await self.query_local_devices(local_query) local_result = await self.query_local_devices(local_query)
@@ -144,10 +135,9 @@ class E2eKeysHandler:
) )
# Now attempt to get any remote devices from our local cache. # Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs. remote_queries_not_in_cache = {}
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries: if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]] query_list = []
for user_id, device_ids in remote_queries.items(): for user_id, device_ids in remote_queries.items():
if device_ids: if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids) query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -294,15 +284,15 @@ class E2eKeysHandler:
return ret return ret
async def get_cross_signing_keys_from_cache( async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str] self, query, from_user_id
) -> Dict[str, Dict[str, dict]]: ) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database """Get cross-signing keys for users from the database
Args: Args:
query: an iterable of user IDs. A dict whose keys query (Iterable[string]) an iterable of user IDs. A dict whose keys
are user IDs satisfies this, so the query format used for are user IDs satisfies this, so the query format used for
query_devices can be used here. query_devices can be used here.
from_user_id: the user making the query. This is used when from_user_id (str): the user making the query. This is used when
adding cross-signing signatures to limit what signatures users adding cross-signing signatures to limit what signatures users
can see. can see.
@@ -325,12 +315,14 @@ class E2eKeysHandler:
if "self_signing" in user_info: if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"] self_signing_keys[user_id] = user_info["self_signing"]
if (
from_user_id in keys
and keys[from_user_id] is not None
and "user_signing" in keys[from_user_id]
):
# users can see other users' master and self-signing keys, but can # users can see other users' master and self-signing keys, but can
# only see their own user-signing keys # only see their own user-signing keys
if from_user_id: user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
from_user_key = keys.get(from_user_id)
if from_user_key and "user_signing" in from_user_key:
user_signing_keys[from_user_id] = from_user_key["user_signing"]
return { return {
"master_keys": master_keys, "master_keys": master_keys,
@@ -352,9 +344,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details A map from user_id -> device_id -> device details
""" """
set_tag("local_query", query) set_tag("local_query", query)
local_query = [] # type: List[Tuple[str, Optional[str]]] local_query = []
result_dict = {} # type: Dict[str, Dict[str, dict]] result_dict = {}
for user_id, device_ids in query.items(): for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)): if not self.is_mine(UserID.from_string(user_id)):
@@ -388,14 +380,10 @@ class E2eKeysHandler:
log_kv(results) log_kv(results)
return result_dict return result_dict
async def on_federation_query_client_keys( async def on_federation_query_client_keys(self, query_body):
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
""" Handle a device key query from a federated server """ Handle a device key query from a federated server
""" """
device_keys_query = query_body.get( device_keys_query = query_body.get("device_keys", {})
"device_keys", {}
) # type: Dict[str, Optional[List[str]]]
res = await self.query_local_devices(device_keys_query) res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res} ret = {"device_keys": res}
@@ -409,34 +397,31 @@ class E2eKeysHandler:
return ret return ret
@trace @trace
async def claim_one_time_keys( async def claim_one_time_keys(self, query, timeout):
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int local_query = []
) -> JsonDict: remote_queries = {}
local_query = [] # type: List[Tuple[str, str, str]]
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
for user_id, one_time_keys in query.get("one_time_keys", {}).items(): for user_id, device_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)): if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in one_time_keys.items(): for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm)) local_query.append((user_id, device_id, algorithm))
else: else:
domain = get_domain_from_id(user_id) domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys remote_queries.setdefault(domain, {})[user_id] = device_keys
set_tag("local_key_query", local_query) set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries) set_tag("remote_key_query", remote_queries)
results = await self.store.claim_e2e_one_time_keys(local_query) results = await self.store.claim_e2e_one_time_keys(local_query)
# A map of user ID -> device ID -> key ID -> key. json_result = {}
json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] failures = {}
failures = {} # type: Dict[str, JsonDict]
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
for key_id, json_str in keys.items(): for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = { json_result.setdefault(user_id, {})[device_id] = {
key_id: json_decoder.decode(json_str) key_id: json_decoder.decode(json_bytes)
} }
@trace @trace
@@ -483,9 +468,7 @@ class E2eKeysHandler:
return {"one_time_keys": json_result, "failures": failures} return {"one_time_keys": json_result, "failures": failures}
@tag_args @tag_args
async def upload_keys_for_user( async def upload_keys_for_user(self, user_id, device_id, keys):
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@@ -560,8 +543,8 @@ class E2eKeysHandler:
return {"one_time_key_counts": result} return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user( async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict self, user_id, device_id, time_now, one_time_keys
) -> None: ):
logger.info( logger.info(
"Adding one_time_keys %r for device %r for user %r at %d", "Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(), one_time_keys.keys(),
@@ -602,14 +585,12 @@ class E2eKeysHandler:
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
async def upload_signing_keys_for_user( async def upload_signing_keys_for_user(self, user_id, keys):
self, user_id: str, keys: JsonDict
) -> JsonDict:
"""Upload signing keys for cross-signing """Upload signing keys for cross-signing
Args: Args:
user_id: the user uploading the keys user_id (string): the user uploading the keys
keys: the signing keys keys (dict[string, dict]): the signing keys
""" """
# if a master key is uploaded, then check it. Otherwise, load the # if a master key is uploaded, then check it. Otherwise, load the
@@ -686,17 +667,16 @@ class E2eKeysHandler:
return {} return {}
async def upload_signatures_for_device_keys( async def upload_signatures_for_device_keys(self, user_id, signatures):
self, user_id: str, signatures: JsonDict
) -> JsonDict:
"""Upload device signatures for cross-signing """Upload device signatures for cross-signing
Args: Args:
user_id: the user uploading the signatures user_id (string): the user uploading the signatures
signatures: map of users to devices to signed keys. This is the submission signatures (dict[string, dict[string, dict]]): map of users to
from the user; an exception will be raised if it is malformed. devices to signed keys. This is the submission from the user; an
exception will be raised if it is malformed.
Returns: Returns:
The response to be sent back to the client. The response will have dict: response to be sent back to the client. The response will have
a "failures" key, which will be a dict mapping users to devices a "failures" key, which will be a dict mapping users to devices
to errors for the signatures that failed. to errors for the signatures that failed.
Raises: Raises:
@@ -739,9 +719,7 @@ class E2eKeysHandler:
return {"failures": failures} return {"failures": failures}
async def _process_self_signatures( async def _process_self_signatures(self, user_id, signatures):
self, user_id: str, signatures: JsonDict
) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of the user's own keys. """Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms: Signatures of the user's own keys from this API come in two forms:
@@ -753,14 +731,15 @@ class E2eKeysHandler:
signatures (dict[string, dict]): map of devices to signed keys signatures (dict[string, dict]): map of devices to signed keys
Returns: Returns:
A tuple of a list of signatures to store, and a map of users to (list[SignatureListItem], dict[string, dict[string, dict]]):
devices to failure reasons a list of signatures to store, and a map of users to devices to failure
reasons
Raises: Raises:
SynapseError: if the input is malformed SynapseError: if the input is malformed
""" """
signature_list = [] # type: List[SignatureListItem] signature_list = []
failures = {} # type: Dict[str, Dict[str, JsonDict]] failures = {}
if not signatures: if not signatures:
return signature_list, failures return signature_list, failures
@@ -855,24 +834,19 @@ class E2eKeysHandler:
return signature_list, failures return signature_list, failures
def _check_master_key_signature( def _check_master_key_signature(
self, self, user_id, master_key_id, signed_master_key, stored_master_key, devices
user_id: str, ):
master_key_id: str,
signed_master_key: JsonDict,
stored_master_key: JsonDict,
devices: Dict[str, Dict[str, JsonDict]],
) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices. """Check signatures of a user's master key made by their devices.
Args: Args:
user_id: the user whose master key is being checked user_id (string): the user whose master key is being checked
master_key_id: the ID of the user's master key master_key_id (string): the ID of the user's master key
signed_master_key: the user's signed master key that was uploaded signed_master_key (dict): the user's signed master key that was uploaded
stored_master_key: our previously-stored copy of the user's master key stored_master_key (dict): our previously-stored copy of the user's master key
devices: the user's devices devices (iterable(dict)): the user's devices
Returns: Returns:
A list of signatures to store list[SignatureListItem]: a list of signatures to store
Raises: Raises:
SynapseError: if a signature is invalid SynapseError: if a signature is invalid
@@ -903,26 +877,25 @@ class E2eKeysHandler:
return master_key_signature_list return master_key_signature_list
async def _process_other_signatures( async def _process_other_signatures(self, user_id, signatures):
self, user_id: str, signatures: Dict[str, dict]
) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of other users' keys. These will be the """Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing target user's master keys, signed by the uploading user's user-signing
key. key.
Args: Args:
user_id: the user uploading the keys user_id (string): the user uploading the keys
signatures: map of users to devices to signed keys signatures (dict[string, dict]): map of users to devices to signed keys
Returns: Returns:
A list of signatures to store, and a map of users to devices to failure (list[SignatureListItem], dict[string, dict[string, dict]]):
a list of signatures to store, and a map of users to devices to failure
reasons reasons
Raises: Raises:
SynapseError: if the input is malformed SynapseError: if the input is malformed
""" """
signature_list = [] # type: List[SignatureListItem] signature_list = []
failures = {} # type: Dict[str, Dict[str, JsonDict]] failures = {}
if not signatures: if not signatures:
return signature_list, failures return signature_list, failures
@@ -1010,7 +983,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key( async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None self, user_id: str, key_type: str, from_user_id: str = None
) -> Tuple[JsonDict, str, VerifyKey]: ):
"""Fetch locally or remotely query for a cross-signing public key. """Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage. First, attempt to fetch the cross-signing public key from storage.
@@ -1024,7 +997,8 @@ class E2eKeysHandler:
This affects what signatures are fetched. This affects what signatures are fetched.
Returns: Returns:
The raw key data, the key ID, and the signedjson verify key dict, str, VerifyKey: the raw key data, the key ID, and the
signedjson verify key
Raises: Raises:
NotFoundError: if the key is not found NotFoundError: if the key is not found
@@ -1161,18 +1135,16 @@ class E2eKeysHandler:
return desired_key, desired_key_id, desired_verify_key return desired_key, desired_key_id, desired_verify_key
def _check_cross_signing_key( def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
) -> None:
"""Check a cross-signing key uploaded by a user. Performs some basic sanity """Check a cross-signing key uploaded by a user. Performs some basic sanity
checking, and ensures that it is signed, if a signature is required. checking, and ensures that it is signed, if a signature is required.
Args: Args:
key: the key data to verify key (dict): the key data to verify
user_id: the user whose key is being checked user_id (str): the user whose key is being checked
key_type: the type of key that the key should be key_type (str): the type of key that the key should be
signing_key: the signing key that the key should be signed with. If signing_key (VerifyKey): (optional) the signing key that the key should
omitted, signatures will not be checked. be signed with. If omitted, signatures will not be checked.
""" """
if ( if (
key.get("user_id") != user_id key.get("user_id") != user_id
@@ -1190,21 +1162,16 @@ def _check_cross_signing_key(
) )
def _check_device_signature( def _check_device_signature(user_id, verify_key, signed_device, stored_device):
user_id: str,
verify_key: VerifyKey,
signed_device: JsonDict,
stored_device: JsonDict,
) -> None:
"""Check that a signature on a device or cross-signing key is correct and """Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an matches the copy of the device/key that we have stored. Throws an
exception if an error is detected. exception if an error is detected.
Args: Args:
user_id: the user ID whose signature is being checked user_id (str): the user ID whose signature is being checked
verify_key: the key to verify the device with verify_key (VerifyKey): the key to verify the device with
signed_device: the uploaded signed device data signed_device (dict): the uploaded signed device data
stored_device: our previously stored copy of the device stored_device (dict): our previously stored copy of the device
Raises: Raises:
SynapseError: if the signature was invalid or the sent device is not the SynapseError: if the signature was invalid or the sent device is not the
@@ -1234,7 +1201,7 @@ def _check_device_signature(
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
def _exception_to_failure(e: Exception) -> JsonDict: def _exception_to_failure(e):
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
return {"status": e.code, "errcode": e.errcode, "message": str(e)} return {"status": e.code, "errcode": e.errcode, "message": str(e)}
@@ -1251,7 +1218,7 @@ def _exception_to_failure(e: Exception) -> JsonDict:
return {"status": 503, "message": str(e)} return {"status": 503, "message": str(e)}
def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: def _one_time_keys_match(old_key_json, new_key):
old_key = json_decoder.decode(old_key_json) old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly # if either is a string rather than an object, they must match exactly
@@ -1272,16 +1239,16 @@ class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys. """An item in the signature list as used by upload_signatures_for_device_keys.
""" """
signing_key_id = attr.ib(type=str) signing_key_id = attr.ib()
target_user_id = attr.ib(type=str) target_user_id = attr.ib()
target_device_id = attr.ib(type=str) target_device_id = attr.ib()
signature = attr.ib(type=JsonDict) signature = attr.ib()
class SigningKeyEduUpdater: class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB""" """Handles incoming signing key updates from federation and updates the DB"""
def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): def __init__(self, hs, e2e_keys_handler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@@ -1290,7 +1257,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key") self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled. # user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] self._pending_updates = {}
# Recently seen stream ids. We don't bother keeping these in the DB, # Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious # but they're useful to have them about to reduce the number of spurious
@@ -1303,15 +1270,13 @@ class SigningKeyEduUpdater:
iterable=True, iterable=True,
) )
async def incoming_signing_key_update( async def incoming_signing_key_update(self, origin, edu_content):
self, origin: str, edu_content: JsonDict
) -> None:
"""Called on incoming signing key update from federation. Responsible for """Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list. parsing the EDU and adding to pending updates list.
Args: Args:
origin: the server that sent the EDU origin (string): the server that sent the EDU
edu_content: the contents of the EDU edu_content (dict): the contents of the EDU
""" """
user_id = edu_content.pop("user_id") user_id = edu_content.pop("user_id")
@@ -1334,11 +1299,11 @@ class SigningKeyEduUpdater:
await self._handle_signing_key_updates(user_id) await self._handle_signing_key_updates(user_id)
async def _handle_signing_key_updates(self, user_id: str) -> None: async def _handle_signing_key_updates(self, user_id):
"""Actually handle pending updates. """Actually handle pending updates.
Args: Args:
user_id: the user whose updates we are processing user_id (string): the user whose updates we are processing
""" """
device_handler = self.e2e_keys_handler.device_handler device_handler = self.e2e_keys_handler.device_handler
@@ -1350,7 +1315,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates # This can happen since we batch updates
return return
device_ids = [] # type: List[str] device_ids = []
logger.info("pending updates: %r", pending_updates) logger.info("pending updates: %r", pending_updates)

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, List, Optional
from synapse.api.errors import ( from synapse.api.errors import (
Codes, Codes,
@@ -25,12 +24,8 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.logging.opentracing import log_kv, trace from synapse.logging.opentracing import log_kv, trace
from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,7 +37,7 @@ class E2eRoomKeysHandler:
The actual payload of the encrypted keys is completely opaque to the handler. The actual payload of the encrypted keys is completely opaque to the handler.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
# Used to lock whenever a client is uploading key data. This prevents collisions # Used to lock whenever a client is uploading key data. This prevents collisions
@@ -53,27 +48,21 @@ class E2eRoomKeysHandler:
self._upload_linearizer = Linearizer("upload_room_keys_lock") self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace @trace
async def get_room_keys( async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
self,
user_id: str,
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> List[JsonDict]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given """Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session. room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details. See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
Args: Args:
user_id: the user whose keys we're getting user_id(str): the user whose keys we're getting
version: the version ID of the backup we're getting keys from version(str): the version ID of the backup we're getting keys from
room_id: room ID to get keys for, for None to get keys for all rooms room_id(string): room ID to get keys for, for None to get keys for all rooms
session_id: session ID to get keys for, for None to get keys for all session_id(string): session ID to get keys for, for None to get keys for all
sessions sessions
Raises: Raises:
NotFoundError: if the backup version does not exist NotFoundError: if the backup version does not exist
Returns: Returns:
A list of dicts giving the session_data and message metadata for A deferred list of dicts giving the session_data and message metadata for
these room keys. these room keys.
""" """
@@ -97,23 +86,17 @@ class E2eRoomKeysHandler:
return results return results
@trace @trace
async def delete_room_keys( async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
self,
user_id: str,
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> JsonDict:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session. room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
Args: Args:
user_id: the user whose backup we're deleting user_id(str): the user whose backup we're deleting
version: the version ID of the backup we're deleting version(str): the version ID of the backup we're deleting
room_id: room ID to delete keys for, for None to delete keys for all room_id(string): room ID to delete keys for, for None to delete keys for all
rooms rooms
session_id: session ID to delete keys for, for None to delete keys session_id(string): session ID to delete keys for, for None to delete keys
for all sessions for all sessions
Raises: Raises:
NotFoundError: if the backup version does not exist NotFoundError: if the backup version does not exist
@@ -145,17 +128,15 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count} return {"etag": str(version_etag), "count": count}
@trace @trace
async def upload_room_keys( async def upload_room_keys(self, user_id, version, room_keys):
self, user_id: str, version: str, room_keys: JsonDict
) -> JsonDict:
"""Bulk upload a list of room keys into a given backup version, asserting """Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT(). into the current backup as described in RoomKeysServlet.on_PUT().
Args: Args:
user_id: the user whose backup we're setting user_id(str): the user whose backup we're setting
version: the version ID of the backup we're updating version(str): the version ID of the backup we're updating
room_keys: a nested dict describing the room_keys we're setting: room_keys(dict): a nested dict describing the room_keys we're setting:
{ {
"rooms": { "rooms": {
@@ -273,16 +254,14 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count} return {"etag": str(version_etag), "count": count}
@staticmethod @staticmethod
def _should_replace_room_key( def _should_replace_room_key(current_room_key, room_key):
current_room_key: Optional[JsonDict], room_key: JsonDict
) -> bool:
""" """
Determine whether to replace a given current_room_key (if any) Determine whether to replace a given current_room_key (if any)
with a newly uploaded room_key backup with a newly uploaded room_key backup
Args: Args:
current_room_key: Optional, the current room_key dict if any current_room_key (dict): Optional, the current room_key dict if any
room_key : The new room_key dict which may or may not be fit to room_key (dict): The new room_key dict which may or may not be fit to
replace the current_room_key replace the current_room_key
Returns: Returns:
@@ -307,14 +286,14 @@ class E2eRoomKeysHandler:
return True return True
@trace @trace
async def create_version(self, user_id: str, version_info: JsonDict) -> str: async def create_version(self, user_id, version_info):
"""Create a new backup version. This automatically becomes the new """Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be backup version for the user's keys; previous backups will no longer be
writeable to. writeable to.
Args: Args:
user_id: the user whose backup version we're creating user_id(str): the user whose backup version we're creating
version_info: metadata about the new version being created version_info(dict): metadata about the new version being created
{ {
"algorithm": "m.megolm_backup.v1", "algorithm": "m.megolm_backup.v1",
@@ -322,7 +301,7 @@ class E2eRoomKeysHandler:
} }
Returns: Returns:
The new version number. A deferred of a string that gives the new version number.
""" """
# TODO: Validate the JSON to make sure it has the right keys. # TODO: Validate the JSON to make sure it has the right keys.
@@ -334,19 +313,17 @@ class E2eRoomKeysHandler:
) )
return new_version return new_version
async def get_version_info( async def get_version_info(self, user_id, version=None):
self, user_id: str, version: Optional[str] = None
) -> JsonDict:
"""Get the info about a given version of the user's backup """Get the info about a given version of the user's backup
Args: Args:
user_id: the user whose current backup version we're querying user_id(str): the user whose current backup version we're querying
version: Optional; if None gives the most recent version version(str): Optional; if None gives the most recent version
otherwise a historical one. otherwise a historical one.
Raises: Raises:
NotFoundError: if the requested backup version doesn't exist NotFoundError: if the requested backup version doesn't exist
Returns: Returns:
A info dict that gives the info about the new version. A deferred of a info dict that gives the info about the new version.
{ {
"version": "1234", "version": "1234",
@@ -369,7 +346,7 @@ class E2eRoomKeysHandler:
return res return res
@trace @trace
async def delete_version(self, user_id: str, version: Optional[str] = None) -> None: async def delete_version(self, user_id, version=None):
"""Deletes a given version of the user's e2e_room_keys backup """Deletes a given version of the user's e2e_room_keys backup
Args: Args:
@@ -389,19 +366,17 @@ class E2eRoomKeysHandler:
raise raise
@trace @trace
async def update_version( async def update_version(self, user_id, version, version_info):
self, user_id: str, version: str, version_info: JsonDict
) -> JsonDict:
"""Update the info about a given version of the user's backup """Update the info about a given version of the user's backup
Args: Args:
user_id: the user whose current backup version we're updating user_id(str): the user whose current backup version we're updating
version: the backup version we're updating version(str): the backup version we're updating
version_info: the new information about the backup version_info(dict): the new information about the backup
Raises: Raises:
NotFoundError: if the requested backup version doesn't exist NotFoundError: if the requested backup version doesn't exist
Returns: Returns:
An empty dict. A deferred of an empty dict.
""" """
if "version" not in version_info: if "version" not in version_info:
version_info["version"] = version version_info["version"] = version

View File

@@ -1617,10 +1617,6 @@ class FederationHandler(BaseHandler):
if event.state_key == self._server_notices_mxid: if event.state_key == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
# We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
member_handler.ratelimit_invite(event.room_id, event.state_key)
# keep a record of the room version, if we don't yet know it. # keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a # (this may get overwritten if we later get a different room version in a
# join dance). # join dance).
@@ -2097,11 +2093,6 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected: if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event) await self.maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)
return context return context
async def _check_for_soft_fail( async def _check_for_soft_fail(

View File

@@ -15,13 +15,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import GroupID, JsonDict, get_domain_from_id from synapse.types import GroupID, get_domain_from_id
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -60,7 +56,7 @@ def _create_rerouter(func_name):
class GroupsLocalWorkerHandler: class GroupsLocalWorkerHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler() self.room_list_handler = hs.get_room_list_handler()
@@ -88,9 +84,7 @@ class GroupsLocalWorkerHandler:
get_group_role = _create_rerouter("get_group_role") get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles") get_group_roles = _create_rerouter("get_group_roles")
async def get_group_summary( async def get_group_summary(self, group_id, requester_user_id):
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get the group summary for a group. """Get the group summary for a group.
If the group is remote we check that the users have valid attestations. If the group is remote we check that the users have valid attestations.
@@ -143,15 +137,14 @@ class GroupsLocalWorkerHandler:
return res return res
async def get_users_in_group( async def get_users_in_group(self, group_id, requester_user_id):
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get users in a group """Get users in a group
""" """
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
return await self.groups_server_handler.get_users_in_group( res = await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id group_id, requester_user_id
) )
return res
group_server_name = get_domain_from_id(group_id) group_server_name = get_domain_from_id(group_id)
@@ -185,11 +178,11 @@ class GroupsLocalWorkerHandler:
return res return res
async def get_joined_groups(self, user_id: str) -> JsonDict: async def get_joined_groups(self, user_id):
group_ids = await self.store.get_joined_groups(user_id) group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids} return {"groups": group_ids}
async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict: async def get_publicised_groups_for_user(self, user_id):
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
result = await self.store.get_publicised_groups_for_user(user_id) result = await self.store.get_publicised_groups_for_user(user_id)
@@ -213,10 +206,8 @@ class GroupsLocalWorkerHandler:
# TODO: Verify attestations # TODO: Verify attestations
return {"groups": result} return {"groups": result}
async def bulk_get_publicised_groups( async def bulk_get_publicised_groups(self, user_ids, proxy=True):
self, user_ids: Iterable[str], proxy: bool = True destinations = {}
) -> JsonDict:
destinations = {} # type: Dict[str, Set[str]]
local_users = set() local_users = set()
for user_id in user_ids: for user_id in user_ids:
@@ -229,7 +220,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local") raise SynapseError(400, "Some user_ids are not local")
results = {} results = {}
failed_results = [] # type: List[str] failed_results = []
for destination, dest_user_ids in destinations.items(): for destination, dest_user_ids in destinations.items():
try: try:
r = await self.transport_client.bulk_get_publicised_groups( r = await self.transport_client.bulk_get_publicised_groups(
@@ -251,7 +242,7 @@ class GroupsLocalWorkerHandler:
class GroupsLocalHandler(GroupsLocalWorkerHandler): class GroupsLocalHandler(GroupsLocalWorkerHandler):
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
super().__init__(hs) super().__init__(hs)
# Ensure attestations get renewed # Ensure attestations get renewed
@@ -280,9 +271,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy") set_group_join_policy = _create_rerouter("set_group_join_policy")
async def create_group( async def create_group(self, group_id, user_id, content):
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Create a group """Create a group
""" """
@@ -295,7 +284,27 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = None local_attestation = None
remote_attestation = None remote_attestation = None
else: else:
raise SynapseError(400, "Unable to create remote groups") local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
content["user_profile"] = await self.profile_handler.get_profile(user_id)
try:
res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
raise e.to_synapse_error()
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
server_name=get_domain_from_id(group_id),
)
is_publicised = content.get("publicise", False) is_publicised = content.get("publicise", False)
token = await self.store.register_user_group_membership( token = await self.store.register_user_group_membership(
@@ -311,9 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res return res
async def join_group( async def join_group(self, group_id, user_id, content):
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Request to join a group """Request to join a group
""" """
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
@@ -358,9 +365,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {} return {}
async def accept_invite( async def accept_invite(self, group_id, user_id, content):
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Accept an invite to a group """Accept an invite to a group
""" """
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
@@ -405,9 +410,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {} return {}
async def invite( async def invite(self, group_id, user_id, requester_user_id, config):
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
) -> JsonDict:
"""Invite a user to a group """Invite a user to a group
""" """
content = {"requester_user_id": requester_user_id, "config": config} content = {"requester_user_id": requester_user_id, "config": config}
@@ -431,9 +434,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res return res
async def on_invite( async def on_invite(self, group_id, user_id, content):
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""One of our users were invited to a group """One of our users were invited to a group
""" """
# TODO: Support auto join and rejection # TODO: Support auto join and rejection
@@ -464,8 +465,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile} return {"state": "invite", "user_profile": user_profile}
async def remove_user_from_group( async def remove_user_from_group(
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict self, group_id, user_id, requester_user_id, content
) -> JsonDict: ):
"""Remove a user from a group """Remove a user from a group
""" """
if user_id == requester_user_id: if user_id == requester_user_id:
@@ -498,9 +499,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res return res
async def user_removed_from_group( async def user_removed_from_group(self, group_id, user_id, content):
self, group_id: str, user_id: str, content: JsonDict
) -> None:
"""One of our users was removed/kicked from a group """One of our users was removed/kicked from a group
""" """
# TODO: Check if user in group # TODO: Check if user in group

View File

@@ -27,11 +27,9 @@ from synapse.api.errors import (
HttpResponseException, HttpResponseException,
SynapseError, SynapseError,
) )
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester from synapse.types import JsonDict, Requester
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64 from synapse.util.hash import sha256_and_url_safe_base64
@@ -59,32 +57,6 @@ class IdentityHandler(BaseHandler):
self._web_client_location = hs.config.invite_client_location self._web_client_location = hs.config.invite_client_location
# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
def ratelimit_request_token_requests(
self, request: SynapseRequest, medium: str, address: str,
):
"""Used to ratelimit requests to `/requestToken` by IP and address.
Args:
request: The associated request
medium: The type of threepid, e.g. "msisdn" or "email"
address: The actual threepid ID, e.g. the phone number or email address
"""
self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
async def threepid_from_creds( async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str] self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]: ) -> Optional[JsonDict]:

View File

@@ -174,7 +174,7 @@ class MessageHandler:
raise NotFoundError("Can't find event for token %s" % (at_token,)) raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = await filter_events_for_client( visible_events = await filter_events_for_client(
self.storage, user_id, last_events, filter_send_to_client=False, self.storage, user_id, last_events, filter_send_to_client=False
) )
event = last_events[0] event = last_events[0]
@@ -432,8 +432,6 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
self._external_cache = hs.get_external_cache()
async def create_event( async def create_event(
self, self,
requester: Requester, requester: Requester,
@@ -941,8 +939,6 @@ class EventCreationHandler:
await self.action_generator.handle_push_actions_for_event(event, context) await self.action_generator.handle_push_actions_for_event(event, context)
await self.cache_joined_hosts_for_event(event)
try: try:
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id) writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -982,44 +978,6 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id) await self.store.remove_push_actions_from_staging(event.event_id)
raise raise
async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
"""Precalculate the joined hosts at the event, when using Redis, so that
external federation senders don't have to recalculate it themselves.
"""
if not self._external_cache.is_enabled():
return
# We actually store two mappings, event ID -> prev state group,
# state group -> joined hosts, which is much more space efficient
# than event ID -> joined hosts.
#
# Note: We have to cache event ID -> prev state group, as we don't
# store that in the DB.
#
# Note: We always set the state group -> joined hosts cache, even if
# we already set it, so that the expiry time is reset.
state_entry = await self.state.resolve_state_groups_for_events(
event.room_id, event_ids=event.prev_event_ids()
)
if state_entry.state_group:
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
await self._external_cache.set(
"event_to_prev_state_group",
event.event_id,
state_entry.state_group,
expiry_ms=60 * 60 * 1000,
)
await self._external_cache.set(
"get_joined_hosts",
str(state_entry.state_group),
list(joined_hosts),
expiry_ms=60 * 60 * 1000,
)
async def _validate_canonical_alias( async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None: ) -> None:

View File

@@ -102,7 +102,7 @@ class OidcHandler:
) from e ) from e
async def handle_oidc_callback(self, request: SynapseRequest) -> None: async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/client/oidc/callback """Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call way, we don't raise SynapseError from here. Instead, we call
@@ -274,9 +274,6 @@ class OidcProvider:
# MXC URI for icon for this auth provider # MXC URI for icon for this auth provider
self.idp_icon = provider.idp_icon self.idp_icon = provider.idp_icon
# optional brand identifier for this auth provider
self.idp_brand = provider.idp_brand
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self) self._sso_handler.register_identity_provider(self)
@@ -643,7 +640,7 @@ class OidcProvider:
- ``client_id``: the client ID set in ``oidc_config.client_id`` - ``client_id``: the client ID set in ``oidc_config.client_id``
- ``response_type``: ``code`` - ``response_type``: ``code``
- ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback`` - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
- ``scope``: the list of scopes set in ``oidc_config.scopes`` - ``scope``: the list of scopes set in ``oidc_config.scopes``
- ``state``: a random string - ``state``: a random string
- ``nonce``: a random string - ``nonce``: a random string
@@ -684,7 +681,7 @@ class OidcProvider:
request.addCookie( request.addCookie(
SESSION_COOKIE_NAME, SESSION_COOKIE_NAME,
cookie, cookie,
path="/_synapse/client/oidc", path="/_synapse/oidc",
max_age="3600", max_age="3600",
httpOnly=True, httpOnly=True,
sameSite="lax", sameSite="lax",
@@ -705,7 +702,7 @@ class OidcProvider:
async def handle_oidc_callback( async def handle_oidc_callback(
self, request: SynapseRequest, session_data: "OidcSessionData", code: str self, request: SynapseRequest, session_data: "OidcSessionData", code: str
) -> None: ) -> None:
"""Handle an incoming request to /_synapse/client/oidc/callback """Handle an incoming request to /_synapse/oidc/callback
By this time we have already validated the session on the synapse side, and By this time we have already validated the session on the synapse side, and
now need to do the provider-specific operations. This includes: now need to do the provider-specific operations. This includes:
@@ -1059,8 +1056,7 @@ class OidcSessionData:
UserAttributeDict = TypedDict( UserAttributeDict = TypedDict(
"UserAttributeDict", "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
{"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
) )
C = TypeVar("C") C = TypeVar("C")
@@ -1139,12 +1135,11 @@ def jinja_finalize(thing):
env = Environment(finalize=jinja_finalize) env = Environment(finalize=jinja_finalize)
@attr.s(slots=True, frozen=True) @attr.s
class JinjaOidcMappingConfig: class JinjaOidcMappingConfig:
subject_claim = attr.ib(type=str) subject_claim = attr.ib(type=str)
localpart_template = attr.ib(type=Optional[Template]) localpart_template = attr.ib(type=Optional[Template])
display_name_template = attr.ib(type=Optional[Template]) display_name_template = attr.ib(type=Optional[Template])
email_template = attr.ib(type=Optional[Template])
extra_attributes = attr.ib(type=Dict[str, Template]) extra_attributes = attr.ib(type=Dict[str, Template])
@@ -1161,17 +1156,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
def parse_config(config: dict) -> JinjaOidcMappingConfig: def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub") subject_claim = config.get("subject_claim", "sub")
def parse_template_config(option_name: str) -> Optional[Template]: localpart_template = None # type: Optional[Template]
if option_name not in config: if "localpart_template" in config:
return None
try: try:
return env.from_string(config[option_name]) localpart_template = env.from_string(config["localpart_template"])
except Exception as e: except Exception as e:
raise ConfigError("invalid jinja template", path=[option_name]) from e raise ConfigError(
"invalid jinja template", path=["localpart_template"]
) from e
localpart_template = parse_template_config("localpart_template") display_name_template = None # type: Optional[Template]
display_name_template = parse_template_config("display_name_template") if "display_name_template" in config:
email_template = parse_template_config("email_template") try:
display_name_template = env.from_string(config["display_name_template"])
except Exception as e:
raise ConfigError(
"invalid jinja template", path=["display_name_template"]
) from e
extra_attributes = {} # type Dict[str, Template] extra_attributes = {} # type Dict[str, Template]
if "extra_attributes" in config: if "extra_attributes" in config:
@@ -1191,7 +1192,6 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
subject_claim=subject_claim, subject_claim=subject_claim,
localpart_template=localpart_template, localpart_template=localpart_template,
display_name_template=display_name_template, display_name_template=display_name_template,
email_template=email_template,
extra_attributes=extra_attributes, extra_attributes=extra_attributes,
) )
@@ -1213,23 +1213,16 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
# a usable mxid. # a usable mxid.
localpart += str(failures) if failures else "" localpart += str(failures) if failures else ""
def render_template_field(template: Optional[Template]) -> Optional[str]: display_name = None # type: Optional[str]
if template is None: if self._config.display_name_template is not None:
return None display_name = self._config.display_name_template.render(
return template.render(user=userinfo).strip() user=userinfo
).strip()
display_name = render_template_field(self._config.display_name_template)
if display_name == "": if display_name == "":
display_name = None display_name = None
emails = [] # type: List[str] return UserAttributeDict(localpart=localpart, display_name=display_name)
email = render_template_field(self._config.email_template)
if email:
emails.append(email)
return UserAttributeDict(
localpart=localpart, display_name=display_name, emails=emails
)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str] extras = {} # type: Dict[str, str]

View File

@@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse import types from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -153,7 +152,7 @@ class RegistrationHandler(BaseHandler):
user_type: Optional[str] = None, user_type: Optional[str] = None,
default_display_name: Optional[str] = None, default_display_name: Optional[str] = None,
address: Optional[str] = None, address: Optional[str] = None,
bind_emails: Iterable[str] = [], bind_emails: List[str] = [],
by_admin: bool = False, by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None, user_agent_ips: Optional[List[Tuple[str, str]]] = None,
) -> str: ) -> str:
@@ -694,8 +693,6 @@ class RegistrationHandler(BaseHandler):
access_token: The access token of the newly logged in device, or access_token: The access token of the newly logged in device, or
None if `inhibit_login` enabled. None if `inhibit_login` enabled.
""" """
# TODO: 3pid registration can actually happen on the workers. Consider
# refactoring it.
if self.hs.config.worker_app: if self.hs.config.worker_app:
await self._post_registration_client( await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token user_id=user_id, auth_result=auth_result, access_token=access_token

View File

@@ -126,10 +126,6 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules() self.third_party_event_rules = hs.get_third_party_event_rules()
self._invite_burst_count = (
hs.config.ratelimiting.rc_invites_per_room.burst_count
)
async def upgrade_room( async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion self, requester: Requester, old_room_id: str, new_version: RoomVersion
) -> str: ) -> str:
@@ -666,9 +662,6 @@ class RoomCreationHandler(BaseHandler):
invite_3pid_list = [] invite_3pid_list = []
invite_list = [] invite_list = []
if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
raise SynapseError(400, "Cannot invite so many users at once")
await self.event_creation_handler.assert_accepted_privacy_policy(requester) await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override") power_level_content_override = config.get("power_level_content_override")

View File

@@ -85,17 +85,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
) )
self._invites_per_room_limiter = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
)
self._invites_per_user_limiter = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
)
# This is only used to get at ratelimit function, and # This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as # maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state. # it doesn't store state.
@@ -155,12 +144,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
def ratelimit_invite(self, room_id: str, invitee_user_id: str):
"""Ratelimit invites by room and by target user.
"""
self._invites_per_room_limiter.ratelimit(room_id)
self._invites_per_user_limiter.ratelimit(invitee_user_id)
async def _local_membership_update( async def _local_membership_update(
self, self,
requester: Requester, requester: Requester,
@@ -404,12 +387,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
if effective_membership_state == Membership.INVITE: if effective_membership_state == Membership.INVITE:
target_id = target.to_string()
if ratelimit:
self.ratelimit_invite(room_id, target_id)
# block any attempts to invite the server notices mxid # block any attempts to invite the server notices mxid
if target_id == self._server_notices_mxid: if target.to_string() == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
block_invite = False block_invite = False
@@ -433,7 +412,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
block_invite = True block_invite = True
if not await self.spam_checker.user_may_invite( if not await self.spam_checker.user_may_invite(
requester.user.to_string(), target_id, room_id requester.user.to_string(), target.to_string(), room_id
): ):
logger.info("Blocking invite due to spam checker") logger.info("Blocking invite due to spam checker")
block_invite = True block_invite = True

View File

@@ -78,10 +78,9 @@ class SamlHandler(BaseHandler):
# user-facing name of this auth provider # user-facing name of this auth provider
self.idp_name = "SAML" self.idp_name = "SAML"
# we do not currently support icons/brands for SAML auth, but this is required by # we do not currently support icons for SAML auth, but this is required by
# the SsoIdentityProvider protocol type. # the SsoIdentityProvider protocol type.
self.idp_icon = None self.idp_icon = None
self.idp_brand = None
# a map from saml session id to Saml2SessionData object # a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
@@ -133,7 +132,7 @@ class SamlHandler(BaseHandler):
raise Exception("prepare_for_authenticate didn't return a Location header") raise Exception("prepare_for_authenticate didn't return a Location header")
async def handle_saml_response(self, request: SynapseRequest) -> None: async def handle_saml_response(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/client/saml2/authn_response """Handle an incoming request to /_matrix/saml2/authn_response
Args: Args:
request: the incoming request from the browser. We'll request: the incoming request from the browser. We'll

View File

@@ -15,28 +15,23 @@
import itertools import itertools
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional from typing import Iterable
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler): class SearchHandler(BaseHandler):
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
super().__init__(hs) super().__init__(hs)
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage() self.storage = hs.get_storage()
@@ -92,15 +87,13 @@ class SearchHandler(BaseHandler):
return historical_room_ids return historical_room_ids
async def search( async def search(self, user, content, batch=None):
self, user: UserID, content: JsonDict, batch: Optional[str] = None
) -> JsonDict:
"""Performs a full text search for a user. """Performs a full text search for a user.
Args: Args:
user user (UserID)
content: Search parameters content (dict): Search parameters
batch: The next_batch parameter. Used for pagination. batch (str): The next_batch parameter. Used for pagination.
Returns: Returns:
dict to be returned to the client with results of search dict to be returned to the client with results of search
@@ -193,7 +186,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms # If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well # are from an upgraded room, and search their contents as well
if search_filter.rooms: if search_filter.rooms:
historical_room_ids = [] # type: List[str] historical_room_ids = []
for room_id in search_filter.rooms: for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist # Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id) ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -216,10 +209,8 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event rank_map = {} # event_id -> rank of event
allowed_events = [] allowed_events = []
# Holds result of grouping by room, if applicable room_groups = {} # Holds result of grouping by room, if applicable
room_groups = {} # type: Dict[str, JsonDict] sender_group = {} # Holds result of grouping by sender, if applicable
# Holds result of grouping by sender, if applicable
sender_group = {} # type: Dict[str, JsonDict]
# Holds the next_batch for the entire result set if one of those exists # Holds the next_batch for the entire result set if one of those exists
global_next_batch = None global_next_batch = None
@@ -263,7 +254,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id) s["results"].append(e.event_id)
elif order_by == "recent": elif order_by == "recent":
room_events = [] # type: List[EventBase] room_events = []
i = 0 i = 0
pagination_token = batch_token pagination_token = batch_token
@@ -427,10 +418,13 @@ class SearchHandler(BaseHandler):
state_results = {} state_results = {}
if include_state: if include_state:
for room_id in {e.room_id for e in allowed_events}: rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
state = await self.state_handler.get_current_state(room_id) state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values()) state_results[room_id] = list(state.values())
state_results.values()
# We're now about to serialize the events. We should not make any # We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong # blocking calls after this. Otherwise the 'age' will be wrong
@@ -454,9 +448,9 @@ class SearchHandler(BaseHandler):
if state_results: if state_results:
s = {} s = {}
for room_id, state_events in state_results.items(): for room_id, state in state_results.items():
s[room_id] = await self._event_serializer.serialize_events( s[room_id] = await self._event_serializer.serialize_events(
state_events, time_now state, time_now
) )
rooms_cat_res["state"] = s rooms_cat_res["state"] = s

View File

@@ -13,26 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import Optional
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester from synapse.types import Requester
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler): class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords""" """Handler which deals with changing user account passwords"""
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
super().__init__(hs) super().__init__(hs)
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()
self._password_policy_handler = hs.get_password_policy_handler()
async def set_password( async def set_password(
self, self,
@@ -40,7 +38,7 @@ class SetPasswordHandler(BaseHandler):
password_hash: str, password_hash: str,
logout_devices: bool, logout_devices: bool,
requester: Optional[Requester] = None, requester: Optional[Requester] = None,
) -> None: ):
if not self.hs.config.password_localdb_enabled: if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)

View File

@@ -14,31 +14,21 @@
# limitations under the License. # limitations under the License.
import abc import abc
import logging import logging
from typing import ( from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
TYPE_CHECKING,
Awaitable,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Set,
)
from urllib.parse import urlencode from urllib.parse import urlencode
import attr import attr
from typing_extensions import NoReturn, Protocol from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request from twisted.web.http import Request
from twisted.web.iweb import IRequest
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@@ -90,11 +80,6 @@ class SsoIdentityProvider(Protocol):
"""Optional MXC URI for user-facing icon""" """Optional MXC URI for user-facing icon"""
return None return None
@property
def idp_brand(self) -> Optional[str]:
"""Optional branding identifier"""
return None
@abc.abstractmethod @abc.abstractmethod
async def handle_redirect_request( async def handle_redirect_request(
self, self,
@@ -124,7 +109,7 @@ class UserAttributes:
# enter one. # enter one.
localpart = attr.ib(type=Optional[str]) localpart = attr.ib(type=Optional[str])
display_name = attr.ib(type=Optional[str], default=None) display_name = attr.ib(type=Optional[str], default=None)
emails = attr.ib(type=Collection[str], default=attr.Factory(list)) emails = attr.ib(type=List[str], default=attr.Factory(list))
@attr.s(slots=True) @attr.s(slots=True)
@@ -139,7 +124,7 @@ class UsernameMappingSession:
# attributes returned by the ID mapper # attributes returned by the ID mapper
display_name = attr.ib(type=Optional[str]) display_name = attr.ib(type=Optional[str])
emails = attr.ib(type=Collection[str]) emails = attr.ib(type=List[str])
# An optional dictionary of extra attributes to be provided to the client in the # An optional dictionary of extra attributes to be provided to the client in the
# login response. # login response.
@@ -151,12 +136,6 @@ class UsernameMappingSession:
# expiry time for the session, in milliseconds # expiry time for the session, in milliseconds
expiry_time_ms = attr.ib(type=int) expiry_time_ms = attr.ib(type=int)
# choices made by the user
chosen_localpart = attr.ib(type=Optional[str], default=None)
use_display_name = attr.ib(type=bool, default=True)
emails_to_use = attr.ib(type=Collection[str], default=())
terms_accepted_version = attr.ib(type=Optional[str], default=None)
# the HTTP cookie used to track the mapping session id # the HTTP cookie used to track the mapping session id
USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
@@ -191,8 +170,6 @@ class SsoHandler:
# map from idp_id to SsoIdentityProvider # map from idp_id to SsoIdentityProvider
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
self._consent_at_registration = hs.config.consent.user_consent_at_registration
def register_identity_provider(self, p: SsoIdentityProvider): def register_identity_provider(self, p: SsoIdentityProvider):
p_id = p.idp_id p_id = p.idp_id
assert p_id not in self._identity_providers assert p_id not in self._identity_providers
@@ -258,10 +235,7 @@ class SsoHandler:
respond_with_html(request, code, html) respond_with_html(request, code, html)
async def handle_redirect_request( async def handle_redirect_request(
self, self, request: SynapseRequest, client_redirect_url: bytes,
request: SynapseRequest,
client_redirect_url: bytes,
idp_id: Optional[str],
) -> str: ) -> str:
"""Handle a request to /login/sso/redirect """Handle a request to /login/sso/redirect
@@ -269,7 +243,6 @@ class SsoHandler:
request: incoming HTTP request request: incoming HTTP request
client_redirect_url: the URL that we should redirect the client_redirect_url: the URL that we should redirect the
client to after login. client to after login.
idp_id: optional identity provider chosen by the client
Returns: Returns:
the URI to redirect to the URI to redirect to
@@ -279,19 +252,10 @@ class SsoHandler:
400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
) )
# if the client chose an IdP, use that
idp = None # type: Optional[SsoIdentityProvider]
if idp_id:
idp = self._identity_providers.get(idp_id)
if not idp:
raise NotFoundError("Unknown identity provider")
# if we only have one auth provider, redirect to it directly # if we only have one auth provider, redirect to it directly
elif len(self._identity_providers) == 1: if len(self._identity_providers) == 1:
idp = next(iter(self._identity_providers.values())) ap = next(iter(self._identity_providers.values()))
return await ap.handle_redirect_request(request, client_redirect_url)
if idp:
return await idp.handle_redirect_request(request, client_redirect_url)
# otherwise, redirect to the IDP picker # otherwise, redirect to the IDP picker
return "/_synapse/client/pick_idp?" + urlencode( return "/_synapse/client/pick_idp?" + urlencode(
@@ -405,8 +369,6 @@ class SsoHandler:
to an additional page. (e.g. to prompt for more information) to an additional page. (e.g. to prompt for more information)
""" """
new_user = False
# grab a lock while we try to find a mapping for this user. This seems... # grab a lock while we try to find a mapping for this user. This seems...
# optimistic, especially for implementations that end up redirecting to # optimistic, especially for implementations that end up redirecting to
# interstitial pages. # interstitial pages.
@@ -447,14 +409,9 @@ class SsoHandler:
get_request_user_agent(request), get_request_user_agent(request),
request.getClientIP(), request.getClientIP(),
) )
new_user = True
await self._auth_handler.complete_sso_login( await self._auth_handler.complete_sso_login(
user_id, user_id, request, client_redirect_url, extra_login_attributes
request,
client_redirect_url,
extra_login_attributes,
new_user=new_user,
) )
async def _call_attribute_mapper( async def _call_attribute_mapper(
@@ -544,7 +501,7 @@ class SsoHandler:
logger.info("Recorded registration session id %s", session_id) logger.info("Recorded registration session id %s", session_id)
# Set the cookie and redirect to the username picker # Set the cookie and redirect to the username picker
e = RedirectException(b"/_synapse/client/pick_username/account_details") e = RedirectException(b"/_synapse/client/pick_username")
e.cookies.append( e.cookies.append(
b"%s=%s; path=/" b"%s=%s; path=/"
% (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
@@ -672,25 +629,6 @@ class SsoHandler:
) )
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
"""Look up the given username mapping session
If it is not found, raises a SynapseError with an http code of 400
Args:
session_id: session to look up
Returns:
active mapping session
Raises:
SynapseError if the session is not found/has expired
"""
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if session:
return session
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
async def check_username_availability( async def check_username_availability(
self, localpart: str, session_id: str, self, localpart: str, session_id: str,
) -> bool: ) -> bool:
@@ -707,7 +645,12 @@ class SsoHandler:
# make sure that there is a valid mapping session, to stop people dictionary- # make sure that there is a valid mapping session, to stop people dictionary-
# scanning for accounts # scanning for accounts
self.get_mapping_session(session_id)
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
logger.info( logger.info(
"[session %s] Checking for availability of username %s", "[session %s] Checking for availability of username %s",
@@ -724,12 +667,7 @@ class SsoHandler:
return not user_infos return not user_infos
async def handle_submit_username_request( async def handle_submit_username_request(
self, self, request: SynapseRequest, localpart: str, session_id: str
request: SynapseRequest,
session_id: str,
localpart: str,
use_display_name: bool,
emails_to_use: Iterable[str],
) -> None: ) -> None:
"""Handle a request to the username-picker 'submit' endpoint """Handle a request to the username-picker 'submit' endpoint
@@ -739,90 +677,21 @@ class SsoHandler:
request: HTTP request request: HTTP request
localpart: localpart requested by the user localpart: localpart requested by the user
session_id: ID of the username mapping session, extracted from a cookie session_id: ID of the username mapping session, extracted from a cookie
use_display_name: whether the user wants to use the suggested display name
emails_to_use: emails that the user would like to use
""" """
session = self.get_mapping_session(session_id) self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
# update the session with the user's choices logger.info("[session %s] Registering localpart %s", session_id, localpart)
session.chosen_localpart = localpart
session.use_display_name = use_display_name
emails_from_idp = set(session.emails)
filtered_emails = set() # type: Set[str]
# we iterate through the list rather than just building a set conjunction, so
# that we can log attempts to use unknown addresses
for email in emails_to_use:
if email in emails_from_idp:
filtered_emails.add(email)
else:
logger.warning(
"[session %s] ignoring user request to use unknown email address %r",
session_id,
email,
)
session.emails_to_use = filtered_emails
# we may now need to collect consent from the user, in which case, redirect
# to the consent-extraction-unit
if self._consent_at_registration:
redirect_url = b"/_synapse/client/new_user_consent"
# otherwise, redirect to the completion page
else:
redirect_url = b"/_synapse/client/sso_register"
respond_with_redirect(request, redirect_url)
async def handle_terms_accepted(
self, request: Request, session_id: str, terms_version: str
):
"""Handle a request to the new-user 'consent' endpoint
Will serve an HTTP response to the request.
Args:
request: HTTP request
session_id: ID of the username mapping session, extracted from a cookie
terms_version: the version of the terms which the user viewed and consented
to
"""
logger.info(
"[session %s] User consented to terms version %s",
session_id,
terms_version,
)
session = self.get_mapping_session(session_id)
session.terms_accepted_version = terms_version
# we're done; now we can register the user
respond_with_redirect(request, b"/_synapse/client/sso_register")
async def register_sso_user(self, request: Request, session_id: str) -> None:
"""Called once we have all the info we need to register a new user.
Does so and serves an HTTP response
Args:
request: HTTP request
session_id: ID of the username mapping session, extracted from a cookie
"""
session = self.get_mapping_session(session_id)
logger.info(
"[session %s] Registering localpart %s",
session_id,
session.chosen_localpart,
)
attributes = UserAttributes( attributes = UserAttributes(
localpart=session.chosen_localpart, emails=session.emails_to_use, localpart=localpart,
display_name=session.display_name,
emails=session.emails,
) )
if session.use_display_name:
attributes.display_name = session.display_name
# the following will raise a 400 error if the username has been taken in the # the following will raise a 400 error if the username has been taken in the
# meantime. # meantime.
user_id = await self._register_mapped_user( user_id = await self._register_mapped_user(
@@ -833,12 +702,7 @@ class SsoHandler:
request.getClientIP(), request.getClientIP(),
) )
logger.info( logger.info("[session %s] Registered userid %s", session_id, user_id)
"[session %s] Registered userid %s with attributes %s",
session_id,
user_id,
attributes,
)
# delete the mapping session and the cookie # delete the mapping session and the cookie
del self._username_mapping_sessions[session_id] del self._username_mapping_sessions[session_id]
@@ -851,21 +715,11 @@ class SsoHandler:
path=b"/", path=b"/",
) )
auth_result = {}
if session.terms_accepted_version:
# TODO: make this less awful.
auth_result[LoginType.TERMS] = True
await self._registration_handler.post_registration_actions(
user_id, auth_result, access_token=None
)
await self._auth_handler.complete_sso_login( await self._auth_handler.complete_sso_login(
user_id, user_id,
request, request,
session.client_redirect_url, session.client_redirect_url,
session.extra_login_attributes, session.extra_login_attributes,
new_user=True,
) )
def _expire_old_sessions(self): def _expire_old_sessions(self):
@@ -879,14 +733,3 @@ class SsoHandler:
for session_id in to_expire: for session_id in to_expire:
logger.info("Expiring mapping session %s", session_id) logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id] del self._username_mapping_sessions[session_id]
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
Raises a SynapseError if the cookie isn't found
"""
session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
return session_id.decode("ascii", errors="replace")

View File

@@ -14,25 +14,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class StateDeltasHandler: class StateDeltasHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def _get_key_change( async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
self,
prev_event_id: Optional[str],
event_id: Optional[str],
key_name: str,
public_value: str,
) -> Optional[bool]:
"""Given two events check if the `key_name` field in content changed """Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so. from not matching `public_value` to doing so.

View File

@@ -12,19 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from collections import Counter from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
from typing_extensions import Counter as CounterType
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.metrics import event_processing_positions from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -37,7 +31,7 @@ class StatsHandler:
Heavily derived from UserDirectoryHandler Heavily derived from UserDirectoryHandler
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@@ -50,7 +44,7 @@ class StatsHandler:
self.stats_enabled = hs.config.stats_enabled self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream # The current position in the current_state_delta stream
self.pos = None # type: Optional[int] self.pos = None
# Guard to ensure we only process deltas one at a time # Guard to ensure we only process deltas one at a time
self._is_processing = False self._is_processing = False
@@ -62,7 +56,7 @@ class StatsHandler:
# we start populating stats # we start populating stats
self.clock.call_later(0, self.notify_new_event) self.clock.call_later(0, self.notify_new_event)
def notify_new_event(self) -> None: def notify_new_event(self):
"""Called when there may be more deltas to process """Called when there may be more deltas to process
""" """
if not self.stats_enabled or self._is_processing: if not self.stats_enabled or self._is_processing:
@@ -78,7 +72,7 @@ class StatsHandler:
run_as_background_process("stats.notify_new_event", process) run_as_background_process("stats.notify_new_event", process)
async def _unsafe_process(self) -> None: async def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB # If self.pos is None then means we haven't fetched it from DB
if self.pos is None: if self.pos is None:
self.pos = await self.store.get_stats_positions() self.pos = await self.store.get_stats_positions()
@@ -116,10 +110,10 @@ class StatsHandler:
) )
for room_id, fields in room_count.items(): for room_id, fields in room_count.items():
room_deltas.setdefault(room_id, Counter()).update(fields) room_deltas.setdefault(room_id, {}).update(fields)
for user_id, fields in user_count.items(): for user_id, fields in user_count.items():
user_deltas.setdefault(user_id, Counter()).update(fields) user_deltas.setdefault(user_id, {}).update(fields)
logger.debug("room_deltas: %s", room_deltas) logger.debug("room_deltas: %s", room_deltas)
logger.debug("user_deltas: %s", user_deltas) logger.debug("user_deltas: %s", user_deltas)
@@ -137,20 +131,19 @@ class StatsHandler:
self.pos = max_pos self.pos = max_pos
async def _handle_deltas( async def _handle_deltas(self, deltas):
self, deltas: Iterable[JsonDict]
) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
"""Called with the state deltas to process """Called with the state deltas to process
Returns: Returns:
tuple[dict[str, Counter], dict[str, counter]]
Two dicts: the room deltas and the user deltas, Two dicts: the room deltas and the user deltas,
mapping from room/user ID to changes in the various fields. mapping from room/user ID to changes in the various fields.
""" """
room_to_stats_deltas = {} # type: Dict[str, CounterType[str]] room_to_stats_deltas = {}
user_to_stats_deltas = {} # type: Dict[str, CounterType[str]] user_to_stats_deltas = {}
room_to_state_updates = {} # type: Dict[str, Dict[str, Any]] room_to_state_updates = {}
for delta in deltas: for delta in deltas:
typ = delta["type"] typ = delta["type"]
@@ -180,7 +173,7 @@ class StatsHandler:
) )
continue continue
event_content = {} # type: JsonDict event_content = {}
sender = None sender = None
if event_id is not None: if event_id is not None:
@@ -264,13 +257,13 @@ class StatsHandler:
) )
if has_changed_joinedness: if has_changed_joinedness:
membership_delta = +1 if membership == Membership.JOIN else -1 delta = +1 if membership == Membership.JOIN else -1
user_to_stats_deltas.setdefault(user_id, Counter())[ user_to_stats_deltas.setdefault(user_id, Counter())[
"joined_rooms" "joined_rooms"
] += membership_delta ] += delta
room_stats_delta["local_users_in_room"] += membership_delta room_stats_delta["local_users_in_room"] += delta
elif typ == EventTypes.Create: elif typ == EventTypes.Create:
room_state["is_federatable"] = ( room_state["is_federatable"] = (

View File

@@ -15,13 +15,13 @@
import logging import logging
import random import random
from collections import namedtuple from collections import namedtuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from typing import TYPE_CHECKING, List, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams import TypingStream
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
@@ -65,17 +65,17 @@ class FollowerTypingHandler:
) )
# map room IDs to serial numbers # map room IDs to serial numbers
self._room_serials = {} # type: Dict[str, int] self._room_serials = {}
# map room IDs to sets of users currently typing # map room IDs to sets of users currently typing
self._room_typing = {} # type: Dict[str, Set[str]] self._room_typing = {}
self._member_last_federation_poke = {} # type: Dict[RoomMember, int] self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000) self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0 self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000) self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self) -> None: def _reset(self):
"""Reset the typing handler's data caches. """Reset the typing handler's data caches.
""" """
# map room IDs to serial numbers # map room IDs to serial numbers
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
self._member_last_federation_poke = {} self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000) self.wheel_timer = WheelTimer(bucket_size=5000)
def _handle_timeouts(self) -> None: def _handle_timeouts(self):
logger.debug("Checking for typing timeouts") logger.debug("Checking for typing timeouts")
now = self.clock.time_msec() now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
for member in members: for member in members:
self._handle_timeout_for_member(now, member) self._handle_timeout_for_member(now, member)
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: def _handle_timeout_for_member(self, now: int, member: RoomMember):
if not self.is_typing(member): if not self.is_typing(member):
# Nothing to do if they're no longer typing # Nothing to do if they're no longer typing
return return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
# each person typing. # each person typing.
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member: RoomMember) -> bool: def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, []) return member.user_id in self._room_typing.get(member.room_id, [])
async def _push_remote(self, member: RoomMember, typing: bool) -> None: async def _push_remote(self, member, typing):
if not self.federation: if not self.federation:
return return
@@ -148,7 +148,7 @@ class FollowerTypingHandler:
def process_replication_rows( def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow] self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None: ):
"""Should be called whenever we receive updates for typing stream. """Should be called whenever we receive updates for typing stream.
""" """
@@ -178,7 +178,7 @@ class FollowerTypingHandler:
async def _send_changes_in_typing_to_remotes( async def _send_changes_in_typing_to_remotes(
self, room_id: str, prev_typing: Set[str], now_typing: Set[str] self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
) -> None: ):
"""Process a change in typing of a room from replication, sending EDUs """Process a change in typing of a room from replication, sending EDUs
for any local users. for any local users.
""" """
@@ -194,12 +194,12 @@ class FollowerTypingHandler:
if self.is_mine_id(user_id): if self.is_mine_id(user_id):
await self._push_remote(RoomMember(room_id, user_id), False) await self._push_remote(RoomMember(room_id, user_id), False)
def get_current_token(self) -> int: def get_current_token(self):
return self._latest_room_serial return self._latest_room_serial
class TypingWriterHandler(FollowerTypingHandler): class TypingWriterHandler(FollowerTypingHandler):
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
super().__init__(hs) super().__init__(hs)
assert hs.config.worker.writers.typing == hs.get_instance_name() assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,15 +213,14 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)
# clock time we expect to stop self._member_typing_until = {} # clock time we expect to stop
self._member_typing_until = {} # type: Dict[RoomMember, int]
# caches which room_ids changed at which serials # caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache( self._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", self._latest_room_serial "TypingStreamChangeCache", self._latest_room_serial
) )
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: def _handle_timeout_for_member(self, now: int, member: RoomMember):
super()._handle_timeout_for_member(now, member) super()._handle_timeout_for_member(now, member)
if not self.is_typing(member): if not self.is_typing(member):
@@ -234,9 +233,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member) self._stopped_typing(member)
return return
async def started_typing( async def started_typing(self, target_user, requester, room_id, timeout):
self, target_user: UserID, requester: Requester, room_id: str, timeout: int
) -> None:
target_user_id = target_user.to_string() target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string() auth_user_id = requester.user.to_string()
@@ -266,13 +263,11 @@ class TypingWriterHandler(FollowerTypingHandler):
if was_present: if was_present:
# No point sending another notification # No point sending another notification
return return None
self._push_update(member=member, typing=True) self._push_update(member=member, typing=True)
async def stopped_typing( async def stopped_typing(self, target_user, requester, room_id):
self, target_user: UserID, requester: Requester, room_id: str
) -> None:
target_user_id = target_user.to_string() target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string() auth_user_id = requester.user.to_string()
@@ -295,23 +290,23 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member) self._stopped_typing(member)
def user_left_room(self, user: UserID, room_id: str) -> None: def user_left_room(self, user, room_id):
user_id = user.to_string() user_id = user.to_string()
if self.is_mine_id(user_id): if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user_id=user_id) member = RoomMember(room_id=room_id, user_id=user_id)
self._stopped_typing(member) self._stopped_typing(member)
def _stopped_typing(self, member: RoomMember) -> None: def _stopped_typing(self, member):
if member.user_id not in self._room_typing.get(member.room_id, set()): if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point # No point
return return None
self._member_typing_until.pop(member, None) self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None) self._member_last_federation_poke.pop(member, None)
self._push_update(member=member, typing=False) self._push_update(member=member, typing=False)
def _push_update(self, member: RoomMember, typing: bool) -> None: def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id): if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users. # Only send updates for changes to our own users.
run_as_background_process( run_as_background_process(
@@ -320,7 +315,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self._push_update_local(member=member, typing=typing) self._push_update_local(member=member, typing=typing)
async def _recv_edu(self, origin: str, content: JsonDict) -> None: async def _recv_edu(self, origin, content):
room_id = content["room_id"] room_id = content["room_id"]
user_id = content["user_id"] user_id = content["user_id"]
@@ -345,7 +340,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
self._push_update_local(member=member, typing=content["typing"]) self._push_update_local(member=member, typing=content["typing"])
def _push_update_local(self, member: RoomMember, typing: bool) -> None: def _push_update_local(self, member, typing):
room_set = self._room_typing.setdefault(member.room_id, set()) room_set = self._room_typing.setdefault(member.room_id, set())
if typing: if typing:
room_set.add(member.user_id) room_set.add(member.user_id)
@@ -391,7 +386,7 @@ class TypingWriterHandler(FollowerTypingHandler):
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
last_id last_id
) # type: Optional[Iterable[str]] )
if changed_rooms is None: if changed_rooms is None:
changed_rooms = self._room_serials changed_rooms = self._room_serials
@@ -417,13 +412,13 @@ class TypingWriterHandler(FollowerTypingHandler):
def process_replication_rows( def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow] self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None: ):
# The writing process should never get updates from replication. # The writing process should never get updates from replication.
raise Exception("Typing writer instance got typing info over replication") raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource: class TypingNotificationEventSource:
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
# We can't call get_typing_handler here because there's a cycle: # We can't call get_typing_handler here because there's a cycle:
@@ -432,7 +427,7 @@ class TypingNotificationEventSource:
# #
self.get_typing_handler = hs.get_typing_handler self.get_typing_handler = hs.get_typing_handler
def _make_event_for(self, room_id: str) -> JsonDict: def _make_event_for(self, room_id):
typing = self.get_typing_handler()._room_typing[room_id] typing = self.get_typing_handler()._room_typing[room_id]
return { return {
"type": "m.typing", "type": "m.typing",
@@ -467,9 +462,7 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial) return (events, handler._latest_room_serial)
async def get_new_events( async def get_new_events(self, from_key, room_ids, **kwargs):
self, from_key: int, room_ids: Iterable[str], **kwargs
) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"): with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key) from_key = int(from_key)
handler = self.get_typing_handler() handler = self.get_typing_handler()
@@ -485,5 +478,5 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial) return (events, handler._latest_room_serial)
def get_current_key(self) -> int: def get_current_key(self):
return self.get_typing_handler()._latest_room_serial return self.get_typing_handler()._latest_room_serial

View File

@@ -145,6 +145,10 @@ class UserDirectoryHandler(StateDeltasHandler):
if self.pos is None: if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos() self.pos = await self.store.get_user_directory_stream_pos()
# If still None then the initial background update hasn't happened yet
if self.pos is None:
return None
# Loop round handling deltas until we're up to date # Loop round handling deltas until we're up to date
while True: while True:
with Measure(self.clock, "user_dir_delta"): with Measure(self.clock, "user_dir_delta"):
@@ -229,11 +233,6 @@ class UserDirectoryHandler(StateDeltasHandler):
if change: # The user joined if change: # The user joined
event = await self.store.get_event(event_id, allow_none=True) event = await self.store.get_event(event_id, allow_none=True)
# It isn't expected for this event to not exist, but we
# don't want the entire background process to break.
if event is None:
continue
profile = ProfileInfo( profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"), avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"), display_name=event.content.get("displayname"),

View File

@@ -22,22 +22,10 @@ import types
import urllib import urllib
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import ( from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
Any,
Awaitable,
Callable,
Dict,
Iterable,
Iterator,
List,
Pattern,
Tuple,
Union,
)
import jinja2 import jinja2
from canonicaljson import iterencode_canonical_json from canonicaljson import iterencode_canonical_json
from typing_extensions import Protocol
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer, interfaces from twisted.internet import defer, interfaces
@@ -180,25 +168,11 @@ def wrap_async_request_handler(h):
return preserve_fn(wrapped_async_request_handler) return preserve_fn(wrapped_async_request_handler)
# Type of a callback method for processing requests class HttpServer:
# it is actually called with a SynapseRequest and a kwargs dict for the params,
# but I can't figure out how to represent that.
ServletCallback = Callable[
..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
]
class HttpServer(Protocol):
""" Interface for registering callbacks on a HTTP server """ Interface for registering callbacks on a HTTP server
""" """
def register_paths( def register_paths(self, method, path_patterns, callback):
self,
method: str,
path_patterns: Iterable[Pattern],
callback: ServletCallback,
servlet_classname: str,
) -> None:
""" Register a callback that gets fired if we receive a http request """ Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex. with the given method for a path that matches the given regex.
@@ -206,14 +180,12 @@ class HttpServer(Protocol):
an unpacked tuple. an unpacked tuple.
Args: Args:
method: The HTTP method to listen to. method (str): The method to listen to.
path_patterns: The regex used to match requests. path_patterns (list<SRE_Pattern>): The regex used to match requests.
callback: The function to fire if we receive a matched callback (function): The function to fire if we receive a matched
request. The first argument will be the request object and request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex. subsequent arguments will be any matched groups from the regex.
This should return either tuple of (code, response), or None. This should return a tuple of (code, response).
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
""" """
pass pass
@@ -382,7 +354,7 @@ class JsonResource(DirectServeJsonResource):
def _get_handler_for_request( def _get_handler_for_request(
self, request: SynapseRequest self, request: SynapseRequest
) -> Tuple[ServletCallback, str, Dict[str, str]]: ) -> Tuple[Callable, str, Dict[str, str]]:
"""Finds a callback method to handle the given request. """Finds a callback method to handle the given request.
Returns: Returns:
@@ -761,13 +733,6 @@ def set_clickjacking_protection_headers(request: Request):
request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
def respond_with_redirect(request: Request, url: bytes) -> None:
"""Write a 302 response to the request, if it is still alive."""
logger.debug("Redirect to %s", url.decode("utf-8"))
request.redirect(url)
finish_request(request)
def finish_request(request: Request): def finish_request(request: Request):
""" Finish writing the response to the request. """ Finish writing the response to the request.

View File

@@ -791,7 +791,7 @@ def tag_args(func):
@wraps(func) @wraps(func)
def _tag_args_inner(*args, **kwargs): def _tag_args_inner(*args, **kwargs):
argspec = inspect.getfullargspec(func) argspec = inspect.getargspec(func)
for i, arg in enumerate(argspec.args[1:]): for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i]) set_tag("ARG_" + arg, args[i])
set_tag("args", args[len(argspec.args) :]) set_tag("args", args[len(argspec.args) :])

View File

@@ -279,11 +279,7 @@ class ModuleApi:
) )
async def complete_sso_login_async( async def complete_sso_login_async(
self, self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
registered_user_id: str,
request: SynapseRequest,
client_redirect_url: str,
new_user: bool = False,
): ):
"""Complete a SSO login by redirecting the user to a page to confirm whether they """Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that want their access token sent to `client_redirect_url`, or redirect them to that
@@ -295,11 +291,9 @@ class ModuleApi:
request: The request to respond to. request: The request to respond to.
client_redirect_url: The URL to which to offer to redirect the user (or to client_redirect_url: The URL to which to offer to redirect the user (or to
redirect them directly if whitelisted). redirect them directly if whitelisted).
new_user: set to true to use wording for the consent appropriate to a user
who has just registered.
""" """
await self._auth_handler.complete_sso_login( await self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url, new_user=new_user registered_user_id, request, client_redirect_url,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@@ -267,20 +267,8 @@ class Mailer:
fallback_to_members=True, fallback_to_members=True,
) )
if len(notifs_by_room) == 1:
# Only one room has new stuff
room_id = list(notifs_by_room.keys())[0]
summary_text = await self.make_summary_text_single_room(
room_id,
notifs_by_room[room_id],
state_by_room[room_id],
notif_events,
user_id,
)
else:
summary_text = await self.make_summary_text( summary_text = await self.make_summary_text(
notifs_by_room, state_by_room, notif_events, reason notifs_by_room, state_by_room, notif_events, user_id, reason
) )
template_vars = { template_vars = {
@@ -504,37 +492,28 @@ class Mailer:
if "url" in event.content: if "url" in event.content:
messagevars["image_url"] = event.content["url"] messagevars["image_url"] = event.content["url"]
async def make_summary_text_single_room( async def make_summary_text(
self, self,
room_id: str, notifs_by_room: Dict[str, List[Dict[str, Any]]],
notifs: List[Dict[str, Any]], room_state_ids: Dict[str, StateMap[str]],
room_state_ids: StateMap[str],
notif_events: Dict[str, EventBase], notif_events: Dict[str, EventBase],
user_id: str, user_id: str,
) -> str: reason: Dict[str, Any],
""" ):
Make a summary text for the email when only a single room has notifications. if len(notifs_by_room) == 1:
# Only one room has new stuff
room_id = list(notifs_by_room.keys())[0]
Args:
room_id: The ID of the room.
notifs: The notifications for this room.
room_state_ids: The state map for the room.
notif_events: A map of event ID -> notification event.
user_id: The user receiving the notification.
Returns:
The summary text.
"""
# If the room has some kind of name, use it, but we don't # If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll # want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room" # end up with, "new message from Bob in the Bob room"
room_name = await calculate_room_name( room_name = await calculate_room_name(
self.store, room_state_ids, user_id, fallback_to_members=False self.store, room_state_ids[room_id], user_id, fallback_to_members=False
) )
# See if one of the notifs is an invite event for the user # See if one of the notifs is an invite event for the user
invite_event = None invite_event = None
for n in notifs: for n in notifs_by_room[room_id]:
ev = notif_events[n["event_id"]] ev = notif_events[n["event_id"]]
if ev.type == EventTypes.Member and ev.state_key == user_id: if ev.type == EventTypes.Member and ev.state_key == user_id:
if ev.content.get("membership") == Membership.INVITE: if ev.content.get("membership") == Membership.INVITE:
@@ -542,7 +521,7 @@ class Mailer:
break break
if invite_event: if invite_event:
inviter_member_event_id = room_state_ids.get( inviter_member_event_id = room_state_ids[room_id].get(
("m.room.member", invite_event.sender) ("m.room.member", invite_event.sender)
) )
inviter_name = invite_event.sender inviter_name = invite_event.sender
@@ -558,19 +537,21 @@ class Mailer:
"person": inviter_name, "person": inviter_name,
"app": self.app_name, "app": self.app_name,
} }
else:
return self.email_subjects.invite_from_person_to_room % { return self.email_subjects.invite_from_person_to_room % {
"person": inviter_name, "person": inviter_name,
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
} }
if len(notifs) == 1:
# There is just the one notification, so give some detail
sender_name = None sender_name = None
event = notif_events[notifs[0]["event_id"]] if len(notifs_by_room[room_id]) == 1:
if ("m.room.member", event.sender) in room_state_ids: # There is just the one notification, so give some detail
state_event_id = room_state_ids[("m.room.member", event.sender)] event = notif_events[notifs_by_room[room_id][0]["event_id"]]
if ("m.room.member", event.sender) in room_state_ids[room_id]:
state_event_id = room_state_ids[room_id][
("m.room.member", event.sender)
]
state_event = await self.store.get_event(state_event_id) state_event = await self.store.get_event(state_event_id)
sender_name = name_from_member_event(state_event) sender_name = name_from_member_event(state_event)
@@ -585,12 +566,6 @@ class Mailer:
"person": sender_name, "person": sender_name,
"app": self.app_name, "app": self.app_name,
} }
# The sender is unknown, just use the room name (or ID).
return self.email_subjects.messages_in_room % {
"room": room_name or room_id,
"app": self.app_name,
}
else: else:
# There's more than one notification for this room, so just # There's more than one notification for this room, so just
# say there are several # say there are several
@@ -599,80 +574,54 @@ class Mailer:
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
} }
else:
return await self.make_summary_text_from_member_events( # If the room doesn't have a name, say who the messages
room_id, notifs, room_state_ids, notif_events # are from explicitly to avoid, "messages in the Bob room"
sender_ids = list(
{
notif_events[n["event_id"]].sender
for n in notifs_by_room[room_id]
}
) )
async def make_summary_text( member_events = await self.store.get_events(
self, [
notifs_by_room: Dict[str, List[Dict[str, Any]]], room_state_ids[room_id][("m.room.member", s)]
room_state_ids: Dict[str, StateMap[str]], for s in sender_ids
notif_events: Dict[str, EventBase], ]
reason: Dict[str, Any], )
) -> str:
"""
Make a summary text for the email when multiple rooms have notifications.
Args: return self.email_subjects.messages_from_person % {
notifs_by_room: A map of room ID to the notifications for that room. "person": descriptor_from_member_events(member_events.values()),
room_state_ids: A map of room ID to the state map for that room. "app": self.app_name,
notif_events: A map of event ID -> notification event. }
reason: The reason this notification is being sent. else:
Returns:
The summary text.
"""
# Stuff's happened in multiple different rooms # Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail # ...but we still refer to the 'reason' room which triggered the mail
if reason["room_name"] is not None: if reason["room_name"] is not None:
return self.email_subjects.messages_in_room_and_others % { return self.email_subjects.messages_in_room_and_others % {
"room": reason["room_name"], "room": reason["room_name"],
"app": self.app_name, "app": self.app_name,
} }
else:
room_id = reason["room_id"] # If the reason room doesn't have a name, say who the messages
return await self.make_summary_text_from_member_events(
room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events
)
async def make_summary_text_from_member_events(
self,
room_id: str,
notifs: List[Dict[str, Any]],
room_state_ids: StateMap[str],
notif_events: Dict[str, EventBase],
) -> str:
"""
Make a summary text for the email when only a single room has notifications.
Args:
room_id: The ID of the room.
notifs: The notifications for this room.
room_state_ids: The state map for the room.
notif_events: A map of event ID -> notification event.
Returns:
The summary text.
"""
# If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room" # are from explicitly to avoid, "messages in the Bob room"
sender_ids = {notif_events[n["event_id"]].sender for n in notifs} room_id = reason["room_id"]
sender_ids = list(
{
notif_events[n["event_id"]].sender
for n in notifs_by_room[room_id]
}
)
member_events = await self.store.get_events( member_events = await self.store.get_events(
[room_state_ids[("m.room.member", s)] for s in sender_ids] [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
) )
# There was a single sender.
if len(sender_ids) == 1:
return self.email_subjects.messages_from_person % {
"person": descriptor_from_member_events(member_events.values()),
"app": self.app_name,
}
# There was more than one sender, use the first one and a tweaked template.
return self.email_subjects.messages_from_person_and_others % { return self.email_subjects.messages_from_person_and_others % {
"person": descriptor_from_member_events(list(member_events.values())[:1]), "person": descriptor_from_member_events(member_events.values()),
"app": self.app_name, "app": self.app_name,
} }
@@ -719,15 +668,6 @@ class Mailer:
def safe_markup(raw_html: str) -> jinja2.Markup: def safe_markup(raw_html: str) -> jinja2.Markup:
"""
Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
Args
raw_html: Unsafe HTML.
Returns:
A Markup object ready to safely use in a Jinja template.
"""
return jinja2.Markup( return jinja2.Markup(
bleach.linkify( bleach.linkify(
bleach.clean( bleach.clean(
@@ -744,13 +684,8 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
def safe_text(raw_text: str) -> jinja2.Markup: def safe_text(raw_text: str) -> jinja2.Markup:
""" """
Sanitise text (escape any HTML tags), and then linkify any bare URLs. Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it.
Args
raw_text: Unsafe text which might include HTML markup.
Returns:
A Markup object ready to safely use in a Jinja template.
""" """
return jinja2.Markup( return jinja2.Markup(
bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False)) bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))

View File

@@ -17,7 +17,7 @@ import logging
import re import re
from typing import TYPE_CHECKING, Dict, Iterable, Optional from typing import TYPE_CHECKING, Dict, Iterable, Optional
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import StateMap from synapse.types import StateMap
@@ -63,7 +63,7 @@ async def calculate_room_name(
m_room_name = await store.get_event( m_room_name = await store.get_event(
room_state_ids[(EventTypes.Name, "")], allow_none=True room_state_ids[(EventTypes.Name, "")], allow_none=True
) )
if m_room_name and m_room_name.content and m_room_name.content.get("name"): if m_room_name and m_room_name.content and m_room_name.content["name"]:
return m_room_name.content["name"] return m_room_name.content["name"]
# does it have a canonical alias? # does it have a canonical alias?
@@ -74,11 +74,15 @@ async def calculate_room_name(
if ( if (
canon_alias canon_alias
and canon_alias.content and canon_alias.content
and canon_alias.content.get("alias") and canon_alias.content["alias"]
and _looks_like_an_alias(canon_alias.content["alias"]) and _looks_like_an_alias(canon_alias.content["alias"])
): ):
return canon_alias.content["alias"] return canon_alias.content["alias"]
# at this point we're going to need to search the state by all state keys
# for an event type, so rearrange the data structure
room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
if not fallback_to_members: if not fallback_to_members:
return None return None
@@ -90,7 +94,7 @@ async def calculate_room_name(
if ( if (
my_member_event is not None my_member_event is not None
and my_member_event.content.get("membership") == Membership.INVITE and my_member_event.content["membership"] == "invite"
): ):
if (EventTypes.Member, my_member_event.sender) in room_state_ids: if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = await store.get_event( inviter_member_event = await store.get_event(
@@ -107,10 +111,6 @@ async def calculate_room_name(
else: else:
return "Room Invite" return "Room Invite"
# at this point we're going to need to search the state by all state keys
# for an event type, so rearrange the data structure
room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
# we're going to have to generate a name based on who's in the room, # we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user. # so find out who is in the room that isn't the user.
if EventTypes.Member in room_state_bytype_ids: if EventTypes.Member in room_state_bytype_ids:
@@ -120,8 +120,8 @@ async def calculate_room_name(
all_members = [ all_members = [
ev ev
for ev in member_events.values() for ev in member_events.values()
if ev.content.get("membership") == Membership.JOIN if ev.content["membership"] == "join"
or ev.content.get("membership") == Membership.INVITE or ev.content["membership"] == "invite"
] ]
# Sort the member events oldest-first so the we name people in the # Sort the member events oldest-first so the we name people in the
# order the joined (it should at least be deterministic rather than # order the joined (it should at least be deterministic rather than
@@ -194,7 +194,11 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
def name_from_member_event(member_event: EventBase) -> str: def name_from_member_event(member_event: EventBase) -> str:
if member_event.content and member_event.content.get("displayname"): if (
member_event.content
and "displayname" in member_event.content
and member_event.content["displayname"]
):
return member_event.content["displayname"] return member_event.content["displayname"]
return member_event.state_key return member_event.state_key

View File

@@ -86,8 +86,8 @@ REQUIREMENTS = [
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
# we use execute_values with the fetch param, which arrived in psycopg 2.8. # we use execute_batch, which arrived in psycopg 2.7.
"postgres": ["psycopg2>=2.8"], "postgres": ["psycopg2>=2.7"],
# ACME support is required to provision TLS certificates from authorities # ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt. # that use the protocol, such as Let's Encrypt.
"acme": [ "acme": [

View File

@@ -1,105 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Optional
from prometheus_client import Counter
from synapse.logging.context import make_deferred_yieldable
from synapse.util import json_decoder, json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
set_counter = Counter(
"synapse_external_cache_set",
"Number of times we set a cache",
labelnames=["cache_name"],
)
get_counter = Counter(
"synapse_external_cache_get",
"Number of times we get a cache",
labelnames=["cache_name", "hit"],
)
logger = logging.getLogger(__name__)
class ExternalCache:
"""A cache backed by an external Redis. Does nothing if no Redis is
configured.
"""
def __init__(self, hs: "HomeServer"):
self._redis_connection = hs.get_outbound_redis_connection()
def _get_redis_key(self, cache_name: str, key: str) -> str:
return "cache_v1:%s:%s" % (cache_name, key)
def is_enabled(self) -> bool:
"""Whether the external cache is used or not.
It's safe to use the cache when this returns false, the methods will
just no-op, but the function is useful to avoid doing unnecessary work.
"""
return self._redis_connection is not None
async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
"""Add the key/value to the named cache, with the expiry time given.
"""
if self._redis_connection is None:
return
set_counter.labels(cache_name).inc()
# txredisapi requires the value to be string, bytes or numbers, so we
# encode stuff in JSON.
encoded_value = json_encoder.encode(value)
logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
return await make_deferred_yieldable(
self._redis_connection.set(
self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
)
)
async def get(self, cache_name: str, key: str) -> Optional[Any]:
"""Look up a key/value in the named cache.
"""
if self._redis_connection is None:
return None
result = await make_deferred_yieldable(
self._redis_connection.get(self._get_redis_key(cache_name, key))
)
logger.debug("Got cache result %s %s: %r", cache_name, key, result)
get_counter.labels(cache_name, result is not None).inc()
if not result:
return None
# For some reason the integers get magically converted back to integers
if isinstance(result, int):
return result
return json_decoder.decode(result)

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Dict, Dict,
@@ -64,9 +63,6 @@ from synapse.replication.tcp.streams import (
TypingStream, TypingStream,
) )
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -92,7 +88,7 @@ class ReplicationCommandHandler:
back out to connections. back out to connections.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
self._replication_data_handler = hs.get_replication_data_handler() self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore() self._store = hs.get_datastore()
@@ -286,6 +282,13 @@ class ReplicationCommandHandler:
if hs.config.redis.redis_enabled: if hs.config.redis.redis_enabled:
from synapse.replication.tcp.redis import ( from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory, RedisDirectTcpReplicationClientFactory,
lazyConnection,
)
logger.info(
"Connecting to redis (host=%r port=%r)",
hs.config.redis_host,
hs.config.redis_port,
) )
# First let's ensure that we have a ReplicationStreamer started. # First let's ensure that we have a ReplicationStreamer started.
@@ -296,7 +299,13 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called). # connection after SUBSCRIBE is called).
# First create the connection for sending commands. # First create the connection for sending commands.
outbound_redis_connection = hs.get_outbound_redis_connection() outbound_redis_connection = lazyConnection(
reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
reconnect=True,
)
# Now create the factory/connection for the subscription stream. # Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory( self._factory = RedisDirectTcpReplicationClientFactory(

View File

@@ -15,7 +15,7 @@
import logging import logging
from inspect import isawaitable from inspect import isawaitable
from typing import TYPE_CHECKING, Optional, Type, cast from typing import TYPE_CHECKING, Optional
import txredisapi import txredisapi
@@ -23,7 +23,6 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import ( from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext, BackgroundProcessLoggingContext,
run_as_background_process, run_as_background_process,
wrap_as_background_process,
) )
from synapse.replication.tcp.commands import ( from synapse.replication.tcp.commands import (
Command, Command,
@@ -60,16 +59,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
immediately after initialisation. immediately after initialisation.
Attributes: Attributes:
synapse_handler: The command handler to handle incoming commands. handler: The command handler to handle incoming commands.
synapse_stream_name: The *redis* stream name to subscribe to and publish stream_name: The *redis* stream name to subscribe to and publish from
from (not anything to do with Synapse replication streams). (not anything to do with Synapse replication streams).
synapse_outbound_redis_connection: The connection to redis to use to send outbound_redis_connection: The connection to redis to use to send
commands. commands.
""" """
synapse_handler = None # type: ReplicationCommandHandler handler = None # type: ReplicationCommandHandler
synapse_stream_name = None # type: str stream_name = None # type: str
synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol outbound_redis_connection = None # type: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -89,19 +88,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# it's important to make sure that we only send the REPLICATE command once we # it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the # have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end. # POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name) logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name)) await make_deferred_yieldable(self.subscribe(self.stream_name))
logger.info( logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command" "Successfully subscribed to redis stream, sending REPLICATE command"
) )
self.synapse_handler.new_connection(self) self.handler.new_connection(self)
await self._async_send_command(ReplicateCommand()) await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent") logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the # We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the # other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE. # otherside won't know we've connected and so won't issue a REPLICATE.
self.synapse_handler.send_positions_to_connection(self) self.handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str): def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis. """Received a message from redis.
@@ -138,7 +137,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
cmd: received command cmd: received command
""" """
cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None) cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if not cmd_func: if not cmd_func:
logger.warning("Unhandled command: %r", cmd) logger.warning("Unhandled command: %r", cmd)
return return
@@ -156,7 +155,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionLost(self, reason): def connectionLost(self, reason):
logger.info("Lost connection to redis") logger.info("Lost connection to redis")
super().connectionLost(reason) super().connectionLost(reason)
self.synapse_handler.lost_connection(self) self.handler.lost_connection(self)
# mark the logging context as finished # mark the logging context as finished
self._logging_context.__exit__(None, None, None) self._logging_context.__exit__(None, None, None)
@@ -184,54 +183,11 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
await make_deferred_yieldable( await make_deferred_yieldable(
self.synapse_outbound_redis_connection.publish( self.outbound_redis_connection.publish(self.stream_name, encoded_string)
self.synapse_stream_name, encoded_string
)
) )
class SynapseRedisFactory(txredisapi.RedisFactory): class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
"""A subclass of RedisFactory that periodically sends pings to ensure that
we detect dead connections.
"""
def __init__(
self,
hs: "HomeServer",
uuid: str,
dbid: Optional[int],
poolsize: int,
isLazy: bool = False,
handler: Type = txredisapi.ConnectionHandler,
charset: str = "utf-8",
password: Optional[str] = None,
replyTimeout: int = 30,
convertNumbers: Optional[int] = True,
):
super().__init__(
uuid=uuid,
dbid=dbid,
poolsize=poolsize,
isLazy=isLazy,
handler=handler,
charset=charset,
password=password,
replyTimeout=replyTimeout,
convertNumbers=convertNumbers,
)
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
@wrap_as_background_process("redis_ping")
async def _send_ping(self):
for connection in self.pool:
try:
await make_deferred_yieldable(connection.ping())
except Exception:
logger.warning("Failed to send ping to a redis connection")
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately """This is a reconnecting factory that connects to redis and immediately
subscribes to a stream. subscribes to a stream.
@@ -250,62 +206,65 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
): ):
super().__init__( super().__init__()
hs,
uuid="subscriber",
dbid=None,
poolsize=1,
replyTimeout=30,
password=hs.config.redis.redis_password,
)
self.synapse_handler = hs.get_tcp_replication() # This sets the password on the RedisFactory base class (as
self.synapse_stream_name = hs.hostname # SubscriberFactory constructor doesn't pass it through).
self.password = hs.config.redis.redis_password
self.synapse_outbound_redis_connection = outbound_redis_connection self.handler = hs.get_tcp_replication()
self.stream_name = hs.hostname
self.outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr): def buildProtocol(self, addr):
p = super().buildProtocol(addr) p = super().buildProtocol(addr) # type: RedisSubscriber
p = cast(RedisSubscriber, p)
# We do this here rather than add to the constructor of `RedisSubcriber` # We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however # as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the # the base method does some other things than just instantiating the
# protocol. # protocol.
p.synapse_handler = self.synapse_handler p.handler = self.handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection p.outbound_redis_connection = self.outbound_redis_connection
p.synapse_stream_name = self.synapse_stream_name p.stream_name = self.stream_name
p.password = self.password
return p return p
def lazyConnection( def lazyConnection(
hs: "HomeServer", reactor,
host: str = "localhost", host: str = "localhost",
port: int = 6379, port: int = 6379,
dbid: Optional[int] = None, dbid: Optional[int] = None,
reconnect: bool = True, reconnect: bool = True,
charset: str = "utf-8",
password: Optional[str] = None, password: Optional[str] = None,
replyTimeout: int = 30, connectTimeout: Optional[int] = None,
replyTimeout: Optional[int] = None,
convertNumbers: bool = True,
) -> txredisapi.RedisProtocol: ) -> txredisapi.RedisProtocol:
"""Creates a connection to Redis that is lazily set up and reconnects if the """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
connections is lost. reactor.
""" """
isLazy = True
poolsize = 1
uuid = "%s:%d" % (host, port) uuid = "%s:%d" % (host, port)
factory = SynapseRedisFactory( factory = txredisapi.RedisFactory(
hs, uuid,
uuid=uuid, dbid,
dbid=dbid, poolsize,
poolsize=1, isLazy,
isLazy=True, txredisapi.ConnectionHandler,
handler=txredisapi.ConnectionHandler, charset,
password=password, password,
replyTimeout=replyTimeout, replyTimeout,
convertNumbers,
) )
factory.continueTrying = reconnect factory.continueTrying = reconnect
for x in range(poolsize):
reactor = hs.get_reactor() reactor.connectTCP(host, port, factory, connectTimeout)
reactor.connectTCP(host, port, factory, 30)
return factory.handler return factory.handler

View File

@@ -1,88 +0,0 @@
body {
font-family: "Inter", "Helvetica", "Arial", sans-serif;
font-size: 14px;
color: #17191C;
}
header {
max-width: 480px;
width: 100%;
margin: 24px auto;
text-align: center;
}
header p {
color: #737D8C;
line-height: 24px;
}
h1 {
font-size: 24px;
}
.error_page h1 {
color: #FE2928;
}
h2 {
font-size: 14px;
}
h2 img {
vertical-align: middle;
margin-right: 8px;
width: 24px;
height: 24px;
}
label {
cursor: pointer;
}
main {
max-width: 360px;
width: 100%;
margin: 24px auto;
}
.primary-button {
border: none;
text-decoration: none;
padding: 12px;
color: white;
background-color: #418DED;
font-weight: bold;
display: block;
border-radius: 12px;
width: 100%;
box-sizing: border-box;
margin: 16px 0;
cursor: pointer;
text-align: center;
}
.profile {
display: flex;
justify-content: center;
margin: 24px 0;
}
.profile .avatar {
width: 36px;
height: 36px;
border-radius: 100%;
display: block;
margin-right: 8px;
}
.profile .display-name {
font-weight: bold;
margin-bottom: 4px;
}
.profile .user-id {
color: #737D8C;
}
.profile .display-name, .profile .user-id {
line-height: 18px;
}

View File

@@ -3,22 +3,8 @@
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<title>SSO account deactivated</title> <title>SSO account deactivated</title>
<meta name="viewport" content="width=device-width, user-scalable=no">
<style type="text/css">
{% include "sso.css" without context %}
</style>
</head> </head>
<body class="error_page"> <body>
<header> <p>This account has been deactivated.</p>
<h1>Your account has been deactivated</h1>
<p>
<strong>No account found</strong>
</p>
<p>
Your account might have been deactivated by the server administrator.
You can either try to create a new account or contact the servers
administrator.
</p>
</header>
</body> </body>
</html> </html>

View File

@@ -1,138 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>Synapse Login</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, user-scalable=no">
<style type="text/css">
{% include "sso.css" without context %}
.username_input {
display: flex;
border: 2px solid #418DED;
border-radius: 8px;
padding: 12px;
position: relative;
margin: 16px 0;
align-items: center;
font-size: 12px;
}
.username_input label {
position: absolute;
top: -8px;
left: 14px;
font-size: 80%;
background: white;
padding: 2px;
}
.username_input input {
flex: 1;
display: block;
min-width: 0;
border: none;
}
.username_input div {
color: #8D99A5;
}
.idp-pick-details {
border: 1px solid #E9ECF1;
border-radius: 8px;
margin: 24px 0;
}
.idp-pick-details h2 {
margin: 0;
padding: 8px 12px;
}
.idp-pick-details .idp-detail {
border-top: 1px solid #E9ECF1;
padding: 12px;
}
.idp-pick-details .check-row {
display: flex;
align-items: center;
}
.idp-pick-details .check-row .name {
flex: 1;
}
.idp-pick-details .use, .idp-pick-details .idp-value {
color: #737D8C;
}
.idp-pick-details .idp-value {
margin: 0;
margin-top: 8px;
}
.idp-pick-details .avatar {
width: 53px;
height: 53px;
border-radius: 100%;
display: block;
margin-top: 8px;
}
</style>
</head>
<body>
<header>
<h1>Your account is nearly ready</h1>
<p>Check your details before creating an account on {{ server_name }}</p>
</header>
<main>
<form method="post" class="form__input" id="form">
<div class="username_input">
<label for="field-username">Username</label>
<div class="prefix">@</div>
<input type="text" name="username" id="field-username" autofocus required pattern="[a-z0-9\-=_\/\.]+">
<div class="postfix">:{{ server_name }}</div>
</div>
<input type="submit" value="Continue" class="primary-button">
{% if user_attributes %}
<section class="idp-pick-details">
<h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2>
{% if user_attributes.avatar_url %}
<div class="idp-detail idp-avatar">
<div class="check-row">
<label for="idp-avatar" class="name">Avatar</label>
<label for="idp-avatar" class="use">Use</label>
<input type="checkbox" name="use_avatar" id="idp-avatar" value="true" checked>
</div>
<img src="{{ user_attributes.avatar_url }}" class="avatar" />
</div>
{% endif %}
{% if user_attributes.display_name %}
<div class="idp-detail">
<div class="check-row">
<label for="idp-displayname" class="name">Display name</label>
<label for="idp-displayname" class="use">Use</label>
<input type="checkbox" name="use_display_name" id="idp-displayname" value="true" checked>
</div>
<p class="idp-value">{{ user_attributes.display_name }}</p>
</div>
{% endif %}
{% for email in user_attributes.emails %}
<div class="idp-detail">
<div class="check-row">
<label for="idp-email{{ loop.index }}" class="name">E-mail</label>
<label for="idp-email{{ loop.index }}" class="use">Use</label>
<input type="checkbox" name="use_email" id="idp-email{{ loop.index }}" value="{{ email }}" checked>
</div>
<p class="idp-value">{{ email }}</p>
</div>
{% endfor %}
</section>
{% endif %}
</form>
</main>
<script type="text/javascript">
{% include "sso_auth_account_details.js" without context %}
</script>
</body>
</html>

View File

@@ -1,76 +0,0 @@
const usernameField = document.getElementById("field-username");
function throttle(fn, wait) {
let timeout;
return function() {
const args = Array.from(arguments);
if (timeout) {
clearTimeout(timeout);
}
timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait);
}
}
function checkUsernameAvailable(username) {
let check_uri = 'check?username=' + encodeURIComponent(username);
return fetch(check_uri, {
// include the cookie
"credentials": "same-origin",
}).then((response) => {
if(!response.ok) {
// for non-200 responses, raise the body of the response as an exception
return response.text().then((text) => { throw new Error(text); });
} else {
return response.json();
}
}).then((json) => {
if(json.error) {
return {message: json.error};
} else if(json.available) {
return {available: true};
} else {
return {message: username + " is not available, please choose another."};
}
});
}
function validateUsername(username) {
usernameField.setCustomValidity("");
if (usernameField.validity.valueMissing) {
usernameField.setCustomValidity("Please provide a username");
return;
}
if (usernameField.validity.patternMismatch) {
usernameField.setCustomValidity("Invalid username, please only use " + allowedCharactersString);
return;
}
usernameField.setCustomValidity("Checking if username is available …");
throttledCheckUsernameAvailable(username);
}
const throttledCheckUsernameAvailable = throttle(function(username) {
const handleError = function(err) {
// don't prevent form submission on error
usernameField.setCustomValidity("");
console.log(err.message);
};
try {
checkUsernameAvailable(username).then(function(result) {
if (!result.available) {
usernameField.setCustomValidity(result.message);
usernameField.reportValidity();
} else {
usernameField.setCustomValidity("");
}
}, handleError);
} catch (err) {
handleError(err);
}
}, 500);
usernameField.addEventListener("input", function(evt) {
validateUsername(usernameField.value);
});
usernameField.addEventListener("change", function(evt) {
validateUsername(usernameField.value);
});

View File

@@ -1,25 +1,18 @@
<!DOCTYPE html> <html>
<html lang="en">
<head> <head>
<meta charset="UTF-8"> <title>Authentication Failed</title>
<title>Authentication failed</title>
<meta name="viewport" content="width=device-width, user-scalable=no">
<style type="text/css">
{% include "sso.css" without context %}
</style>
</head> </head>
<body class="error_page"> <body>
<header> <div>
<h1>That doesn't look right</h1>
<p> <p>
<strong>We were unable to validate your {{ server_name }} account</strong> We were unable to validate your <tt>{{server_name | e}}</tt> account via
via single&nbsp;sign&#8209;on&nbsp;(SSO), because the SSO Identity single-sign-on (SSO), because the SSO Identity Provider returned
Provider returned different details than when you logged in. different details than when you logged in.
</p> </p>
<p> <p>
Try the operation again, and ensure that you use the same details on Try the operation again, and ensure that you use the same details on
the Identity Provider as when you log into your account. the Identity Provider as when you log into your account.
</p> </p>
</header> </div>
</body> </body>
</html> </html>

View File

@@ -1,28 +1,14 @@
<!DOCTYPE html> <html>
<html lang="en">
<head> <head>
<meta charset="UTF-8">
<title>Authentication</title> <title>Authentication</title>
<meta name="viewport" content="width=device-width, user-scalable=no">
<style type="text/css">
{% include "sso.css" without context %}
</style>
</head> </head>
<body> <body>
<header> <div>
<h1>Confirm it's you to continue</h1>
<p> <p>
A client is trying to {{ description }}. To confirm this action A client is trying to {{ description | e }}. To confirm this action,
re-authorize your account with single sign-on. <a href="{{ redirect_url | e }}">re-authenticate with single sign-on</a>.
If you did not expect this, your account may be compromised!
</p> </p>
<p><strong> </div>
If you did not expect this, your account may be compromised.
</strong></p>
</header>
<main>
<a href="{{ redirect_url }}" class="primary-button">
Continue with {{ idp.idp_name }}
</a>
</main>
</body> </body>
</html> </html>

View File

@@ -1,12 +1,6 @@
<!DOCTYPE html> <html>
<html lang="en">
<head> <head>
<meta charset="UTF-8"> <title>Authentication Successful</title>
<title>Authentication successful</title>
<meta name="viewport" content="width=device-width, user-scalable=no">
<style type="text/css">
{% include "sso.css" without context %}
</style>
<script> <script>
if (window.onAuthDone) { if (window.onAuthDone) {
window.onAuthDone(); window.onAuthDone();
@@ -16,12 +10,9 @@
</script> </script>
</head> </head>
<body> <body>
<header> <div>
<h1>Thank you</h1> <p>Thank you</p>
<p> <p>You may now close this window and return to the application</p>
Now we know its you, you can close this window and return to the </div>
application.
</p>
</header>
</body> </body>
</html> </html>

View File

@@ -2,42 +2,27 @@
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<title>Authentication failed</title> <title>SSO error</title>
<meta name="viewport" content="width=device-width, user-scalable=no">
<style type="text/css">
{% include "sso.css" without context %}
#error_code {
margin-top: 56px;
}
</style>
</head> </head>
<body class="error_page"> <body>
{# If an error of unauthorised is returned it means we have actively rejected their login #} {# If an error of unauthorised is returned it means we have actively rejected their login #}
{% if error == "unauthorised" %} {% if error == "unauthorised" %}
<header>
<p>You are not allowed to log in here.</p> <p>You are not allowed to log in here.</p>
</header>
{% else %} {% else %}
<header>
<h1>There was an error</h1>
<p> <p>
<strong id="errormsg">{{ error_description }}</strong> There was an error during authentication:
</p> </p>
<div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div>
<p> <p>
If you are seeing this page after clicking a link sent to you via email, If you are seeing this page after clicking a link sent to you via email, make
make sure you only click the confirmation link once, and that you open sure you only click the confirmation link once, and that you open the
the validation link in the same client you're logging in from. validation link in the same client you're logging in from.
</p> </p>
<p> <p>
Try logging in again from your Matrix client and if the problem persists Try logging in again from your Matrix client and if the problem persists
please contact the server's administrator. please contact the server's administrator.
</p> </p>
<div id="error_code"> <p>Error: <code>{{ error }}</code></p>
<p><strong>Error code</strong></p>
<p>{{ error }}</p>
</div>
</header>
<script type="text/javascript"> <script type="text/javascript">
// Error handling to support Auth0 errors that we might get through a GET request // Error handling to support Auth0 errors that we might get through a GET request

View File

@@ -3,20 +3,20 @@
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<link rel="stylesheet" href="/_matrix/static/client/login/style.css"> <link rel="stylesheet" href="/_matrix/static/client/login/style.css">
<title>{{ server_name }} Login</title> <title>{{server_name | e}} Login</title>
</head> </head>
<body> <body>
<div id="container"> <div id="container">
<h1 id="title">{{ server_name }} Login</h1> <h1 id="title">{{server_name | e}} Login</h1>
<div class="login_flow"> <div class="login_flow">
<p>Choose one of the following identity providers:</p> <p>Choose one of the following identity providers:</p>
<form> <form>
<input type="hidden" name="redirectUrl" value="{{ redirect_url }}"> <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
<ul class="radiobuttons"> <ul class="radiobuttons">
{% for p in providers %} {% for p in providers %}
<li> <li>
<input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}"> <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
<label for="prov{{ loop.index }}">{{ p.idp_name }}</label> <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
{% if p.idp_icon %} {% if p.idp_icon %}
<img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/> <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
{% endif %} {% endif %}

View File

@@ -1,39 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>SSO redirect confirmation</title>
<meta name="viewport" content="width=device-width, user-scalable=no">
<style type="text/css">
{% include "sso.css" without context %}
#consent_form {
margin-top: 56px;
}
</style>
</head>
<body>
<header>
<h1>Your account is nearly ready</h1>
<p>Agree to the terms to create your account.</p>
</header>
<main>
<!-- {% if user_profile.avatar_url and user_profile.display_name %} -->
<div class="profile">
<img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
<div class="profile-details">
<div class="display-name">{{ user_profile.display_name }}</div>
<div class="user-id">{{ user_id }}</div>
</div>
</div>
<!-- {% endif %} -->
<form method="post" action="{{my_url}}" id="consent_form">
<p>
<input id="accepted_version" type="checkbox" name="accepted_version" value="{{ consent_version }}" required>
<label for="accepted_version">I have read and agree to the <a href="{{ terms_url }}" target="_blank">terms and conditions</a>.</label>
</p>
<input type="submit" class="primary-button" value="Continue"/>
</form>
</main>
</body>
</html>

View File

@@ -3,34 +3,12 @@
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<title>SSO redirect confirmation</title> <title>SSO redirect confirmation</title>
<meta name="viewport" content="width=device-width, user-scalable=no">
<style type="text/css">
{% include "sso.css" without context %}
</style>
</head> </head>
<body> <body>
<header> <p>The application at <span style="font-weight:bold">{{ display_url | e }}</span> is requesting full access to your <span style="font-weight:bold">{{ server_name }}</span> Matrix account.</p>
{% if new_user %} <p>If you don't recognise this address, you should ignore this and close this tab.</p>
<h1>Your account is now ready</h1> <p>
<p>You've made your account on {{ server_name }}.</p> <a href="{{ redirect_url | e }}">I trust this address</a>
{% else %} </p>
<h1>Log in</h1>
{% endif %}
<p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p>
</header>
<main>
{% if user_profile.avatar_url %}
<div class="profile">
<img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
<div class="profile-details">
{% if user_profile.display_name %}
<div class="display-name">{{ user_profile.display_name }}</div>
{% endif %}
<div class="user-id">{{ user_id }}</div>
</div>
</div>
{% endif %}
<a href="{{ redirect_url }}" class="primary-button">Continue</a>
</main>
</body> </body>
</html> </html>

View File

@@ -0,0 +1,19 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>Synapse Login</title>
<link rel="stylesheet" href="style.css" type="text/css" />
</head>
<body>
<div class="card">
<form method="post" class="form__input" id="form" action="submit">
<label for="field-username">Please pick your username:</label>
<input type="text" name="username" id="field-username" autofocus="">
<input type="submit" class="button button--full-width" id="button-submit" value="Submit">
</form>
<!-- this is used for feedback -->
<div role=alert class="tooltip hidden" id="message"></div>
<script src="script.js"></script>
</div>
</body>
</html>

View File

@@ -0,0 +1,95 @@
let inputField = document.getElementById("field-username");
let inputForm = document.getElementById("form");
let submitButton = document.getElementById("button-submit");
let message = document.getElementById("message");
// Submit username and receive response
function showMessage(messageText) {
// Unhide the message text
message.classList.remove("hidden");
message.textContent = messageText;
};
function doSubmit() {
showMessage("Success. Please wait a moment for your browser to redirect.");
// remove the event handler before re-submitting the form.
delete inputForm.onsubmit;
inputForm.submit();
}
function onResponse(response) {
// Display message
showMessage(response);
// Enable submit button and input field
submitButton.classList.remove('button--disabled');
submitButton.value = "Submit";
};
let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]");
function usernameIsValid(username) {
return !allowedUsernameCharacters.test(username);
}
let allowedCharactersString = "lowercase letters, digits, ., _, -, /, =";
function buildQueryString(params) {
return Object.keys(params)
.map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k]))
.join('&');
}
function submitUsername(username) {
if(username.length == 0) {
onResponse("Please enter a username.");
return;
}
if(!usernameIsValid(username)) {
onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString);
return;
}
// if this browser doesn't support fetch, skip the availability check.
if(!window.fetch) {
doSubmit();
return;
}
let check_uri = 'check?' + buildQueryString({"username": username});
fetch(check_uri, {
// include the cookie
"credentials": "same-origin",
}).then((response) => {
if(!response.ok) {
// for non-200 responses, raise the body of the response as an exception
return response.text().then((text) => { throw text; });
} else {
return response.json();
}
}).then((json) => {
if(json.error) {
throw json.error;
} else if(json.available) {
doSubmit();
} else {
onResponse("This username is not available, please choose another.");
}
}).catch((err) => {
onResponse("Error checking username availability: " + err);
});
}
function clickSubmit() {
event.preventDefault();
if(submitButton.classList.contains('button--disabled')) { return; }
// Disable submit button and input field
submitButton.classList.add('button--disabled');
// Submit username
submitButton.value = "Checking...";
submitUsername(inputField.value);
};
inputForm.onsubmit = clickSubmit;

View File

@@ -0,0 +1,27 @@
input[type="text"] {
font-size: 100%;
background-color: #ededf0;
border: 1px solid #fff;
border-radius: .2em;
padding: .5em .9em;
display: block;
width: 26em;
}
.button--disabled {
border-color: #fff;
background-color: transparent;
color: #000;
text-transform: none;
}
.hidden {
display: none;
}
.tooltip {
background-color: #f9f9fa;
padding: 1em;
margin: 1em 0;
}

View File

@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd # Copyright 2018-2019 New Vector Ltd
# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -38,13 +36,11 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import ( from synapse.rest.admin.rooms import (
DeleteRoomRestServlet, DeleteRoomRestServlet,
ForwardExtremitiesRestServlet,
JoinRoomAliasServlet, JoinRoomAliasServlet,
ListRoomRestServlet, ListRoomRestServlet,
MakeRoomAdminRestServlet, MakeRoomAdminRestServlet,
RoomMembersRestServlet, RoomMembersRestServlet,
RoomRestServlet, RoomRestServlet,
RoomStateRestServlet,
ShutdownRoomRestServlet, ShutdownRoomRestServlet,
) )
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -55,7 +51,6 @@ from synapse.rest.admin.users import (
PushersRestServlet, PushersRestServlet,
ResetPasswordRestServlet, ResetPasswordRestServlet,
SearchUsersRestServlet, SearchUsersRestServlet,
ShadowBanRestServlet,
UserAdminServlet, UserAdminServlet,
UserMediaRestServlet, UserMediaRestServlet,
UserMembershipRestServlet, UserMembershipRestServlet,
@@ -214,7 +209,6 @@ def register_servlets(hs, http_server):
""" """
register_servlets_for_client_rest_resource(hs, http_server) register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server)
RoomMembersRestServlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server)
@@ -236,8 +230,6 @@ def register_servlets(hs, http_server):
EventReportsRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server)
MakeRoomAdminRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server)
ShadowBanRestServlet(hs).register(http_server)
ForwardExtremitiesRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server): def register_servlets_for_client_rest_resource(hs, http_server):

View File

@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -292,45 +292,6 @@ class RoomMembersRestServlet(RestServlet):
return 200, ret return 200, ret
class RoomStateRestServlet(RestServlet):
"""
Get full state within a room.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
ret = await self.store.get_room(room_id)
if not ret:
raise NotFoundError("Room not found")
event_ids = await self.store.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
room_state = await self._event_serializer.serialize_events(
events.values(),
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_aggregations=False,
)
ret = {"state": room_state}
return 200, ret
class JoinRoomAliasServlet(RestServlet): class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
@@ -470,17 +431,7 @@ class MakeRoomAdminRestServlet(RestServlet):
if not admin_users: if not admin_users:
raise SynapseError(400, "No local admin user in room") raise SynapseError(400, "No local admin user in room")
admin_user_id = None admin_user_id = admin_users[-1]
for admin_user in reversed(admin_users):
if room_state.get((EventTypes.Member, admin_user)):
admin_user_id = admin_user
break
if not admin_user_id:
raise SynapseError(
400, "No local admin user in room",
)
pl_content = power_levels.content pl_content = power_levels.content
else: else:
@@ -548,60 +499,3 @@ class MakeRoomAdminRestServlet(RestServlet):
) )
return 200, {} return 200, {}
class ForwardExtremitiesRestServlet(RestServlet):
"""Allows a server admin to get or clear forward extremities.
Clearing does not require restarting the server.
Clear forward extremities:
DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
Get forward_extremities:
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
"""
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.store = hs.get_datastore()
async def resolve_room_id(self, room_identifier: str) -> str:
"""Resolve to a room ID, if necessary."""
if RoomID.is_valid(room_identifier):
resolved_room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
)
return resolved_room_id
async def on_DELETE(self, request, room_identifier):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier)
deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
return 200, {"deleted": deleted_count}
async def on_GET(self, request, room_identifier):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id)
return 200, {"count": len(extremities), "results": extremities}

View File

@@ -890,39 +890,3 @@ class UserTokenRestServlet(RestServlet):
) )
return 200, {"access_token": token} return 200, {"access_token": token}
class ShadowBanRestServlet(RestServlet):
"""An admin API for shadow-banning a user.
A shadow-banned users receives successful responses to their client-server
API requests, but the events are not propagated into rooms.
Shadow-banning a user should be used as a tool of last resort and may lead
to confusing or broken behaviour for the client.
Example:
POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
{}
200 OK
{}
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, user_id):
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Only local users can be shadow-banned")
await self.store.set_shadow_banned(UserID.from_string(user_id), True)
return 200, {}

View File

@@ -19,8 +19,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider from synapse.http.server import finish_request
from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
parse_json_object_from_request, parse_json_object_from_request,
@@ -61,14 +60,11 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled self.oidc_enabled = hs.config.oidc_enabled
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
self._well_known_builder = WellKnownBuilder(hs) self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter( self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(), clock=hs.get_clock(),
@@ -93,17 +89,8 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE}) flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict flows.append({"type": LoginRestServlet.SSO_TYPE})
# While its valid for us to advertise this login type generally,
if self._msc2858_enabled:
sso_flow["org.matrix.msc2858.identity_providers"] = [
_get_auth_flow_dict_for_idp(idp)
for idp in self._sso_handler.get_identity_providers().values()
]
flows.append(sso_flow)
# While it's valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the # synapse currently only gives out these tokens as part of the
# SSO login flow. # SSO login flow.
# Generally we don't want to advertise login flows that clients # Generally we don't want to advertise login flows that clients
@@ -324,22 +311,8 @@ class LoginRestServlet(RestServlet):
return result return result
def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
"""Return an entry for the login flow dict
Returns an entry suitable for inclusion in "identity_providers" in the
response to GET /_matrix/client/r0/login
"""
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
if idp.idp_icon:
e["icon"] = idp.idp_icon
if idp.idp_brand:
e["brand"] = idp.idp_brand
return e
class SsoRedirectServlet(RestServlet): class SsoRedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they # make sure that the relevant handlers are instantiated, so that they
@@ -351,31 +324,13 @@ class SsoRedirectServlet(RestServlet):
if hs.config.oidc_enabled: if hs.config.oidc_enabled:
hs.get_oidc_handler() hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
def register(self, http_server: HttpServer) -> None: async def on_GET(self, request: SynapseRequest):
super().register(http_server)
if self._msc2858_enabled:
# expose additional endpoint for MSC2858 support
http_server.register_paths(
"GET",
client_patterns(
"/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
releases=(),
unstable=True,
),
self.on_GET,
self.__class__.__name__,
)
async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None
) -> None:
client_redirect_url = parse_string( client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None request, "redirectUrl", required=True, encoding=None
) )
sso_url = await self._sso_handler.handle_redirect_request( sso_url = await self._sso_handler.handle_redirect_request(
request, client_redirect_url, idp_id, request, client_redirect_url
) )
logger.info("Redirecting to %s", sso_url) logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url) request.redirect(sso_url)

View File

@@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet): class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$") PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
@@ -103,8 +103,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid # Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link) assert_valid_next_link(self.hs, next_link)
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
# The email will be sent to the stored address. # The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to # This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after # an email address which is controlled by the attacker but which, after
@@ -381,8 +379,6 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
if next_link: if next_link:
# Raise if the provided next_link value isn't valid # Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link) assert_valid_next_link(self.hs, next_link)
@@ -434,7 +430,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
self.hs = hs self.hs = hs
super().__init__() super().__init__()
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
@@ -462,10 +458,6 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)
if next_link: if next_link:
# Raise if the provided next_link value isn't valid # Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link) assert_valid_next_link(self.hs, next_link)

View File

@@ -126,8 +126,6 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email "email", email
) )
@@ -207,10 +205,6 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn "msisdn", msisdn
) )

Some files were not shown because too many files have changed in this diff Show More