1
0

Compare commits

...

59 Commits

Author SHA1 Message Date
Erik Johnston
11f5579129 Newsfile 2020-03-30 16:39:06 +01:00
Erik Johnston
22fc68762e Move command processing out of transport 2020-03-30 16:38:11 +01:00
Erik Johnston
4f21c33be3 Remove usage of "conn_id" for presence. (#7128)
* Remove `conn_id` usage for UserSyncCommand.

Each tcp replication connection is assigned a "conn_id", which is used
to give an ID to a remotely connected worker. In a redis world, there
will no longer be a one to one mapping between connection and instance,
so instead we need to replace such usages with an ID generated by the
remote instances and included in the replicaiton commands.

This really only effects UserSyncCommand.

* Add CLEAR_USER_SYNCS command that is sent on shutdown.

This should help with the case where a synchrotron gets restarted
gracefully, rather than rely on 5 minute timeout.
2020-03-30 16:37:24 +01:00
David Baker
07569f25d1 Merge pull request #7160 from matrix-org/dbkr/always_send_own_device_list_updates
Always send the user updates to their own device list
2020-03-30 14:34:28 +01:00
Andrew Morgan
104844c1e1 Add explanatory comment 2020-03-30 14:00:11 +01:00
Richard van der Hoff
6486c96b65 Merge pull request #7157 from matrix-org/rev.outbound_device_pokes_tests
Add tests for outbound device pokes
2020-03-30 13:59:07 +01:00
Patrick Cloke
c5f89fba55 Add developer documentation for running a local CAS server (#7147) 2020-03-30 07:28:42 -04:00
David Baker
7406477525 black 2020-03-30 10:18:33 +01:00
David Baker
9fc588e6dc Just add own user ID to the list we track device changes for 2020-03-30 10:11:26 +01:00
Richard van der Hoff
b7da598a61 Always whitelist the login fallback for SSO (#7153)
That fallback sets the redirect URL to itself (so it can process the login
token then return gracefully to the client). This would make it pointless to
ask the user for confirmation, since the URL the confirmation page would be
showing wouldn't be the client's.
2020-03-27 20:24:52 +00:00
Brendan Abolivier
84f7eaed16 Improve the UX of the login fallback when using SSO (#7152)
* Don't show the login forms if we're currently logging in with a
  password or a token.
* Submit directly the SSO login form, showing only a spinner to the
  user, in order to eliminate from the clunkiness of SSO through this
  fallback.
2020-03-27 20:19:54 +00:00
Dirk Klimpel
fb69690761 Admin API to join users to a room. (#7051) 2020-03-27 19:16:43 +00:00
Dirk Klimpel
8327eb9280 Add options to prevent users from changing their profile. (#7096) 2020-03-27 19:15:23 +00:00
txt-file
ae219fb411 update debian installation instructions to recommend installing virtualenv instead of python3-virtualenv (#6892)
* change debian package from python3-virtualenv to virtualenv

The virtualenv package is needed for the virtualenv command. The
virtualenv package depends on python3-virtualenv (at least since
debian jessie) so there is no need to specify python3-virtualenv
additionally.

Signed-off-by: Vieno Hakkerinen <vieno@hakkerinen.eu>

* Add changelog

Co-authored-by: Andrew Morgan <andrew@amorgan.xyz>
2020-03-27 15:02:00 +00:00
Andrew Morgan
12aa5a7fa7 Ensure is_verified on /_matrix/client/r0/room_keys/keys is a boolean (#7150) 2020-03-27 13:30:22 +00:00
David Vo
fbf0782c63 Only import sqlite3 when type checking (#7155)
Fixes: #7127
Signed-off-by: David Vo <david@vovo.id.au>
2020-03-27 13:20:00 +00:00
David Baker
16ee97988a Fix undefined variable & remove debug logging 2020-03-27 12:39:54 +00:00
David Baker
a07e03ce90 black 2020-03-27 12:35:32 +00:00
David Baker
d9965fb8d6 changelog 2020-03-27 12:30:59 +00:00
David Baker
09cc058a4c Always send the user updates to their own device list
This will allow clients to notify users about new devices even if
the user isn't in any rooms (yet).
2020-03-27 12:26:47 +00:00
Richard van der Hoff
665630fcaa Add tests for outbound device pokes 2020-03-27 12:01:37 +00:00
Jason Robinson
7496d3d2f6 Merge pull request #7151 from matrix-org/jaywink/saml-redirect-fix
Allow RedirectResponse in SAML response handler
2020-03-26 22:10:31 +02:00
Patrick Cloke
fa4f12102d Refactor the CAS code (move the logic out of the REST layer to a handler) (#7136) 2020-03-26 15:05:26 -04:00
Jason Robinson
55ca6cf88c Update changelog.d/7151.bugfix
Co-Authored-By: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
2020-03-26 20:35:50 +02:00
Nektarios Katakis
825fb5d0a5 Don't default to an invalid sqlite config if no database configuration is provided (#6573) 2020-03-26 17:13:14 +00:00
Jason Robinson
060e7dce09 Allow RedirectResponse in SAML response handler
Allow custom SAML handlers to redirect after processing an auth response.

Fixes #7149

Signed-off-by: Jason Robinson <jasonr@matrix.org>
2020-03-26 19:02:35 +02:00
Dirk Klimpel
e8e2ddb60a Allow server admins to define and enforce a password policy (MSC2000). (#7118) 2020-03-26 16:51:13 +00:00
Patrick Cloke
1c1242acba Validate that the session is not modified during UI-Auth (#7068) 2020-03-26 07:39:34 -04:00
Aaron Raimist
6ca5e56fd1 Remove unused captcha_bypass_secret option (#7137)
Signed-off-by: Aaron Raimist <aaron@raim.ist>
2020-03-25 17:49:34 +00:00
Erik Johnston
4cff617df1 Move catchup of replication streams to worker. (#7024)
This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date.
2020-03-25 14:54:01 +00:00
Andrew Morgan
7bab642707 Various cleanups to INSTALL.md (#7141) 2020-03-25 13:56:40 +00:00
Erik Johnston
b1cfaf08af Merge pull request #7133 from matrix-org/erikj/fix_worker_startup
Fix starting workers when federation sending not split out.
2020-03-25 09:42:39 +00:00
Richard van der Hoff
28d9d6e8a9 Remove spurious "name" parameter to default_config
this is never set to anything other than "test", and is a source of unnecessary
boilerplate.
2020-03-24 18:33:49 +00:00
Richard van der Hoff
39230d2171 Clean up some LoggingContext stuff (#7120)
* Pull Sentinel out of LoggingContext

... and drop a few unnecessary references to it

* Factor out LoggingContext.current_context

move `current_context` and `set_context` out to top-level functions.

Mostly this means that I can more easily trace what's actually referring to
LoggingContext, but I think it's generally neater.

* move copy-to-parent into `stop`

this really just makes `start` and `stop` more symetric. It also means that it
behaves correctly if you manually `set_log_context` rather than using the
context manager.

* Replace `LoggingContext.alive` with `finished`

Turn `alive` into `finished` and make it a bit better defined.
2020-03-24 14:45:33 +00:00
Naugrimm
1fcf9c6f95 Fix CAS redirect url (#6634)
Build the same service URL when requesting the CAS ticket and when calling the proxyValidate URL.
2020-03-24 11:59:04 +00:00
Erik Johnston
d6828c129f Newsfile 2020-03-24 10:36:44 +00:00
Erik Johnston
c816072d47 Fix starting workers when federation sending not split out. 2020-03-24 10:35:00 +00:00
Patrick Cloke
190ab593b7 Use the proper error code when a canonical alias that does not exist is used. (#7109) 2020-03-23 15:21:54 -04:00
Kartikaya Gupta (kats)
e341518f92 Update pre-built package name for FreeBSD (#7107). (#7107)
Signed-off-by: Kartikaya Gupta <kats@trevize.staktrace.com>
2020-03-23 15:31:02 +00:00
Richard van der Hoff
a564b92d37 Convert *StreamRow classes to inner classes (#7116)
This just helps keep the rows closer to their streams, so that it's easier to
see what the format of each stream is.
2020-03-23 13:59:11 +00:00
Richard van der Hoff
5126cb1253 Merge branch 'master' into develop 2020-03-23 13:54:29 +00:00
Richard van der Hoff
229eb81498 Merge tag 'v1.12.0'
Synapse 1.12.0 (2020-03-23)
===========================

No significant changes since 1.12.0rc1.

Debian packages and Docker images are rebuilt using the latest versions of
dependency libraries, including Twisted 20.3.0. **Please see security advisory
below**.

Security advisory
-----------------

Synapse may be vulnerable to request-smuggling attacks when it is used with a
reverse-proxy. The vulnerabilties are fixed in Twisted 20.3.0, and are
described in
[CVE-2020-10108](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-10108)
and
[CVE-2020-10109](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-10109).
For a good introduction to this class of request-smuggling attacks, see
https://portswigger.net/research/http-desync-attacks-request-smuggling-reborn.

We are not aware of these vulnerabilities being exploited in the wild, and
do not believe that they are exploitable with current versions of any reverse
proxies. Nevertheless, we recommend that all Synapse administrators ensure that
they have the latest versions of the Twisted library to ensure that their
installation remains secure.

* Administrators using the [`matrix.org` Docker
  image](https://hub.docker.com/r/matrixdotorg/synapse/) or the [Debian/Ubuntu
  packages from
  `matrix.org`](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#matrixorg-packages)
  should ensure that they have version 1.12.0 installed: these images include
  Twisted 20.3.0.
* Administrators who have [installed Synapse from
  source](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#installing-from-source)
  should upgrade Twisted within their virtualenv by running:
  ```sh
  <path_to_virtualenv>/bin/pip install 'Twisted>=20.3.0'
  ```
* Administrators who have installed Synapse from distribution packages should
  consult the information from their distributions.

The `matrix.org` Synapse instance was not vulnerable to these vulnerabilities.

Advance notice of change to the default `git` branch for Synapse
----------------------------------------------------------------

Currently, the default `git` branch for Synapse is `master`, which tracks the
latest release.

After the release of Synapse 1.13.0, we intend to change this default to
`develop`, which is the development tip. This is more consistent with common
practice and modern `git` usage.

Although we try to keep `develop` in a stable state, there may be occasions
where regressions creep in. Developers and distributors who have scripts which
run builds using the default branch of `Synapse` should therefore consider
pinning their scripts to `master`.

Synapse 1.12.0rc1 (2020-03-19)
==============================

Features
--------

- Changes related to room alias management ([MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432)):
  - Publishing/removing a room from the room directory now requires the user to have a power level capable of modifying the canonical alias, instead of the room aliases. ([\#6965](https://github.com/matrix-org/synapse/issues/6965))
  - Validate the `alt_aliases` property of canonical alias events. ([\#6971](https://github.com/matrix-org/synapse/issues/6971))
  - Users with a power level sufficient to modify the canonical alias of a room can now delete room aliases. ([\#6986](https://github.com/matrix-org/synapse/issues/6986))
  - Implement updated authorization rules and redaction rules for aliases events, from [MSC2261](https://github.com/matrix-org/matrix-doc/pull/2261) and [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432). ([\#7037](https://github.com/matrix-org/synapse/issues/7037))
  - Stop sending m.room.aliases events during room creation and upgrade. ([\#6941](https://github.com/matrix-org/synapse/issues/6941))
  - Synapse no longer uses room alias events to calculate room names for push notifications. ([\#6966](https://github.com/matrix-org/synapse/issues/6966))
  - The room list endpoint no longer returns a list of aliases. ([\#6970](https://github.com/matrix-org/synapse/issues/6970))
  - Remove special handling of aliases events from [MSC2260](https://github.com/matrix-org/matrix-doc/pull/2260) added in v1.10.0rc1. ([\#7034](https://github.com/matrix-org/synapse/issues/7034))
- Expose the `synctl`, `hash_password` and `generate_config` commands in the snapcraft package. Contributed by @devec0. ([\#6315](https://github.com/matrix-org/synapse/issues/6315))
- Check that server_name is correctly set before running database updates. ([\#6982](https://github.com/matrix-org/synapse/issues/6982))
- Break down monthly active users by `appservice_id` and emit via Prometheus. ([\#7030](https://github.com/matrix-org/synapse/issues/7030))
- Render a configurable and comprehensible error page if something goes wrong during the SAML2 authentication process. ([\#7058](https://github.com/matrix-org/synapse/issues/7058), [\#7067](https://github.com/matrix-org/synapse/issues/7067))
- Add an optional parameter to control whether other sessions are logged out when a user's password is modified. ([\#7085](https://github.com/matrix-org/synapse/issues/7085))
- Add prometheus metrics for the number of active pushers. ([\#7103](https://github.com/matrix-org/synapse/issues/7103), [\#7106](https://github.com/matrix-org/synapse/issues/7106))
- Improve performance when making HTTPS requests to sygnal, sydent, etc, by sharing the SSL context object between connections. ([\#7094](https://github.com/matrix-org/synapse/issues/7094))

Bugfixes
--------

- When a user's profile is updated via the admin API, also generate a displayname/avatar update for that user in each room. ([\#6572](https://github.com/matrix-org/synapse/issues/6572))
- Fix a couple of bugs in email configuration handling. ([\#6962](https://github.com/matrix-org/synapse/issues/6962))
- Fix an issue affecting worker-based deployments where replication would stop working, necessitating a full restart, after joining a large room. ([\#6967](https://github.com/matrix-org/synapse/issues/6967))
- Fix `duplicate key` error which was logged when rejoining a room over federation. ([\#6968](https://github.com/matrix-org/synapse/issues/6968))
- Prevent user from setting 'deactivated' to anything other than a bool on the v2 PUT /users Admin API. ([\#6990](https://github.com/matrix-org/synapse/issues/6990))
- Fix py35-old CI by using native tox package. ([\#7018](https://github.com/matrix-org/synapse/issues/7018))
- Fix a bug causing `org.matrix.dummy_event` to be included in responses from `/sync`. ([\#7035](https://github.com/matrix-org/synapse/issues/7035))
- Fix a bug that renders UTF-8 text files incorrectly when loaded from media. Contributed by @TheStranjer. ([\#7044](https://github.com/matrix-org/synapse/issues/7044))
- Fix a bug that would cause Synapse to respond with an error about event visibility if a client tried to request the state of a room at a given token. ([\#7066](https://github.com/matrix-org/synapse/issues/7066))
- Repair a data-corruption issue which was introduced in Synapse 1.10, and fixed in Synapse 1.11, and which could cause `/sync` to return with 404 errors about missing events and unknown rooms. ([\#7070](https://github.com/matrix-org/synapse/issues/7070))
- Fix a bug causing account validity renewal emails to be sent even if the feature is turned off in some cases. ([\#7074](https://github.com/matrix-org/synapse/issues/7074))

Improved Documentation
----------------------

- Updated CentOS8 install instructions. Contributed by Richard Kellner. ([\#6925](https://github.com/matrix-org/synapse/issues/6925))
- Fix `POSTGRES_INITDB_ARGS` in the `contrib/docker/docker-compose.yml` example docker-compose configuration. ([\#6984](https://github.com/matrix-org/synapse/issues/6984))
- Change date in [INSTALL.md](./INSTALL.md#tls-certificates) for last date of getting TLS certificates to November 2019. ([\#7015](https://github.com/matrix-org/synapse/issues/7015))
- Document that the fallback auth endpoints must be routed to the same worker node as the register endpoints. ([\#7048](https://github.com/matrix-org/synapse/issues/7048))

Deprecations and Removals
-------------------------

- Remove the unused query_auth federation endpoint per [MSC2451](https://github.com/matrix-org/matrix-doc/pull/2451). ([\#7026](https://github.com/matrix-org/synapse/issues/7026))

Internal Changes
----------------

- Add type hints to `logging/context.py`. ([\#6309](https://github.com/matrix-org/synapse/issues/6309))
- Add some clarifications to `README.md` in the database schema directory. ([\#6615](https://github.com/matrix-org/synapse/issues/6615))
- Refactoring work in preparation for changing the event redaction algorithm. ([\#6874](https://github.com/matrix-org/synapse/issues/6874), [\#6875](https://github.com/matrix-org/synapse/issues/6875), [\#6983](https://github.com/matrix-org/synapse/issues/6983), [\#7003](https://github.com/matrix-org/synapse/issues/7003))
- Improve performance of v2 state resolution for large rooms. ([\#6952](https://github.com/matrix-org/synapse/issues/6952), [\#7095](https://github.com/matrix-org/synapse/issues/7095))
- Reduce time spent doing GC, by freezing objects on startup. ([\#6953](https://github.com/matrix-org/synapse/issues/6953))
- Minor perfermance fixes to `get_auth_chain_ids`. ([\#6954](https://github.com/matrix-org/synapse/issues/6954))
- Don't record remote cross-signing keys in the `devices` table. ([\#6956](https://github.com/matrix-org/synapse/issues/6956))
- Use flake8-comprehensions to enforce good hygiene of list/set/dict comprehensions. ([\#6957](https://github.com/matrix-org/synapse/issues/6957))
- Merge worker apps together. ([\#6964](https://github.com/matrix-org/synapse/issues/6964), [\#7002](https://github.com/matrix-org/synapse/issues/7002), [\#7055](https://github.com/matrix-org/synapse/issues/7055), [\#7104](https://github.com/matrix-org/synapse/issues/7104))
- Remove redundant `store_room` call from `FederationHandler._process_received_pdu`. ([\#6979](https://github.com/matrix-org/synapse/issues/6979))
- Update warning for incorrect database collation/ctype to include link to documentation. ([\#6985](https://github.com/matrix-org/synapse/issues/6985))
- Add some type annotations to the database storage classes. ([\#6987](https://github.com/matrix-org/synapse/issues/6987))
- Port `synapse.handlers.presence` to async/await. ([\#6991](https://github.com/matrix-org/synapse/issues/6991), [\#7019](https://github.com/matrix-org/synapse/issues/7019))
- Add some type annotations to the federation base & client classes. ([\#6995](https://github.com/matrix-org/synapse/issues/6995))
- Port `synapse.rest.keys` to async/await. ([\#7020](https://github.com/matrix-org/synapse/issues/7020))
- Add a type check to `is_verified` when processing room keys. ([\#7045](https://github.com/matrix-org/synapse/issues/7045))
- Add type annotations and comments to the auth handler. ([\#7063](https://github.com/matrix-org/synapse/issues/7063))
2020-03-23 13:54:17 +00:00
Richard van der Hoff
b3cee0ce67 Fix processing of groups stream, and use symbolic names for streams (#7117)
`groups` != `receipts`

Introduced in #6964
2020-03-23 11:39:36 +00:00
Dionysis Grigoropoulos
96071eea8f Set Referrer-Policy to no-referrer for media (#7009) 2020-03-23 09:48:28 +00:00
Patrick Cloke
477c4f5b1c Clean-up some auth/login REST code (#7115) 2020-03-20 16:22:47 -04:00
Richard van der Hoff
c165c1233b Improve database configuration docs (#6988)
Attempts to clarify the sample config for databases, and add some stuff about
tcp keepalives to `postgres.md`.
2020-03-20 15:24:22 +00:00
Erik Johnston
fdb1344716 Remove concept of a non-limited stream. (#7011) 2020-03-20 14:40:47 +00:00
Patrick Cloke
caec7d4fa0 Convert some of the media REST code to async/await (#7110) 2020-03-20 07:20:02 -04:00
Patrick Cloke
c2db6599c8 Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors (#7089). 2020-03-19 08:22:56 -04:00
Erik Johnston
a319cb1dd1 Change device list streams to have one row per ID (#7010)
* Add 'device_lists_outbound_pokes' as extra table.

This makes sure we check all the relevant tables to get the current max
stream ID.

Currently not doing so isn't problematic as the max stream ID in
`device_lists_outbound_pokes` is the same as in `device_lists_stream`,
however that will change.

* Change device lists stream to have one row per id.

This will make it possible to process the streams more incrementally,
avoiding having to process large chunks at once.

* Change device list replication to match new semantics.

Instead of sending down batches of user ID/host tuples, send down a row
per entity (user ID or host).

* Newsfile

* Remove handling of multiple rows per ID

* Fix worker handling

* Comments from review
2020-03-19 11:36:53 +00:00
Erik Johnston
6e6476ef07 Comments from review 2020-03-18 10:13:55 +00:00
Richard van der Hoff
4ce50519cd Update postgres.md
fix broken link
2020-03-17 18:08:43 +00:00
Erik Johnston
65a941d1f8 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/fixup_devices_stream 2020-03-02 16:55:55 +00:00
Erik Johnston
e53744c737 Fix worker handling 2020-03-02 12:52:28 +00:00
Erik Johnston
f70f44abc7 Remove handling of multiple rows per ID 2020-02-28 11:45:35 +00:00
Erik Johnston
59ad93d2a4 Newsfile 2020-02-28 11:27:37 +00:00
Erik Johnston
9ce4e344a8 Change device list replication to match new semantics.
Instead of sending down batches of user ID/host tuples, send down a row
per entity (user ID or host).
2020-02-28 11:25:34 +00:00
Erik Johnston
f5caa1864e Change device lists stream to have one row per id.
This will make it possible to process the streams more incrementally,
avoiding having to process large chunks at once.
2020-02-28 11:21:25 +00:00
Erik Johnston
c3c6c0e622 Add 'device_lists_outbound_pokes' as extra table.
This makes sure we check all the relevant tables to get the current max
stream ID.

Currently not doing so isn't problematic as the max stream ID in
`device_lists_outbound_pokes` is the same as in `device_lists_stream`,
however that will change.
2020-02-28 11:15:11 +00:00
150 changed files with 4248 additions and 2164 deletions

View File

@@ -2,7 +2,6 @@
- [Installing Synapse](#installing-synapse)
- [Installing from source](#installing-from-source)
- [Platform-Specific Instructions](#platform-specific-instructions)
- [Troubleshooting Installation](#troubleshooting-installation)
- [Prebuilt packages](#prebuilt-packages)
- [Setting up Synapse](#setting-up-synapse)
- [TLS certificates](#tls-certificates)
@@ -10,6 +9,7 @@
- [Registering a user](#registering-a-user)
- [Setting up a TURN server](#setting-up-a-turn-server)
- [URL previews](#url-previews)
- [Troubleshooting Installation](#troubleshooting-installation)
# Choosing your server name
@@ -70,7 +70,7 @@ pip install -U matrix-synapse
```
Before you can start Synapse, you will need to generate a configuration
file. To do this, run (in your virtualenv, as before)::
file. To do this, run (in your virtualenv, as before):
```
cd ~/synapse
@@ -84,22 +84,24 @@ python -m synapse.app.homeserver \
... substituting an appropriate value for `--server-name`.
This command will generate you a config file that you can then customise, but it will
also generate a set of keys for you. These keys will allow your Home Server to
identify itself to other Home Servers, so don't lose or delete them. It would be
also generate a set of keys for you. These keys will allow your homeserver to
identify itself to other homeserver, so don't lose or delete them. It would be
wise to back them up somewhere safe. (If, for whatever reason, you do need to
change your Home Server's keys, you may find that other Home Servers have the
change your homeserver's keys, you may find that other homeserver have the
old key cached. If you update the signing key, you should change the name of the
key in the `<server name>.signing.key` file (the second word) to something
different. See the
[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys)
for more information on key management.)
for more information on key management).
To actually run your new homeserver, pick a working directory for Synapse to
run (e.g. `~/synapse`), and::
run (e.g. `~/synapse`), and:
cd ~/synapse
source env/bin/activate
synctl start
```
cd ~/synapse
source env/bin/activate
synctl start
```
### Platform-Specific Instructions
@@ -110,7 +112,7 @@ Installing prerequisites on Ubuntu or Debian:
```
sudo apt-get install build-essential python3-dev libffi-dev \
python3-pip python3-setuptools sqlite3 \
libssl-dev python3-virtualenv libjpeg-dev libxslt1-dev
libssl-dev virtualenv libjpeg-dev libxslt1-dev
```
#### ArchLinux
@@ -188,7 +190,7 @@ doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
There is currently no port for OpenBSD. Additionally, OpenBSD's security
settings require a slightly more difficult installation process.
XXX: I suspect this is out of date.
(XXX: I suspect this is out of date)
1. Create a new directory in `/usr/local` called `_synapse`. Also, create a
new user called `_synapse` and set that directory as the new user's home.
@@ -196,7 +198,7 @@ XXX: I suspect this is out of date.
write and execute permissions on the same memory space to be run from
`/usr/local`.
2. `su` to the new `_synapse` user and change to their home directory.
3. Create a new virtualenv: `virtualenv -p python2.7 ~/.synapse`
3. Create a new virtualenv: `virtualenv -p python3 ~/.synapse`
4. Source the virtualenv configuration located at
`/usr/local/_synapse/.synapse/bin/activate`. This is done in `ksh` by
using the `.` command, rather than `bash`'s `source`.
@@ -217,45 +219,6 @@ be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for
Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server
for Windows Server.
### Troubleshooting Installation
XXX a bunch of this is no longer relevant.
Synapse requires pip 8 or later, so if your OS provides too old a version you
may need to manually upgrade it::
sudo pip install --upgrade pip
Installing may fail with `Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)`.
You can fix this by manually upgrading pip and virtualenv::
sudo pip install --upgrade virtualenv
You can next rerun `virtualenv -p python3 synapse` to update the virtual env.
Installing may fail during installing virtualenv with `InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.`
You can fix this by manually installing ndg-httpsclient::
pip install --upgrade ndg-httpsclient
Installing may fail with `mock requires setuptools>=17.1. Aborting installation`.
You can fix this by upgrading setuptools::
pip install --upgrade setuptools
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
pip install twisted
## Prebuilt packages
As an alternative to installing from source, prebuilt packages are available
@@ -314,7 +277,7 @@ For `buster` and `sid`, Synapse is available in the Debian repositories and
it should be possible to install it with simply:
```
sudo apt install matrix-synapse
sudo apt install matrix-synapse
```
There is also a version of `matrix-synapse` in `stretch-backports`. Please see
@@ -375,15 +338,17 @@ sudo pip install py-bcrypt
Synapse can be found in the void repositories as 'synapse':
xbps-install -Su
xbps-install -S synapse
```
xbps-install -Su
xbps-install -S synapse
```
### FreeBSD
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
- Packages: `pkg install py27-matrix-synapse`
- Packages: `pkg install py37-matrix-synapse`
### NixOS
@@ -420,6 +385,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
resources:
- names: [client, federation]
```
* You will also need to uncomment the `tls_certificate_path` and
`tls_private_key_path` lines under the `TLS` section. You can either
point these settings at an existing certificate and key, or you can
@@ -435,7 +401,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
`cert.pem`).
For a more detailed guide to configuring your server for federation, see
[federate.md](docs/federate.md)
[federate.md](docs/federate.md).
## Email
@@ -482,7 +448,7 @@ on your server even if `enable_registration` is `false`.
## Setting up a TURN server
For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
## URL previews
@@ -491,10 +457,24 @@ turn it on you must enable the `url_preview_enabled: True` config parameter
and explicitly specify the IP ranges that Synapse is not allowed to spider for
previewing in the `url_preview_ip_range_blacklist` configuration parameter.
This is critical from a security perspective to stop arbitrary Matrix users
spidering 'internal' URLs on your network. At the very least we recommend that
spidering 'internal' URLs on your network. At the very least we recommend that
your loopback and RFC1918 IP addresses are blacklisted.
This also requires the optional lxml and netaddr python dependencies to be
installed. This in turn requires the libxml2 library to be available - on
This also requires the optional `lxml` and `netaddr` python dependencies to be
installed. This in turn requires the `libxml2` library to be available - on
Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for
your OS.
# Troubleshooting Installation
`pip` seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.:
```
pip install twisted
```
If you have any other problems, feel free to ask in
[#synapse:matrix.org](https://matrix.to/#/#synapse:matrix.org).

1
changelog.d/6573.bugfix Normal file
View File

@@ -0,0 +1 @@
Don't attempt to use an invalid sqlite config if no database configuration is provided. Contributed by @nekatak.

1
changelog.d/6634.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix single-sign on with CAS systems: pass the same service URL when requesting the CAS ticket and when calling the `proxyValidate` URL. Contributed by @Naugrimm.

1
changelog.d/6892.doc Normal file
View File

@@ -0,0 +1 @@
Update Debian installation instructions to recommend installing the `virtualenv` package instead of `python3-virtualenv`.

1
changelog.d/6988.doc Normal file
View File

@@ -0,0 +1 @@
Improve the documentation for database configuration.

1
changelog.d/7009.feature Normal file
View File

@@ -0,0 +1 @@
Set `Referrer-Policy` header to `no-referrer` on media downloads.

1
changelog.d/7010.misc Normal file
View File

@@ -0,0 +1 @@
Change device list streams to have one row per ID.

1
changelog.d/7011.misc Normal file
View File

@@ -0,0 +1 @@
Remove concept of a non-limited stream.

1
changelog.d/7024.misc Normal file
View File

@@ -0,0 +1 @@
Move catchup of replication streams logic to worker.

1
changelog.d/7051.feature Normal file
View File

@@ -0,0 +1 @@
Admin API `POST /_synapse/admin/v1/join/<roomIdOrAlias>` to join users to a room like `auto_join_rooms` for creation of users.

1
changelog.d/7068.bugfix Normal file
View File

@@ -0,0 +1 @@
Ensure that a user inteactive authentication session is tied to a single request.

1
changelog.d/7089.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors.

1
changelog.d/7096.feature Normal file
View File

@@ -0,0 +1 @@
Add options to prevent users from changing their profile or associated 3PIDs.

1
changelog.d/7107.doc Normal file
View File

@@ -0,0 +1 @@
Update pre-built package name for FreeBSD.

1
changelog.d/7109.bugfix Normal file
View File

@@ -0,0 +1 @@
Return the proper error (M_BAD_ALIAS) when a non-existant canonical alias is provided.

1
changelog.d/7110.misc Normal file
View File

@@ -0,0 +1 @@
Convert some of synapse.rest.media to async/await.

1
changelog.d/7115.misc Normal file
View File

@@ -0,0 +1 @@
De-duplicate / remove unused REST code for login and auth.

1
changelog.d/7116.misc Normal file
View File

@@ -0,0 +1 @@
Convert `*StreamRow` classes to inner classes.

1
changelog.d/7117.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug which meant that groups updates were not correctly replicated between workers.

1
changelog.d/7118.feature Normal file
View File

@@ -0,0 +1 @@
Allow server admins to define and enforce a password policy (MSC2000).

1
changelog.d/7120.misc Normal file
View File

@@ -0,0 +1 @@
Clean up some LoggingContext code.

1
changelog.d/7128.misc Normal file
View File

@@ -0,0 +1 @@
Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage.

1
changelog.d/7133.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix starting workers when federation sending not split out.

1
changelog.d/7134.misc Normal file
View File

@@ -0,0 +1 @@
Refactor replication command handling to reduce difference between server vs client.

1
changelog.d/7136.misc Normal file
View File

@@ -0,0 +1 @@
Refactored the CAS authentication logic to a separate class.

1
changelog.d/7137.removal Normal file
View File

@@ -0,0 +1 @@
Remove nonfunctional `captcha_bypass_secret` option from `homeserver.yaml`.

1
changelog.d/7141.doc Normal file
View File

@@ -0,0 +1 @@
Clean up INSTALL.md a bit.

1
changelog.d/7147.doc Normal file
View File

@@ -0,0 +1 @@
Add documentation for running a local CAS server for testing.

1
changelog.d/7150.bugfix Normal file
View File

@@ -0,0 +1 @@
Ensure `is_verified` is a boolean in responses to `GET /_matrix/client/r0/room_keys/keys`. Also warn the user if they forgot the `version` query param.

1
changelog.d/7151.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix error page being shown when a custom SAML handler attempted to redirect when processing an auth response.

1
changelog.d/7152.feature Normal file
View File

@@ -0,0 +1 @@
Improve the support for SSO authentication on the login fallback page.

1
changelog.d/7153.feature Normal file
View File

@@ -0,0 +1 @@
Always whitelist the login fallback in the SSO configuration if `public_baseurl` is set.

1
changelog.d/7155.bugfix Normal file
View File

@@ -0,0 +1 @@
Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo.

1
changelog.d/7157.misc Normal file
View File

@@ -0,0 +1 @@
Add tests for outbound device pokes.

1
changelog.d/7160.feature Normal file
View File

@@ -0,0 +1 @@
Always send users their own device updates.

View File

@@ -0,0 +1,34 @@
# Edit Room Membership API
This API allows an administrator to join an user account with a given `user_id`
to a room with a given `room_id_or_alias`. You can only modify the membership of
local users. The server administrator must be in the room and have permission to
invite users.
## Parameters
The following parameters are available:
* `user_id` - Fully qualified user: for example, `@user:server.com`.
* `room_id_or_alias` - The room identifier or alias to join: for example,
`!636q39766251:server.com`.
## Usage
```
POST /_synapse/admin/v1/join/<room_id_or_alias>
{
"user_id": "@user:server.com"
}
```
Including an `access_token` of a server admin.
Response:
```
{
"room_id": "!636q39766251:server.com"
}
```

64
docs/dev/cas.md Normal file
View File

@@ -0,0 +1,64 @@
# How to test CAS as a developer without a server
The [django-mama-cas](https://github.com/jbittel/django-mama-cas) project is an
easy to run CAS implementation built on top of Django.
## Prerequisites
1. Create a new virtualenv: `python3 -m venv <your virtualenv>`
2. Activate your virtualenv: `source /path/to/your/virtualenv/bin/activate`
3. Install Django and django-mama-cas:
```
python -m pip install "django<3" "django-mama-cas==2.4.0"
```
4. Create a Django project in the current directory:
```
django-admin startproject cas_test .
```
5. Follow the [install directions](https://django-mama-cas.readthedocs.io/en/latest/installation.html#configuring) for django-mama-cas
6. Setup the SQLite database: `python manage.py migrate`
7. Create a user:
```
python manage.py createsuperuser
```
1. Use whatever you want as the username and password.
2. Leave the other fields blank.
8. Use the built-in Django test server to serve the CAS endpoints on port 8000:
```
python manage.py runserver
```
You should now have a Django project configured to serve CAS authentication with
a single user created.
## Configure Synapse (and Riot) to use CAS
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
running Django test server:
```yaml
cas_config:
enabled: true
server_url: "http://localhost:8000"
service_url: "http://localhost:8081"
#displayname_attribute: name
#required_attributes:
# name: value
```
2. Restart Synapse.
Note that the above configuration assumes the homeserver is running on port 8081
and that the CAS server is on port 8000, both on localhost.
## Testing the configuration
Then in Riot:
1. Visit the login page with a Riot pointing at your homeserver.
2. Click the Single Sign-On button.
3. Login using the credentials created with `createsuperuser`.
4. You should be logged in.
If you want to repeat this process you'll need to manually logout first:
1. http://localhost:8000/admin/
2. Click "logout" in the top right.

View File

@@ -18,9 +18,13 @@ To make Synapse (and therefore Riot) use it:
metadata:
local: ["samling.xml"]
```
5. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure
5. Ensure that your `homeserver.yaml` has a setting for `public_baseurl`:
```yaml
public_baseurl: http://localhost:8080/
```
6. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure
the dependencies are installed and ready to go.
6. Restart Synapse.
7. Restart Synapse.
Then in Riot:

View File

@@ -29,14 +29,13 @@ from synapse.logging import context # omitted from future snippets
def handle_request(request_id):
request_context = context.LoggingContext()
calling_context = context.LoggingContext.current_context()
context.LoggingContext.set_current_context(request_context)
calling_context = context.set_current_context(request_context)
try:
request_context.request = request_id
do_request_handling()
logger.debug("finished")
finally:
context.LoggingContext.set_current_context(calling_context)
context.set_current_context(calling_context)
def do_request_handling():
logger.debug("phew") # this will be logged against request_id

View File

@@ -72,8 +72,7 @@ underneath the database, or if a different version of the locale is used on any
replicas.
The safest way to fix the issue is to take a dump and recreate the database with
the correct `COLLATE` and `CTYPE` parameters (as per
[docs/postgres.md](docs/postgres.md)). It is also possible to change the
the correct `COLLATE` and `CTYPE` parameters (as shown above). It is also possible to change the
parameters on a live database and run a `REINDEX` on the entire database,
however extreme care must be taken to avoid database corruption.
@@ -105,19 +104,41 @@ of free memory the database host has available.
When you are ready to start using PostgreSQL, edit the `database`
section in your config file to match the following lines:
database:
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
```yaml
database:
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
```
All key, values in `args` are passed to the `psycopg2.connect(..)`
function, except keys beginning with `cp_`, which are consumed by the
twisted adbapi connection pool.
twisted adbapi connection pool. See the [libpq
documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS)
for a list of options which can be passed.
You should consider tuning the `args.keepalives_*` options if there is any danger of
the connection between your homeserver and database dropping, otherwise Synapse
may block for an extended period while it waits for a response from the
database server. Example values might be:
```yaml
# seconds of inactivity after which TCP should send a keepalive message to the server
keepalives_idle: 10
# the number of seconds after which a TCP keepalive message that is not
# acknowledged by the server should be retransmitted
keepalives_interval: 10
# the number of TCP keepalives that can be lost before the client's connection
# to the server is considered dead
keepalives_count: 3
```
## Porting from SQLite

View File

@@ -578,13 +578,46 @@ acme:
## Database ##
# The 'database' setting defines the database that synapse uses to store all of
# its data.
#
# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
# 'psycopg2' (for PostgreSQL).
#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:
# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
#
#
# Example SQLite configuration:
#
#database:
# name: sqlite3
# args:
# database: /path/to/homeserver.db
#
#
# Example Postgres configuration:
#
#database:
# name: psycopg2
# args:
# user: synapse
# password: secretpassword
# database: synapse
# host: localhost
# cp_min: 5
# cp_max: 10
#
# For more information on using Synapse with Postgres, see `docs/postgres.md`.
#
database:
# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
name: sqlite3
args:
# Path to the database
database: "DATADIR/homeserver.db"
database: DATADIR/homeserver.db
# Number of events to cache in memory.
#
@@ -839,10 +872,6 @@ media_store_path: "DATADIR/media_store"
#
#enable_registration_captcha: false
# A secret key used to bypass the captcha test entirely.
#
#captcha_bypass_secret: "YOUR_SECRET_HERE"
# The API endpoint to use for verifying m.login.recaptcha responses.
#
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
@@ -1057,6 +1086,29 @@ account_threepid_delegates:
#email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
# Whether users are allowed to change their displayname after it has
# been initially set. Useful when provisioning users based on the
# contents of a third-party directory.
#
# Does not apply to server administrators. Defaults to 'true'
#
#enable_set_displayname: false
# Whether users are allowed to change their avatar after it has been
# initially set. Useful when provisioning users based on the contents
# of a third-party directory.
#
# Does not apply to server administrators. Defaults to 'true'
#
#enable_set_avatar_url: false
# Whether users can change the 3PIDs associated with their accounts
# (email address and msisdn).
#
# Defaults to 'true'
#
#enable_3pid_changes: false
# Users who register on this homeserver will automatically be joined
# to these rooms
#
@@ -1392,6 +1444,10 @@ sso:
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
# If public_baseurl is set, then the login fallback page (used by clients
# that don't natively support the required login flows) is whitelisted in
# addition to any URLs in this list.
#
# By default, this list is empty.
#
#client_whitelist:
@@ -1453,6 +1509,41 @@ password_config:
#
#pepper: "EVEN_MORE_SECRET"
# Define and enforce a password policy. Each parameter is optional.
# This is an implementation of MSC2000.
#
policy:
# Whether to enforce the password policy.
# Defaults to 'false'.
#
#enabled: true
# Minimum accepted length for a password.
# Defaults to 0.
#
#minimum_length: 15
# Whether a password must contain at least one digit.
# Defaults to 'false'.
#
#require_digit: true
# Whether a password must contain at least one symbol.
# A symbol is any character that's not a number or a letter.
# Defaults to 'false'.
#
#require_symbol: true
# Whether a password must contain at least one lowercase letter.
# Defaults to 'false'.
#
#require_lowercase: true
# Whether a password must contain at least one lowercase letter.
# Defaults to 'false'.
#
#require_uppercase: true
# Configuration for sending emails from Synapse.
#

View File

@@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and
'<' worker to master flows):
> SERVER example.com
< REPLICATE events 53
< REPLICATE
> POSITION events 53
> RDATA events 54 ["$foo1:bar.com", ...]
> RDATA events 55 ["$foo4:bar.com", ...]
The example shows the server accepting a new connection and sending its
identity with the `SERVER` command, followed by the client asking to
subscribe to the `events` stream from the token `53`. The server then
periodically sends `RDATA` commands which have the format
`RDATA <stream_name> <token> <row>`, where the format of `<row>` is
defined by the individual streams.
The example shows the server accepting a new connection and sending its identity
with the `SERVER` command, followed by the client server to respond with the
position of all streams. The server then periodically sends `RDATA` commands
which have the format `RDATA <stream_name> <token> <row>`, where the format of
`<row>` is defined by the individual streams.
Error reporting happens by either the client or server sending an ERROR
command, and usually the connection will be closed.
@@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually
connect to the server using a tool like netcat. A few things should be
noted when manually using the protocol:
- When subscribing to a stream using `REPLICATE`, the special token
`NOW` can be used to get all future updates. The special stream name
`ALL` can be used with `NOW` to subscribe to all available streams.
- The federation stream is only available if federation sending has
been disabled on the main process.
- The server will only time connections out that have sent a `PING`
@@ -91,9 +88,7 @@ The client:
- Sends a `NAME` command, allowing the server to associate a human
friendly name with the connection. This is optional.
- Sends a `PING` as above
- For each stream the client wishes to subscribe to it sends a
`REPLICATE` with the `stream_name` and token it wants to subscribe
from.
- Sends a `REPLICATE` to get the current position of all streams.
- On receipt of a `SERVER` command, checks that the server name
matches the expected server name.
@@ -140,9 +135,7 @@ the wire:
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE events 1
< REPLICATE backfill 1
< REPLICATE caches 1
< REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -181,9 +174,9 @@ client (C):
#### POSITION (S)
The position of the stream has been updated. Sent to the client
after all missing updates for a stream have been sent to the client
and they're now up to date.
On receipt of a POSITION command clients should check if they have missed any
updates, and if so then fetch them out of band. Sent in response to a
REPLICATE command (but can happen at any time).
#### ERROR (S, C)
@@ -199,25 +192,18 @@ client (C):
#### REPLICATE (C)
Asks the server to replicate a given stream. The syntax is:
```
REPLICATE <stream_name> <token>
```
Where `<token>` may be either:
* a numeric stream_id to stream updates since (exclusive)
* `NOW` to stream all subsequent updates.
The `<stream_name>` is the name of a replication stream to subscribe
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
of streams). It can also be `ALL` to subscribe to all known streams,
in which case the `<token>` must be set to `NOW`.
Asks the server for the current position of all streams.
#### USER_SYNC (C)
A user has started or stopped syncing
#### CLEAR_USER_SYNC (C)
The server should clear all associated user sync data from the worker.
This is used when a worker is shutting down.
#### FEDERATION_ACK (C)
Acknowledge receipt of some federation data

View File

@@ -64,6 +64,13 @@ class Codes(object):
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
WEAK_PASSWORD = "M_WEAK_PASSWORD"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
BAD_ALIAS = "M_BAD_ALIAS"
@@ -439,6 +446,20 @@ class IncompatibleRoomVersionError(SynapseError):
return cs_error(self.msg, self.errcode, room_version=self._room_version)
class PasswordRefusedError(SynapseError):
"""A password has been refused, either during password reset/change or registration.
"""
def __init__(
self,
msg="This password doesn't comply with the server's policy",
errcode=Codes.WEAK_PASSWORD,
):
super(PasswordRefusedError, self).__init__(
code=400, msg=msg, errcode=errcode,
)
class RequestSendFailed(RuntimeError):
"""Sending a HTTP request over federation failed due to not being able to
talk to the remote server for some reason.

View File

@@ -64,13 +64,26 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams._base import (
from synapse.replication.tcp.client import ReplicationClientFactory
from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.handler import WorkerReplicationDataHandler
from synapse.replication.tcp.streams import (
AccountDataStream,
DeviceListsStream,
GroupServerStream,
PresenceStream,
PushersStream,
PushRulesStream,
ReceiptsStream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@@ -113,7 +126,6 @@ from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.generic_worker")
@@ -222,6 +234,7 @@ class GenericWorkerPresence(object):
self.user_to_num_current_syncs = {}
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -234,13 +247,24 @@ class GenericWorkerPresence(object):
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
self.process_id = random_string(16)
logger.info("Presence process_id is %r", self.process_id)
hs.get_reactor().addSystemEventTrigger(
"before",
"shutdown",
run_as_background_process,
"generic_presence.on_shutdown",
self._on_shutdown,
)
def _on_shutdown(self):
if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_command(
ClearUserSyncsCommand(self.instance_id)
)
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_user_sync(
user_id, is_syncing, last_sync_ms
self.instance_id, user_id, is_syncing, last_sync_ms
)
def mark_as_coming_online(self, user_id):
@@ -390,6 +414,9 @@ class GenericWorkerTyping(object):
self._room_serials[row.room_id] = token
self._room_typing[row.room_id] = row.user_ids
def get_current_token(self) -> int:
return self._latest_room_serial
class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
@@ -572,25 +599,26 @@ class GenericWorkerServer(HomeServer):
else:
logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
factory = ReplicationClientFactory(self, self.config.worker_name)
host = self.config.worker_replication_host
port = self.config.worker_replication_port
self.get_reactor().connectTCP(host, port, factory)
def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
def build_tcp_replication(self):
return GenericWorkerReplicationHandler(self)
def build_presence_handler(self):
return GenericWorkerPresence(self)
def build_typing_handler(self):
return GenericWorkerTyping(self)
def build_replication_data_handler(self):
return GenericWorkerReplicationHandler(self)
class GenericWorkerReplicationHandler(ReplicationClientHandler):
class GenericWorkerReplicationHandler(WorkerReplicationDataHandler):
def __init__(self, hs):
super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
# NB this is a SynchrotronPresence, not a normal PresenceHandler
@@ -618,15 +646,12 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
args.update(self.send_handler.stream_positions())
return args
def get_currently_syncing_users(self):
return self.presence_handler.get_currently_syncing_users()
async def process_and_notify(self, stream_name, token, rows):
try:
if self.send_handler:
self.send_handler.process_replication_rows(stream_name, token, rows)
if stream_name == "events":
if stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
for row in rows:
@@ -649,43 +674,44 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
)
await self.pusher_pool.on_new_notifications(token, token)
elif stream_name == "push_rules":
elif stream_name == PushRulesStream.NAME:
self.notifier.on_new_event(
"push_rules_key", token, users=[row.user_id for row in rows]
)
elif stream_name in ("account_data", "tag_account_data"):
elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
self.notifier.on_new_event(
"account_data_key", token, users=[row.user_id for row in rows]
)
elif stream_name == "receipts":
elif stream_name == ReceiptsStream.NAME:
self.notifier.on_new_event(
"receipt_key", token, rooms=[row.room_id for row in rows]
)
await self.pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
elif stream_name == "typing":
elif stream_name == TypingStream.NAME:
self.typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows]
)
elif stream_name == "to_device":
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
self.notifier.on_new_event("to_device_key", token, users=entities)
elif stream_name == "device_lists":
elif stream_name == DeviceListsStream.NAME:
all_room_ids = set()
for row in rows:
room_ids = await self.store.get_rooms_for_user(row.user_id)
all_room_ids.update(room_ids)
if row.entity.startswith("@"):
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
elif stream_name == "presence":
elif stream_name == PresenceStream.NAME:
await self.presence_handler.process_replication_rows(token, rows)
elif stream_name == "receipts":
elif stream_name == GroupServerStream.NAME:
self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows]
)
elif stream_name == "pushers":
elif stream_name == PushersStream.NAME:
for row in rows:
if row.deleted:
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
@@ -774,7 +800,10 @@ class FederationSenderHandler(object):
# ... as well as device updates and messages
elif stream_name == DeviceListsStream.NAME:
hosts = {row.destination for row in rows}
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
for host in hosts:
self.federation_sender.send_device_messages(host)
@@ -789,7 +818,7 @@ class FederationSenderHandler(object):
async def _on_new_receipts(self, rows):
"""
Args:
rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
new receipts to be processed
"""
for receipt in rows:
@@ -860,6 +889,9 @@ def start(config_options):
# Force the appservice to start since they will be disabled in the main config
config.notify_appservices = True
else:
# For other worker types we force this to off.
config.notify_appservices = False
if config.worker_app == "synapse.app.pusher":
if config.start_pushers:
@@ -873,6 +905,9 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.start_pushers = True
else:
# For other worker types we force this to off.
config.start_pushers = False
if config.worker_app == "synapse.app.user_dir":
if config.update_user_directory:
@@ -886,6 +921,9 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.update_user_directory = True
else:
# For other worker types we force this to off.
config.update_user_directory = False
if config.worker_app == "synapse.app.federation_sender":
if config.send_federation:
@@ -899,6 +937,9 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.send_federation = True
else:
# For other worker types we force this to off.
config.send_federation = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts

View File

@@ -294,7 +294,6 @@ class RootConfig(object):
report_stats=None,
open_private_ports=False,
listeners=None,
database_conf=None,
tls_certificate_path=None,
tls_private_key_path=None,
acme_domain=None,
@@ -367,7 +366,6 @@ class RootConfig(object):
report_stats=report_stats,
open_private_ports=open_private_ports,
listeners=listeners,
database_conf=database_conf,
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,

View File

@@ -24,7 +24,6 @@ class CaptchaConfig(Config):
self.enable_registration_captcha = config.get(
"enable_registration_captcha", False
)
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config.get(
"recaptcha_siteverify_api",
"https://www.recaptcha.net/recaptcha/api/siteverify",
@@ -49,10 +48,6 @@ class CaptchaConfig(Config):
#
#enable_registration_captcha: false
# A secret key used to bypass the captcha test entirely.
#
#captcha_bypass_secret: "YOUR_SECRET_HERE"
# The API endpoint to use for verifying m.login.recaptcha responses.
#
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,14 +15,65 @@
# limitations under the License.
import logging
import os
from textwrap import indent
import yaml
from synapse.config._base import Config, ConfigError
logger = logging.getLogger(__name__)
NON_SQLITE_DATABASE_PATH_WARNING = """\
Ignoring 'database_path' setting: not using a sqlite3 database.
--------------------------------------------------------------------------------
"""
DEFAULT_CONFIG = """\
## Database ##
# The 'database' setting defines the database that synapse uses to store all of
# its data.
#
# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
# 'psycopg2' (for PostgreSQL).
#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:
# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
#
#
# Example SQLite configuration:
#
#database:
# name: sqlite3
# args:
# database: /path/to/homeserver.db
#
#
# Example Postgres configuration:
#
#database:
# name: psycopg2
# args:
# user: synapse
# password: secretpassword
# database: synapse
# host: localhost
# cp_min: 5
# cp_max: 10
#
# For more information on using Synapse with Postgres, see `docs/postgres.md`.
#
database:
name: sqlite3
args:
database: %(database_path)s
# Number of events to cache in memory.
#
#event_cache_size: 10K
"""
class DatabaseConnectionConfig:
"""Contains the connection config for a particular database.
@@ -36,10 +88,12 @@ class DatabaseConnectionConfig:
"""
def __init__(self, name: str, db_config: dict):
if db_config["name"] not in ("sqlite3", "psycopg2"):
raise ConfigError("Unsupported database type %r" % (db_config["name"],))
db_engine = db_config.get("name", "sqlite3")
if db_config["name"] == "sqlite3":
if db_engine not in ("sqlite3", "psycopg2"):
raise ConfigError("Unsupported database type %r" % (db_engine,))
if db_engine == "sqlite3":
db_config.setdefault("args", {}).update(
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
)
@@ -56,6 +110,11 @@ class DatabaseConnectionConfig:
class DatabaseConfig(Config):
section = "database"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.databases = []
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
@@ -76,12 +135,13 @@ class DatabaseConfig(Config):
multi_database_config = config.get("databases")
database_config = config.get("database")
database_path = config.get("database_path")
if multi_database_config and database_config:
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
if multi_database_config:
if config.get("database_path"):
if database_path:
raise ConfigError("Can't specify 'database_path' with 'databases'")
self.databases = [
@@ -89,65 +149,55 @@ class DatabaseConfig(Config):
for name, db_conf in multi_database_config.items()
]
else:
if database_config is None:
database_config = {"name": "sqlite3", "args": {}}
if database_config:
self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(config.get("database_path"))
if database_path:
if self.databases and self.databases[0].name != "sqlite3":
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
return
def generate_config_section(self, data_dir_path, database_conf, **kwargs):
if not database_conf:
database_path = os.path.join(data_dir_path, "homeserver.db")
database_conf = (
"""# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
# Path to the database
database: "%(database_path)s"
"""
% locals()
)
else:
database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip()
database_config = {"name": "sqlite3", "args": {}}
self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(database_path)
return (
"""\
## Database ##
database:
%(database_conf)s
# Number of events to cache in memory.
#
#event_cache_size: 10K
"""
% locals()
)
def generate_config_section(self, data_dir_path, **kwargs):
return DEFAULT_CONFIG % {
"database_path": os.path.join(data_dir_path, "homeserver.db")
}
def read_arguments(self, args):
self.set_databasepath(args.database_path)
"""
Cases for the cli input:
- If no databases are configured and no database_path is set, raise.
- No databases and only database_path available ==> sqlite3 db.
- If there are multiple databases and a database_path raise an error.
- If the database set in the config file is sqlite then
overwrite with the command line argument.
"""
if args.database_path is None:
if not self.databases:
raise ConfigError("No database config provided")
return
if len(self.databases) == 0:
database_config = {"name": "sqlite3", "args": {}}
self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(args.database_path)
return
if self.get_single_database().name == "sqlite3":
self.set_databasepath(args.database_path)
else:
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
def set_databasepath(self, database_path):
if database_path is None:
return
if database_path != ":memory:":
database_path = self.abspath(database_path)
# We only support setting a database path if we have a single sqlite3
# database.
if len(self.databases) != 1:
raise ConfigError("Cannot specify 'database_path' with multiple databases")
database = self.get_single_database()
if database.config["name"] != "sqlite3":
# We don't raise here as we haven't done so before for this case.
logger.warn("Ignoring 'database_path' for non-sqlite3 database")
return
database.config["args"]["database"] = database_path
self.databases[0].config["args"]["database"] = database_path
@staticmethod
def add_arguments(parser):
@@ -162,7 +212,7 @@ class DatabaseConfig(Config):
def get_single_database(self) -> DatabaseConnectionConfig:
"""Returns the database if there is only one, useful for e.g. tests
"""
if len(self.databases) != 1:
if not self.databases:
raise Exception("More than one database exists")
return self.databases[0]

View File

@@ -31,6 +31,10 @@ class PasswordConfig(Config):
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
self.password_pepper = password_config.get("pepper", "")
# Password policy
self.password_policy = password_config.get("policy") or {}
self.password_policy_enabled = self.password_policy.get("enabled", False)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
@@ -48,4 +52,39 @@ class PasswordConfig(Config):
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#
#pepper: "EVEN_MORE_SECRET"
# Define and enforce a password policy. Each parameter is optional.
# This is an implementation of MSC2000.
#
policy:
# Whether to enforce the password policy.
# Defaults to 'false'.
#
#enabled: true
# Minimum accepted length for a password.
# Defaults to 0.
#
#minimum_length: 15
# Whether a password must contain at least one digit.
# Defaults to 'false'.
#
#require_digit: true
# Whether a password must contain at least one symbol.
# A symbol is any character that's not a number or a letter.
# Defaults to 'false'.
#
#require_symbol: true
# Whether a password must contain at least one lowercase letter.
# Defaults to 'false'.
#
#require_lowercase: true
# Whether a password must contain at least one lowercase letter.
# Defaults to 'false'.
#
#require_uppercase: true
"""

View File

@@ -129,6 +129,10 @@ class RegistrationConfig(Config):
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
self.enable_set_displayname = config.get("enable_set_displayname", True)
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
self.enable_3pid_changes = config.get("enable_3pid_changes", True)
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -330,6 +334,29 @@ class RegistrationConfig(Config):
#email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
# Whether users are allowed to change their displayname after it has
# been initially set. Useful when provisioning users based on the
# contents of a third-party directory.
#
# Does not apply to server administrators. Defaults to 'true'
#
#enable_set_displayname: false
# Whether users are allowed to change their avatar after it has been
# initially set. Useful when provisioning users based on the contents
# of a third-party directory.
#
# Does not apply to server administrators. Defaults to 'true'
#
#enable_set_avatar_url: false
# Whether users can change the 3PIDs associated with their accounts
# (email address and msisdn).
#
# Defaults to 'true'
#
#enable_3pid_changes: false
# Users who register on this homeserver will automatically be joined
# to these rooms
#

View File

@@ -39,6 +39,17 @@ class SSOConfig(Config):
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
# Attempt to also whitelist the server's login fallback, since that fallback sets
# the redirect URL to itself (so it can process the login token then return
# gracefully to the client). This would make it pointless to ask the user for
# confirmation, since the URL the confirmation page would be showing wouldn't be
# the client's.
# public_baseurl is an optional setting, so we only add the fallback's URL to the
# list if it's provided (because we can't figure out what that URL is otherwise).
if self.public_baseurl:
login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
self.sso_client_whitelist.append(login_fallback_url)
def generate_config_section(self, **kwargs):
return """\
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
@@ -54,6 +65,10 @@ class SSOConfig(Config):
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
# If public_baseurl is set, then the login fallback page (used by clients
# that don't natively support the required login flows) is whitelisted in
# addition to any URLs in this list.
#
# By default, this list is empty.
#
#client_whitelist:

View File

@@ -43,8 +43,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
preserve_fn,
run_in_background,
@@ -236,7 +236,7 @@ class Keyring(object):
"""
try:
ctx = LoggingContext.current_context()
ctx = current_context()
# map from server name to a set of outstanding request ids
server_to_request_ids = {}

View File

@@ -25,19 +25,15 @@ from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
)
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.types import JsonDict, get_domain_from_id
@@ -55,13 +51,15 @@ class FederationBase(object):
self.store = hs.get_datastore()
self._clock = hs.get_clock()
def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
def _check_sigs_and_hash(
self, room_version: RoomVersion, pdu: EventBase
) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)
def _check_sigs_and_hashes(
self, room_version: str, pdus: List[EventBase]
self, room_version: RoomVersion, pdus: List[EventBase]
) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
@@ -80,7 +78,7 @@ class FederationBase(object):
"""
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
ctx = LoggingContext.current_context()
ctx = current_context()
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
@@ -146,7 +144,7 @@ class PduToCheckSig(
def _check_sigs_on_pdus(
keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
) -> List[Deferred]:
"""Check that the given events are correctly signed
@@ -191,10 +189,6 @@ def _check_sigs_on_pdus(
for p in pdus
]
v = KNOWN_ROOM_VERSIONS.get(room_version)
if not v:
raise RuntimeError("Unrecognized room version %s" % (room_version,))
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
@@ -204,7 +198,7 @@ def _check_sigs_on_pdus(
(
p.sender_domain,
p.redacted_pdu_json,
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
@@ -227,7 +221,7 @@ def _check_sigs_on_pdus(
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
# (ie, the room version uses old-style non-hash event IDs).
if v.event_format == EventFormatVersions.V1:
if room_version.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
p
for p in pdus_to_check
@@ -239,7 +233,7 @@ def _check_sigs_on_pdus(
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id

View File

@@ -220,8 +220,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = await make_deferred_yieldable(
defer.gatherResults(
self._check_sigs_and_hashes(room_version.identifier, pdus),
consumeErrors=True,
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
).addErrback(unwrapFirstError)
)
@@ -291,9 +290,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
signed_pdu = await self._check_sigs_and_hash(
room_version.identifier, pdu
)
signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
break
@@ -350,7 +347,7 @@ class FederationClient(FederationBase):
self,
origin: str,
pdus: List[EventBase],
room_version: str,
room_version: RoomVersion,
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
@@ -396,7 +393,7 @@ class FederationClient(FederationBase):
self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version, # type: ignore
room_version=room_version,
outlier=outlier,
timeout=10000,
)
@@ -434,7 +431,7 @@ class FederationClient(FederationBase):
]
signed_auth = await self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True, room_version=room_version.identifier
destination, auth_chain, outlier=True, room_version=room_version
)
signed_auth.sort(key=lambda e: e.depth)
@@ -661,7 +658,7 @@ class FederationClient(FederationBase):
destination,
list(pdus.values()),
outlier=True,
room_version=room_version.identifier,
room_version=room_version,
)
valid_pdus_map = {p.event_id: p for p in valid_pdus}
@@ -756,7 +753,7 @@ class FederationClient(FederationBase):
pdu = event_from_pdu_json(pdu_dict, room_version)
# Check signatures are correct.
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)
# FIXME: We should handle signature failures more gracefully.
@@ -948,7 +945,7 @@ class FederationClient(FederationBase):
]
signed_events = await self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False, room_version=room_version.identifier
destination, events, outlier=False, room_version=room_version
)
except HttpResponseException as e:
if not e.code == 400:

View File

@@ -409,7 +409,7 @@ class FederationServer(FederationBase):
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)
ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
@@ -425,7 +425,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)
res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -455,7 +455,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)
await self.handler.on_send_leave_request(origin, pdu)
return {}
@@ -611,7 +611,7 @@ class FederationServer(FederationBase):
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
room_version = await self.store.get_room_version_id(pdu.room_id)
room_version = await self.store.get_room_version(pdu.room_id)
# Check signature.
try:

View File

@@ -477,7 +477,7 @@ def process_rows_for_federation(transaction_queue, rows):
Args:
transaction_queue (FederationSender)
rows (list(synapse.replication.tcp.streams.FederationStreamRow))
rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow))
"""
# The federation stream contains a bunch of different types of

View File

@@ -499,4 +499,13 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self) -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return []

View File

@@ -125,7 +125,11 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def validate_user_via_ui_auth(
self, requester: Requester, request_body: Dict[str, Any], clientip: str
self,
requester: Requester,
request: SynapseRequest,
request_body: Dict[str, Any],
clientip: str,
):
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -137,6 +141,8 @@ class AuthHandler(BaseHandler):
Args:
requester: The user, as given by the access token
request: The request sent by the client.
request_body: The body of the request sent by the client
clientip: The IP address of the client.
@@ -172,7 +178,9 @@ class AuthHandler(BaseHandler):
flows = [[login_type] for login_type in self._supported_login_types]
try:
result, params, _ = yield self.check_auth(flows, request_body, clientip)
result, params, _ = yield self.check_auth(
flows, request, request_body, clientip
)
except LoginError:
# Update the ratelimite to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(
@@ -211,7 +219,11 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def check_auth(
self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
self,
flows: List[List[str]],
request: SynapseRequest,
clientdict: Dict[str, Any],
clientip: str,
):
"""
Takes a dictionary sent by the client in the login / registration
@@ -231,6 +243,8 @@ class AuthHandler(BaseHandler):
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
request: The request sent by the client.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
@@ -270,13 +284,27 @@ class AuthHandler(BaseHandler):
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a homeserver.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbintrary.
session["clientdict"] = clientdict
self._save_session(session)
elif "clientdict" in session:
clientdict = session["clientdict"]
# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
comparator = (request.uri, request.method, clientdict)
if "ui_auth" not in session:
session["ui_auth"] = comparator
elif session["ui_auth"] != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
)
if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session)
@@ -322,6 +350,7 @@ class AuthHandler(BaseHandler):
creds,
list(clientdict),
)
return creds, clientdict, session["id"]
ret = self._auth_dict_for_flows(flows, session)

View File

@@ -0,0 +1,204 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import xml.etree.ElementTree as ET
from typing import AnyStr, Dict, Optional, Tuple
from six.moves import urllib
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
logger = logging.getLogger(__name__)
class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.
Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self._cas_displayname_attribute = hs.config.cas_displayname_attribute
self._cas_required_attributes = hs.config.cas_required_attributes
self._http_client = hs.get_proxied_http_client()
def _build_service_param(self, client_redirect_url: AnyStr) -> str:
return "%s%s?%s" % (
self._cas_service_url,
"/_matrix/client/r0/login/cas/ticket",
urllib.parse.urlencode({"redirectUrl": client_redirect_url}),
)
async def _handle_cas_response(
self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str
) -> None:
"""
Retrieves the user and display name from the CAS response and continues with the authentication.
Args:
request: The original client request.
cas_response_body: The response from the CAS server.
client_redirect_url: The URl to redirect the client to when
everything is done.
"""
user, attributes = self._parse_cas_response(cas_response_body)
displayname = attributes.pop(self._cas_displayname_attribute, None)
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
await self._on_successful_auth(user, request, client_redirect_url, displayname)
def _parse_cas_response(
self, cas_response_body: str
) -> Tuple[str, Dict[str, Optional[str]]]:
"""
Retrieve the user and other parameters from the CAS response.
Args:
cas_response_body: The response from the CAS query.
Returns:
A tuple of the user and a mapping of other attributes.
"""
user = None
attributes = {}
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
async def _on_successful_auth(
self,
username: str,
request: SynapseRequest,
client_redirect_url: str,
user_display_name: Optional[str] = None,
) -> None:
"""Called once the user has successfully authenticated with the SSO.
Registers the user if necessary, and then returns a redirect (with
a login token) to the client.
Args:
username: the remote user id. We'll map this onto
something sane for a MXID localpath.
request: the incoming request from the browser. We'll
respond to it with a redirect.
client_redirect_url: the redirect_url the client gave us when
it first started the process.
user_display_name: if set, and we have to register a new user,
we will set their displayname to this.
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name
)
self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url
)
def handle_redirect_request(self, client_redirect_url: bytes) -> bytes:
"""
Generates a URL to the CAS server where the client should be redirected.
Args:
client_redirect_url: The final URL the client should go to after the
user has negotiated SSO.
Returns:
The URL to redirect to.
"""
args = urllib.parse.urlencode(
{"service": self._build_service_param(client_redirect_url)}
)
return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii")
async def handle_ticket_request(
self, request: SynapseRequest, client_redirect_url: str, ticket: str
) -> None:
"""
Validates a CAS ticket sent by the client for login/registration.
On a successful request, writes a redirect to the request.
"""
uri = self._cas_server_url + "/proxyValidate"
args = {
"ticket": ticket,
"service": self._build_service_param(client_redirect_url),
}
try:
body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
await self._handle_cas_response(request, body, client_redirect_url)

View File

@@ -125,8 +125,14 @@ class DeviceWorkerHandler(BaseHandler):
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
tracked_users = set(users_who_share_room)
# Always tell the user about their own devices
tracked_users.add(user_id)
changed = yield self.store.get_users_whose_devices_changed(
from_token.device_list_key, users_who_share_room
from_token.device_list_key, tracked_users
)
# Then work out if any users have since joined
@@ -456,7 +462,11 @@ class DeviceHandler(DeviceWorkerHandler):
room_ids = yield self.store.get_rooms_for_user(user_id)
yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
yield self.notifier.on_new_event(
"device_list_key", position, users=[user_id], rooms=room_ids
)
if hosts:
logger.info(

View File

@@ -851,6 +851,38 @@ class EventCreationHandler(object):
self.store.remove_push_actions_from_staging, event.event_id
)
@defer.inlineCallbacks
def _validate_canonical_alias(
self, directory_handler, room_alias_str, expected_room_id
):
"""
Ensure that the given room alias points to the expected room ID.
Args:
directory_handler: The directory handler object.
room_alias_str: The room alias to check.
expected_room_id: The room ID that the alias should point to.
"""
room_alias = RoomAlias.from_string(room_alias_str)
try:
mapping = yield directory_handler.get_association(room_alias)
except SynapseError as e:
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
if e.errcode == Codes.NOT_FOUND:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
)
raise
if mapping["room_id"] != expected_room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
)
@defer.inlineCallbacks
def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
@@ -905,15 +937,9 @@ class EventCreationHandler(object):
room_alias_str = event.content.get("alias", None)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias:
room_alias = RoomAlias.from_string(room_alias_str)
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
)
yield self._validate_canonical_alias(
directory_handler, room_alias_str, event.room_id
)
# Check that alt_aliases is the proper form.
alt_aliases = event.content.get("alt_aliases", [])
@@ -931,16 +957,9 @@ class EventCreationHandler(object):
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases:
for alias_str in new_alt_aliases:
room_alias = RoomAlias.from_string(alias_str)
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room"
% (room_alias_str,),
Codes.BAD_ALIAS,
)
yield self._validate_canonical_alias(
directory_handler, alias_str, event.room_id
)
federation_handler = self.hs.get_handlers().federation_handler

View File

@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from synapse.api.errors import Codes, PasswordRefusedError
logger = logging.getLogger(__name__)
class PasswordPolicyHandler(object):
def __init__(self, hs):
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
# Regexps for the spec'd policy parameters.
self.regexp_digit = re.compile("[0-9]")
self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
self.regexp_uppercase = re.compile("[A-Z]")
self.regexp_lowercase = re.compile("[a-z]")
def validate_password(self, password):
"""Checks whether a given password complies with the server's policy.
Args:
password (str): The password to check against the server's policy.
Raises:
PasswordRefusedError: The password doesn't comply with the server's policy.
"""
if not self.enabled:
return
minimum_accepted_length = self.policy.get("minimum_length", 0)
if len(password) < minimum_accepted_length:
raise PasswordRefusedError(
msg=(
"The password must be at least %d characters long"
% minimum_accepted_length
),
errcode=Codes.PASSWORD_TOO_SHORT,
)
if (
self.policy.get("require_digit", False)
and self.regexp_digit.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one digit",
errcode=Codes.PASSWORD_NO_DIGIT,
)
if (
self.policy.get("require_symbol", False)
and self.regexp_symbol.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one symbol",
errcode=Codes.PASSWORD_NO_SYMBOL,
)
if (
self.policy.get("require_uppercase", False)
and self.regexp_uppercase.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one uppercase letter",
errcode=Codes.PASSWORD_NO_UPPERCASE,
)
if (
self.policy.get("require_lowercase", False)
and self.regexp_lowercase.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one lowercase letter",
errcode=Codes.PASSWORD_NO_LOWERCASE,
)

View File

@@ -747,7 +747,7 @@ class PresenceHandler(object):
return False
async def get_all_presence_updates(self, last_id, current_id):
async def get_all_presence_updates(self, last_id, current_id, limit):
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@@ -762,7 +762,7 @@ class PresenceHandler(object):
"""
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
rows = await self.store.get_all_presence_updates(last_id, current_id)
rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
return rows
def notify_new_event(self):

View File

@@ -157,6 +157,15 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname:
profile = yield self.store.get_profileinfo(target_user.localpart)
if profile.display_name:
raise SynapseError(
400,
"Changing display name is disabled on this server",
Codes.FORBIDDEN,
)
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -218,6 +227,13 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
if not by_admin and not self.hs.config.enable_set_avatar_url:
profile = yield self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url:
raise SynapseError(
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
)
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)

View File

@@ -26,6 +26,7 @@ from synapse.config import ConfigError
from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi
from synapse.module_api.errors import RedirectException
from synapse.types import (
UserID,
map_username_to_mxid_localpart,
@@ -119,6 +120,9 @@ class SamlHandler:
try:
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
except RedirectException:
# Raise the exception as per the wishes of the SAML module response
raise
except Exception as e:
# If decoding the response or mapping it to a user failed, then log the
# error and tell the user that something went wrong.

View File

@@ -32,6 +32,7 @@ class SetPasswordHandler(BaseHandler):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._password_policy_handler = hs.get_password_policy_handler()
@defer.inlineCallbacks
def set_password(
@@ -44,6 +45,7 @@ class SetPasswordHandler(BaseHandler):
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
self._password_policy_handler.validate_password(new_password)
password_hash = yield self._auth_handler.hash(new_password)
try:

View File

@@ -26,7 +26,7 @@ from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import FilterCollection
from synapse.events import EventBase
from synapse.logging.context import LoggingContext
from synapse.logging.context import current_context
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
@@ -301,7 +301,7 @@ class SyncHandler(object):
else:
sync_type = "incremental_sync"
context = LoggingContext.current_context()
context = current_context()
if context:
context.tag = sync_type
@@ -1143,9 +1143,14 @@ class SyncHandler(object):
user_id
)
tracked_users = set(users_who_share_room)
# Always tell the user about their own devices
tracked_users.add(user_id)
# Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = await self.store.get_users_whose_devices_changed(
since_token.device_list_key, users_who_share_room
since_token.device_list_key, tracked_users
)
# Step 1b, check for newly joined rooms

View File

@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
from typing import List
from twisted.internet import defer
@@ -257,7 +258,13 @@ class TypingHandler(object):
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
async def get_all_typing_updates(self, last_id, current_id):
async def get_all_typing_updates(
self, last_id: int, current_id: int, limit: int
) -> List[dict]:
"""Get up to `limit` typing updates between the given tokens, earliest
updates first.
"""
if last_id == current_id:
return []
@@ -275,7 +282,7 @@ class TypingHandler(object):
typing = self._room_typing[room_id]
rows.append((serial, room_id, list(typing)))
rows.sort()
return rows
return rows[:limit]
def get_current_token(self):
return self._latest_room_serial

View File

@@ -19,7 +19,7 @@ import threading
from prometheus_client.core import Counter, Histogram
from synapse.logging.context import LoggingContext
from synapse.logging.context import current_context
from synapse.metrics import LaterGauge
logger = logging.getLogger(__name__)
@@ -148,7 +148,7 @@ LaterGauge(
class RequestMetrics(object):
def start(self, time_sec, name, method):
self.start = time_sec
self.start_context = LoggingContext.current_context()
self.start_context = current_context()
self.name = name
self.method = method
@@ -163,7 +163,7 @@ class RequestMetrics(object):
with _in_flight_requests_lock:
_in_flight_requests.discard(self)
context = LoggingContext.current_context()
context = current_context()
tag = ""
if context:

View File

@@ -42,7 +42,7 @@ from synapse.logging._terse_json import (
TerseJSONToConsoleLogObserver,
TerseJSONToTCPLogObserver,
)
from synapse.logging.context import LoggingContext
from synapse.logging.context import current_context
def stdlib_log_level_to_twisted(level: str) -> LogLevel:
@@ -86,7 +86,7 @@ class LogContextObserver(object):
].startswith("Timing out client"):
return
context = LoggingContext.current_context()
context = current_context()
# Copy the context information to the log event.
if context is not None:

View File

@@ -175,7 +175,54 @@ class ContextResourceUsage(object):
return res
LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
class _Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = ["previous_context", "finished", "request", "scope", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.finished = False
self.request = None
self.scope = None
self.tag = None
def __str__(self):
return "sentinel"
def copy_to(self, record):
pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self):
pass
def stop(self):
pass
def add_database_transaction(self, duration_sec):
pass
def add_database_scheduled(self, sched_sec):
pass
def record_event_fetch(self, event_count):
pass
def __nonzero__(self):
return False
__bool__ = __nonzero__ # python3
SENTINEL_CONTEXT = _Sentinel()
class LoggingContext(object):
@@ -199,76 +246,33 @@ class LoggingContext(object):
"_resource_usage",
"usage_start",
"main_thread",
"alive",
"finished",
"request",
"tag",
"scope",
]
thread_local = threading.local()
class Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = ["previous_context", "alive", "request", "scope", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.alive = None
self.request = None
self.scope = None
self.tag = None
def __str__(self):
return "sentinel"
def copy_to(self, record):
pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self):
pass
def stop(self):
pass
def add_database_transaction(self, duration_sec):
pass
def add_database_scheduled(self, sched_sec):
pass
def record_event_fetch(self, event_count):
pass
def __nonzero__(self):
return False
__bool__ = __nonzero__ # python3
sentinel = Sentinel()
def __init__(self, name=None, parent_context=None, request=None) -> None:
self.previous_context = LoggingContext.current_context()
self.previous_context = current_context()
self.name = name
# track the resources used by this context so far
self._resource_usage = ContextResourceUsage()
# If alive has the thread resource usage when the logcontext last
# became active.
# The thread resource usage when the logcontext became active. None
# if the context is not currently active.
self.usage_start = None
self.main_thread = get_thread_id()
self.request = None
self.tag = ""
self.alive = True
self.scope = None # type: Optional[_LogContextScope]
# keep track of whether we have hit the __exit__ block for this context
# (suggesting that the the thing that created the context thinks it should
# be finished, and that re-activating it would suggest an error).
self.finished = False
self.parent_context = parent_context
if self.parent_context is not None:
@@ -283,44 +287,15 @@ class LoggingContext(object):
return str(self.request)
return "%s@%x" % (self.name, id(self))
@classmethod
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
Returns:
LoggingContext: the current logging context
"""
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
current = cls.current_context()
if current is not context:
current.stop()
cls.thread_local.current_context = context
context.start()
return current
def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
old_context = self.set_current_context(self)
old_context = set_current_context(self)
if self.previous_context != old_context:
logger.warning(
"Expected previous context %r, found %r",
self.previous_context,
old_context,
)
self.alive = True
return self
def __exit__(self, type, value, traceback) -> None:
@@ -329,24 +304,19 @@ class LoggingContext(object):
Returns:
None to avoid suppressing any exceptions that were thrown.
"""
current = self.set_current_context(self.previous_context)
current = set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
if current is SENTINEL_CONTEXT:
logger.warning("Expected logging context %s was lost", self)
else:
logger.warning(
"Expected logging context %s but found %s", self, current
)
self.alive = False
# if we have a parent, pass our CPU usage stats on
if self.parent_context is not None and hasattr(
self.parent_context, "_resource_usage"
):
self.parent_context._resource_usage += self._resource_usage
# reset them in case we get entered again
self._resource_usage.reset()
# the fact that we are here suggests that the caller thinks that everything
# is done and dusted for this logcontext, and further activity will not get
# recorded against the correct metrics.
self.finished = True
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
@@ -371,9 +341,14 @@ class LoggingContext(object):
logger.warning("Started logcontext %s on different thread", self)
return
if self.finished:
logger.warning("Re-starting finished log context %s", self)
# If we haven't already started record the thread resource usage so
# far
if not self.usage_start:
if self.usage_start:
logger.warning("Re-starting already-active log context %s", self)
else:
self.usage_start = get_thread_resource_usage()
def stop(self) -> None:
@@ -396,6 +371,15 @@ class LoggingContext(object):
self.usage_start = None
# if we have a parent, pass our CPU usage stats on
if self.parent_context is not None and hasattr(
self.parent_context, "_resource_usage"
):
self.parent_context._resource_usage += self._resource_usage
# reset them in case we get entered again
self._resource_usage.reset()
def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far.
@@ -409,7 +393,7 @@ class LoggingContext(object):
# If we are on the correct thread and we're currently running then we
# can include resource usage so far.
is_main_thread = get_thread_id() == self.main_thread
if self.alive and self.usage_start and is_main_thread:
if self.usage_start and is_main_thread:
utime_delta, stime_delta = self._get_cputime()
res.ru_utime += utime_delta
res.ru_stime += stime_delta
@@ -492,7 +476,7 @@ class LoggingContextFilter(logging.Filter):
Returns:
True to include the record in the log output.
"""
context = LoggingContext.current_context()
context = current_context()
for key, value in self.defaults.items():
setattr(record, key, value)
@@ -512,27 +496,24 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
if new_context is None:
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
else:
self.new_context = new_context
def __init__(
self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
) -> None:
self.new_context = new_context
def __enter__(self) -> None:
"""Captures the current logging context"""
self.current_context = LoggingContext.set_current_context(self.new_context)
self.current_context = set_current_context(self.new_context)
if self.current_context:
self.has_parent = self.current_context.previous_context is not None
if not self.current_context.alive:
logger.debug("Entering dead context: %s", self.current_context)
def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context"""
context = LoggingContext.set_current_context(self.current_context)
context = set_current_context(self.current_context)
if context != self.new_context:
if context is LoggingContext.sentinel:
if not context:
logger.warning("Expected logging context %s was lost", self.new_context)
else:
logger.warning(
@@ -541,9 +522,30 @@ class PreserveLoggingContext(object):
context,
)
if self.current_context is not LoggingContext.sentinel:
if not self.current_context.alive:
logger.debug("Restoring dead context: %s", self.current_context)
_thread_local = threading.local()
_thread_local.current_context = SENTINEL_CONTEXT
def current_context() -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage"""
return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
current = current_context()
if current is not context:
current.stop()
_thread_local.current_context = context
context.start()
return current
def nested_logging_context(
@@ -572,7 +574,7 @@ def nested_logging_context(
if parent_context is not None:
context = parent_context # type: LoggingContextOrSentinel
else:
context = LoggingContext.current_context()
context = current_context()
return LoggingContext(
parent_context=context, request=str(context.request) + "-" + suffix
)
@@ -604,7 +606,7 @@ def run_in_background(f, *args, **kwargs):
CRITICAL error about an unhandled error will be logged without much
indication about where it came from.
"""
current = LoggingContext.current_context()
current = current_context()
try:
res = f(*args, **kwargs)
except: # noqa: E722
@@ -625,7 +627,7 @@ def run_in_background(f, *args, **kwargs):
# The function may have reset the context before returning, so
# we need to restore it now.
ctx = LoggingContext.set_current_context(current)
ctx = set_current_context(current)
# The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will
@@ -674,7 +676,7 @@ def make_deferred_yieldable(deferred):
# ok, we can't be sure that a yield won't block, so let's reset the
# logcontext, and add a callback to the deferred to restore it.
prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
prev_context = set_current_context(SENTINEL_CONTEXT)
deferred.addBoth(_set_context_cb, prev_context)
return deferred
@@ -684,7 +686,7 @@ ResultT = TypeVar("ResultT")
def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context"""
LoggingContext.set_current_context(context)
set_current_context(context)
return result
@@ -752,7 +754,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception.
"""
logcontext = LoggingContext.current_context()
logcontext = current_context()
def g():
with LoggingContext(parent_context=logcontext):

View File

@@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager
import twisted
from synapse.logging.context import LoggingContext, nested_logging_context
from synapse.logging.context import current_context, nested_logging_context
logger = logging.getLogger(__name__)
@@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager):
(Scope) : the Scope that is active, or None if not
available.
"""
ctx = LoggingContext.current_context()
if ctx is LoggingContext.sentinel:
return None
else:
return ctx.scope
ctx = current_context()
return ctx.scope
def activate(self, span, finish_on_close):
"""
@@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager):
"""
enter_logcontext = False
ctx = LoggingContext.current_context()
ctx = current_context()
if ctx is LoggingContext.sentinel:
if not ctx:
# We don't want this scope to affect.
logger.error("Tried to activate scope outside of loggingcontext")
return Scope(None, span)

View File

@@ -21,6 +21,7 @@ from synapse.replication.http import (
membership,
register,
send_event,
streams,
)
REPLICATION_PREFIX = "/_synapse/replication"
@@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
streams.register_servlets(hs, self)

View File

@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationGetStreamUpdates(ReplicationEndpoint):
"""Fetches stream updates from a server. Used for streams not persisted to
the database, e.g. typing notifications.
The API looks like:
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
200 OK
{
updates: [ ... ],
upto_token: 10,
limited: False,
}
"""
NAME = "get_repl_stream_updates"
PATH_ARGS = ("stream_name",)
METHOD = "GET"
def __init__(self, hs):
super().__init__(hs)
# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token, limit):
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
async def _handle_request(self, request, stream_name):
stream = self.streams.get(stream_name)
if stream is None:
raise SynapseError(400, "Unknown stream")
from_token = parse_integer(request, "from_token", required=True)
upto_token = parse_integer(request, "upto_token", required=True)
limit = parse_integer(request, "limit", required=True)
updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token, limit
)
return (
200,
{"updates": updates, "upto_token": upto_token, "limited": limited},
)
def register_servlets(hs, http_server):
ReplicationGetStreamUpdates(hs).register(http_server)

View File

@@ -18,8 +18,10 @@ from typing import Dict, Optional
import six
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
from synapse.storage.data_stores.main.cache import (
CURRENT_STATE_CACHE_NAME,
CacheInvalidationWorkerStore,
)
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
@@ -35,7 +37,7 @@ def __func__(inp):
return inp.__func__
class BaseSlavedStore(SQLBaseStore):
class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
@@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:

View File

@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id"
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
@@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
def _invalidate_caches_for_devices(self, token, user_id, destination):
self._device_list_stream_cache.entity_has_changed(user_id, token)
def _invalidate_caches_for_devices(self, token, rows):
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate_many((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
self.get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
)

View File

@@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)

View File

@@ -16,26 +16,10 @@
"""
import logging
from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.protocol import (
AbstractReplicationClientHandler,
ClientReplicationStreamProtocol,
)
from .commands import (
Command,
FederationAckCommand,
InvalidateCacheCommand,
RemoteServerUpCommand,
RemovePusherCommand,
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
logger = logging.getLogger(__name__)
@@ -51,10 +35,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
initialDelay = 0.1
maxDelay = 1 # Try at least once every N seconds
def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
def __init__(self, hs, client_name):
self.client_name = client_name
self.handler = handler
self.handler = hs.get_tcp_replication()
self.server_name = hs.config.server_name
self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +50,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
self.client_name, self.server_name, self._clock, self.handler
self.hs, self.client_name, self.server_name, self._clock, self.handler,
)
def clientConnectionLost(self, connector, reason):
@@ -75,170 +60,3 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def clientConnectionFailed(self, connector, reason):
logger.error("Failed to connect to replication: %r", reason)
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
self.connection = None
# Any pending commands to be sent once a new connection has been
# established
self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
client_name = hs.config.worker_name
self.factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
self.store.process_replication_rows(stream_name, token, rows)
async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
for the sync with the given data.
Used by tests.
"""
d = self.awaiting_syncs.pop(data, None)
if d:
d.callback(data)
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users. (Overriden by the synchrotron's only)
"""
return []
def send_command(self, cmd):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
if self.connection:
self.connection.send_command(cmd)
else:
logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
self.send_command(FederationAckCommand(token))
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
"""Poke the master that a user has started/stopped syncing.
"""
self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
def send_remove_pusher(self, app_id, push_key, user_id):
"""Poke the master to remove a pusher for a user
"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func, keys):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
"""Tell the master that the user made a request.
"""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def await_sync(self, data):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
[Not currently] used by tests.
"""
return self.awaiting_syncs.setdefault(data, defer.Deferred())
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
self.connection = connection
if connection:
for cmd in self.pending_commands:
connection.send_command(cmd)
self.pending_commands = []
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
logger.info("Finished connecting to server")
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
if self.factory:
self.factory.resetDelay()

View File

@@ -136,8 +136,8 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
Sent to the client after all missing updates for a stream have been sent
to the client and they're now up to date.
On receipt of a POSITION command clients should check if they have missed
any updates, and if so then fetch them out of band.
"""
NAME = "POSITION"
@@ -179,42 +179,24 @@ class NameCommand(Command):
class ReplicateCommand(Command):
"""Sent by the client to subscribe to the stream.
"""Sent by the client to subscribe to streams.
Format::
REPLICATE <stream_name> <token>
Where <token> may be either:
* a numeric stream_id to stream updates from
* "NOW" to stream all subsequent updates.
The <stream_name> can be "ALL" to subscribe to all known streams, in which
case the <token> must be set to "NOW", i.e.::
REPLICATE ALL NOW
REPLICATE
"""
NAME = "REPLICATE"
def __init__(self, stream_name, token):
self.stream_name = stream_name
self.token = token
def __init__(self):
pass
@classmethod
def from_line(cls, line):
stream_name, token = line.split(" ", 1)
if token in ("NOW", "now"):
token = "NOW"
else:
token = int(token)
return cls(stream_name, token)
return cls()
def to_line(self):
return " ".join((self.stream_name, str(self.token)))
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
return ""
class UserSyncCommand(Command):
@@ -225,30 +207,32 @@ class UserSyncCommand(Command):
Format::
USER_SYNC <user_id> <state> <last_sync_ms>
USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
Where <state> is either "start" or "stop"
"""
NAME = "USER_SYNC"
def __init__(self, user_id, is_syncing, last_sync_ms):
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
self.instance_id = instance_id
self.user_id = user_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls, line):
user_id, state, last_sync_ms = line.split(" ", 2)
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
if state not in ("start", "end"):
raise Exception("Invalid USER_SYNC state %r" % (state,))
return cls(user_id, state == "start", int(last_sync_ms))
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
def to_line(self):
return " ".join(
(
self.instance_id,
self.user_id,
"start" if self.is_syncing else "end",
str(self.last_sync_ms),
@@ -256,6 +240,30 @@ class UserSyncCommand(Command):
)
class ClearUserSyncsCommand(Command):
"""Sent by the client to inform the server that it should drop all
information about syncing users sent by the client.
Mainly used when client is about to shut down.
Format::
CLEAR_USER_SYNC <instance_id>
"""
NAME = "CLEAR_USER_SYNC"
def __init__(self, instance_id):
self.instance_id = instance_id
@classmethod
def from_line(cls, line):
return cls(line)
def to_line(self):
return self.instance_id
class FederationAckCommand(Command):
"""Sent by the client when it has processed up to a given point in the
federation stream. This allows the master to drop in-memory caches of the
@@ -416,6 +424,7 @@ _COMMANDS = (
InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
@@ -438,6 +447,7 @@ VALID_CLIENT_COMMANDS = (
ReplicateCommand.NAME,
PingCommand.NAME,
UserSyncCommand.NAME,
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME,

View File

@@ -0,0 +1,410 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A replication client for use by synapse workers.
"""
import logging
from typing import Any, Callable, Dict, List
from prometheus_client import Counter
from synapse.metrics import LaterGauge
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
RemovePusherCommand,
ReplicateCommand,
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
class ReplicationClientHandler:
"""Handles incoming commands from replication.
Proxies data to `HomeServer.get_replication_data_handler()`.
"""
def __init__(self, hs):
self.replication_data_handler = hs.get_replication_data_handler()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
self.presence_handler = hs.get_presence_handler()
self.instance_id = hs.get_instance_id()
self._position_linearizer = Linearizer("replication_position")
self.connections = [] # type: List[Any]
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self.connections),
)
LaterGauge(
"synapse_replication_tcp_resource_connections_per_stream",
"",
["stream_name"],
lambda: {
(stream_name,): len(
[
conn
for conn in self.connections
if stream_name in conn.replication_streams
]
)
for stream_name in self.streams
},
)
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, List[Any]]
self.is_master = hs.config.worker_app is None
self.federation_sender = None
if self.is_master and not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self._server_notices_sender = None
if self.is_master:
self._server_notices_sender = hs.get_server_notices_sender()
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
def new_connection(self, connection):
self.connections.append(connection)
def lost_connection(self, connection):
try:
self.connections.remove(connection)
except ValueError:
pass
def connected(self) -> bool:
"""Do we have any replication connections open?
Used to no-op if nothing is connected.
"""
return bool(self.connections)
async def on_REPLICATE(self, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self.is_master:
return
if not self.connections:
raise Exception("Not connected")
for stream_name, stream in self.streams.items():
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))
async def on_USER_SYNC(self, cmd: UserSyncCommand):
user_sync_counter.inc()
if self.is_master:
await self.presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
if self.is_master:
await self.presence_handler.update_external_syncs_clear(cmd.instance_id)
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self.federation_sender:
self.federation_sender.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
remove_pusher_counter.inc()
if self.is_master:
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
self.notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
invalidate_cache_counter.inc()
if self.is_master:
# We invalidate the cache locally, but then also stream that to other
# workers.
await self.store.invalidate_cache_and_stream(
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self.is_master:
await self.store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
async def on_RDATA(self, cmd: RdataCommand):
stream_name = cmd.stream_name
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception("[%s] Failed to parse RDATA: %r", stream_name, cmd.row)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows)
async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name: name of the replication stream for this batch of rows
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
await self.replication_data_handler.on_rdata(stream_name, token, rows)
async def on_POSITION(self, cmd: PositionCommand):
stream = self.streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return
# We protect catching up with a linearizer in case the replicaiton
# connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name):
# Find where we previously streamed up to.
current_token = self.replication_data_handler.get_streams_to_replicate().get(
cmd.stream_name
)
if current_token is None:
logger.debug(
"Got POSITION for stream we're not subscribed to: %s",
cmd.stream_name,
)
return
# Fetch all updates between then and now.
limited = cmd.token != current_token
while limited:
updates, current_token, limited = await stream.get_updates_since(
current_token, cmd.token
)
if updates:
await self.on_rdata(
cmd.stream_name,
current_token,
[stream.parse_row(update[1]) for update in updates],
)
# We've now caught up to position sent to us, notify handler.
await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])
if rows:
await self.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
"""Called when get a new REMOTE_SERVER_UP command."""
if self.is_master:
self.notifier.notify_remote_server_up(cmd.data)
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users.
"""
return self.presence_handler.get_currently_syncing_users()
def send_command(self, cmd: Command):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
for conn in self.connections:
conn.send_command(cmd)
def send_federation_ack(self, token: int):
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
self.send_command(FederationAckCommand(token))
def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
):
"""Poke the master that a user has started/stopped syncing.
"""
self.send_command(
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
)
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
"""Poke the master to remove a pusher for a user
"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(
self,
user_id: str,
access_token: str,
ip: str,
user_agent: str,
device_id: str,
last_seen: int,
):
"""Tell the master that the user made a request.
"""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def stream_update(self, stream_name: str, token: str, data: Any):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))
class DummyReplicationDataHandler:
"""A replication data handler that simply discards all data.
"""
async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
pass
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
return {}
async def on_position(self, stream_name: str, token: int):
pass
class WorkerReplicationDataHandler:
"""A replication data handler that calls slave data stores.
"""
def __init__(self, store):
self.store = store
async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, [])

View File

@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE events 1
< REPLICATE backfill 1
< REPLICATE caches 1
< REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -48,45 +46,40 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
import abc
import fcntl
import logging
import struct
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Tuple
from typing import Any, DefaultDict, Dict, List, Set
from six import iteritems, iterkeys
from six import iteritems
from prometheus_client import Counter
from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
Command,
ErrorCommand,
NameCommand,
PingCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
SyncCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.types import Collection
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.util import Clock
from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
from synapse.server import HomeServer
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@@ -127,16 +120,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
delimiter = b"\n"
# Valid commands we expect to receive
VALID_INBOUND_COMMANDS = [] # type: Collection[str]
# Valid commands we can send
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000
def __init__(self, clock):
def __init__(self, clock, handler):
self.clock = clock
self.handler = handler
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
@@ -176,6 +164,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# can time us out.
self.send_command(PingCommand(self.clock.time_msec()))
self.handler.new_connection(self)
def send_ping(self):
"""Periodically sends a ping and checks if we should close the connection
due to the other side timing out.
@@ -213,11 +203,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
line = line.decode("utf-8")
cmd_name, rest_of_line = line.split(" ", 1)
if cmd_name not in self.VALID_INBOUND_COMMANDS:
logger.error("[%s] invalid command %s", self.id(), cmd_name)
self.send_error("invalid command: %s", cmd_name)
return
self.last_received_command = self.clock.time_msec()
self.inbound_commands_counter[cmd_name] = (
@@ -249,8 +234,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
Args:
cmd: received command
"""
handler = getattr(self, "on_%s" % (cmd.NAME,))
await handler(cmd)
handled = False
# First call any command handlers on this instance. These are for TCP
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd)
def close(self):
logger.warning("[%s] Closing connection", self.id())
@@ -258,6 +258,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.transport.loseConnection()
self.on_connection_closed()
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def send_error(self, error_string, *args):
"""Send an error to remote and close the connection.
"""
@@ -379,6 +382,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.state = ConnectionStates.CLOSED
self.pending_commands = []
self.handler.lost_connection(self)
if self.transport:
self.transport.unregisterProducer()
@@ -402,346 +407,66 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
def __init__(self, server_name, clock, streamer):
BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
def __init__(self, hs, server_name, clock, handler):
BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class
self.server_name = server_name
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)
async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
async def on_USER_SYNC(self, cmd):
await self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
deferreds = [
run_in_background(self.subscribe_to_stream, stream, token)
for stream in iterkeys(self.streamer.streams_by_name)
]
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
await self.subscribe_to_stream(stream_name, token)
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd):
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
async def on_INVALIDATE_CACHE(self, cmd):
await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.streamer.on_remote_server_up(cmd.data)
async def on_USER_IP(self, cmd):
self.streamer.on_user_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
updates down if they have. During that time new updates for the stream
are queued and sent once we've sent down any missed updates.
"""
self.replication_streams.discard(stream_name)
self.connecting_streams.add(stream_name)
try:
# Get missing updates
updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)
# Send all the missing updates
for update in updates:
token, row = update[0], update[1]
self.send_command(RdataCommand(stream_name, token, row))
# We send a POSITION command to ensure that they have an up to
# date token (especially useful if we didn't send any updates
# above)
self.send_command(PositionCommand(stream_name, current_token))
# Now we can send any updates that came in while we were subscribing
pending_rdata = self.pending_rdata.pop(stream_name, [])
updates = []
for token, update in pending_rdata:
# If the token is null, it is part of a batch update. Batches
# are multiple updates that share a single token. To denote
# this, the token is set to None for all tokens in the batch
# except for the last. If we find a None token, we keep looking
# through tokens until we find one that is not None and then
# process all previous updates in the batch as if they had the
# final token.
if token is None:
# Store this update as part of a batch
updates.append(update)
continue
if token <= current_token:
# This update or batch of updates is older than
# current_token, dismiss it
updates = []
continue
updates.append(update)
# Send all updates that are part of this batch with the
# found token
for update in updates:
self.send_command(RdataCommand(stream_name, token, update))
# Clear stored updates
updates = []
# They're now fully subscribed
self.replication_streams.add(stream_name)
except Exception as e:
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
self.send_error("failed to handle replicate: %r", e)
finally:
self.connecting_streams.discard(stream_name)
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
if stream_name in self.replication_streams:
# The client is subscribed to the stream
self.send_command(RdataCommand(stream_name, token, data))
elif stream_name in self.connecting_streams:
# The client is being subscribed to the stream
logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
self.pending_rdata.setdefault(stream_name, []).append((token, data))
else:
# The client isn't subscribed
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
def send_sync(self, data):
self.send_command(SyncCommand(data))
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
"""
The interface for the handler that should be passed to
ClientReplicationStreamProtocol
"""
@abc.abstractmethod
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
raise NotImplementedError()
@abc.abstractmethod
async def on_position(self, stream_name, token):
"""Called when we get new position data."""
raise NotImplementedError()
@abc.abstractmethod
def on_sync(self, data):
"""Called when get a new SYNC command."""
raise NotImplementedError()
@abc.abstractmethod
async def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
raise NotImplementedError()
@abc.abstractmethod
def get_streams_to_replicate(self):
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
raise NotImplementedError()
@abc.abstractmethod
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users."""
raise NotImplementedError()
@abc.abstractmethod
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
raise NotImplementedError()
@abc.abstractmethod
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
raise NotImplementedError()
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
def __init__(
self,
hs: "HomeServer",
client_name: str,
server_name: str,
clock: Clock,
handler: AbstractReplicationClientHandler,
handler,
):
BaseReplicationStreamProtocol.__init__(self, clock)
BaseReplicationStreamProtocol.__init__(self, clock, handler)
self.instance_id = hs.get_instance_id()
self.client_name = client_name
self.server_name = server_name
self.handler = handler
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = set() # type: Set[str]
self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, Any]
self.pending_batches = {} # type: Dict[str, List[Any]]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token)
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
currently_syncing = self.handler.get_currently_syncing_users()
now = self.clock.time_msec()
for user_id in currently_syncing:
self.send_command(UserSyncCommand(user_id, True, now))
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
# This will happen if we don't actually subscribe to any streams
if not self.streams_connecting:
self.handler.finished_connecting()
self.send_command(NameCommand(self.client_name))
self.replicate()
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
await self.handler.on_position(cmd.stream_name, cmd.token)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
def replicate(self):
"""Send the subscription request to the server
"""
if stream_name not in STREAMS_MAP:
raise Exception("Invalid stream name %r" % (stream_name,))
logger.info("[%s] Subscribing to replication streams", self.id())
logger.info(
"[%s] Subscribing to replication stream: %r from %r",
self.id(),
stream_name,
token,
)
self.streams_connecting.add(stream_name)
self.send_command(ReplicateCommand(stream_name, token))
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.handler.update_connection(None)
self.send_command(ReplicateCommand())
# The following simply registers metrics for the replication connections

View File

@@ -17,7 +17,7 @@
import logging
import random
from typing import Any, List
from typing import Dict
from six import itervalues
@@ -25,24 +25,16 @@ from prometheus_client import Counter
from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func
from synapse.util.metrics import Measure
from .protocol import ServerReplicationStreamProtocol
from .streams import STREAMS_MAP
from .streams import STREAMS_MAP, Stream
from .streams.federation import FederationStream
stream_updates_counter = Counter(
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -52,13 +44,18 @@ class ReplicationStreamProtocolFactory(Factory):
"""
def __init__(self, hs):
self.streamer = ReplicationStreamer(hs)
self.handler = hs.get_tcp_replication()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
self.hs = hs
# Ensure the replication streamer is started if we register a
# replication server endpoint.
hs.get_replication_streamer()
def buildProtocol(self, addr):
return ServerReplicationStreamProtocol(
self.server_name, self.clock, self.streamer
self.hs, self.server_name, self.clock, self.handler
)
@@ -78,16 +75,6 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
self.connections = [] # type: List[ServerReplicationStreamProtocol]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self.connections),
)
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
@@ -99,39 +86,22 @@ class ReplicationStreamer(object):
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
LaterGauge(
"synapse_replication_tcp_resource_connections_per_stream",
"",
["stream_name"],
lambda: {
(stream_name,): len(
[
conn
for conn in self.connections
if stream_name in conn.replication_streams
]
)
for stream_name in self.streams_by_name
},
)
self.federation_sender = None
if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
self.pending_updates = False
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
self.client = hs.get_tcp_replication()
def on_shutdown(self):
# close all connections on shutdown
for conn in self.connections:
conn.send_error("server shutting down")
def get_streams(self) -> Dict[str, Stream]:
"""Get a mapp from stream name to stream instance.
"""
return self.streams_by_name
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
@@ -140,7 +110,7 @@ class ReplicationStreamer(object):
This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed
"""
if not self.connections:
if not self.client.connected():
# Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall beihind forever
for stream in self.streams:
@@ -166,11 +136,6 @@ class ReplicationStreamer(object):
self.pending_updates = False
with Measure(self.clock, "repl.stream.get_updates"):
# First we tell the streams that they should update their
# current tokens.
for stream in self.streams:
stream.advance_current_token()
all_streams = self.streams
if self._replication_torture_level is not None:
@@ -180,7 +145,7 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
if stream.last_token == stream.upto_token:
if stream.last_token == stream.current_token():
continue
if self._replication_torture_level:
@@ -192,18 +157,17 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
stream.upto_token,
stream.current_token(),
)
try:
updates, current_token = await stream.get_updates()
updates, current_token, limited = await stream.get_updates()
self.pending_updates |= limited
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
logger.debug(
"Sending %d updates to %d connections",
len(updates),
len(self.connections),
"Sending %d updates", len(updates),
)
if updates:
@@ -219,116 +183,17 @@ class ReplicationStreamer(object):
# token. See RdataCommand for more details.
batched_updates = _batch_updates(updates)
for conn in self.connections:
for token, row in batched_updates:
try:
conn.stream_update(stream.NAME, token, row)
except Exception:
logger.exception("Failed to replicate")
for token, row in batched_updates:
try:
self.client.stream_update(stream.NAME, token, row)
except Exception:
logger.exception("Failed to replicate")
logger.debug("No more pending updates, breaking poke loop")
finally:
self.pending_updates = False
self.is_looping = False
@measure_func("repl.get_stream_updates")
async def get_stream_updates(self, stream_name, token):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
stream = self.streams_by_name.get(stream_name, None)
if not stream:
raise Exception("unknown stream %s", stream_name)
return await stream.get_updates_since(token)
@measure_func("repl.federation_ack")
def federation_ack(self, token):
"""We've received an ack for federation stream from a client.
"""
federation_ack_counter.inc()
if self.federation_sender:
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
await self.presence_handler.update_external_syncs_row(
conn_id, user_id, is_syncing, last_sync_ms
)
@measure_func("repl.on_remove_pusher")
async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
"""
remove_pusher_counter.inc()
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id
)
self.notifier.on_new_replication_data()
@measure_func("repl.on_invalidate_cache")
async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
"""The client has asked us to invalidate a cache
"""
invalidate_cache_counter.inc()
# We invalidate the cache locally, but then also stream that to other
# workers.
await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
@measure_func("repl.on_user_ip")
async def on_user_ip(
self, user_id, access_token, ip, user_agent, device_id, last_seen
):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
await self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen
)
await self._server_notices_sender.on_user_ip(user_id)
@measure_func("repl.on_remote_server_up")
def on_remote_server_up(self, server: str):
self.notifier.notify_remote_server_up(server)
def send_remote_server_up(self, server: str):
for conn in self.connections:
conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
Used in tests.
"""
for conn in self.connections:
conn.send_sync(data)
def new_connection(self, connection):
"""A new client connection has been established
"""
self.connections.append(connection)
def lost_connection(self, connection):
"""A client connection has been lost
"""
try:
self.connections.remove(connection)
except ValueError:
pass
# We need to tell the presence handler that the connection has been
# lost so that it can handle any ongoing syncs on that connection.
run_as_background_process(
"update_external_syncs_clear",
self.presence_handler.update_external_syncs_clear,
connection.conn_id,
)
def _batch_updates(updates):
"""Takes a list of updates of form [(token, row)] and sets the token to

View File

@@ -25,26 +25,66 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
from . import _base, events, federation
from typing import Dict, Type
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
CachesStream,
DeviceListsStream,
GroupServerStream,
PresenceStream,
PublicRoomsStream,
PushersStream,
PushRulesStream,
ReceiptsStream,
Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
UserSignatureStream,
)
from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.federation import FederationStream
STREAMS_MAP = {
stream.NAME: stream
for stream in (
events.EventsStream,
_base.BackfillStream,
_base.PresenceStream,
_base.TypingStream,
_base.ReceiptsStream,
_base.PushRulesStream,
_base.PushersStream,
_base.CachesStream,
_base.PublicRoomsStream,
_base.DeviceListsStream,
_base.ToDeviceStream,
federation.FederationStream,
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.GroupServerStream,
_base.UserSignatureStream,
EventsStream,
BackfillStream,
PresenceStream,
TypingStream,
ReceiptsStream,
PushRulesStream,
PushersStream,
CachesStream,
PublicRoomsStream,
DeviceListsStream,
ToDeviceStream,
FederationStream,
TagAccountDataStream,
AccountDataStream,
GroupServerStream,
UserSignatureStream,
)
}
} # type: Dict[str, Type[Stream]]
__all__ = [
"STREAMS_MAP",
"Stream",
"BackfillStream",
"PresenceStream",
"TypingStream",
"ReceiptsStream",
"PushRulesStream",
"PushersStream",
"CachesStream",
"PublicRoomsStream",
"DeviceListsStream",
"ToDeviceStream",
"TagAccountDataStream",
"AccountDataStream",
"GroupServerStream",
"UserSignatureStream",
]

View File

@@ -14,114 +14,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import namedtuple
from typing import Any, List, Optional
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000
BackfillStreamRow = namedtuple(
"BackfillStreamRow",
(
"event_id", # str
"room_id", # str
"type", # str
"state_key", # str, optional
"redacts", # str, optional
"relates_to", # str, optional
),
)
PresenceStreamRow = namedtuple(
"PresenceStreamRow",
(
"user_id", # str
"state", # str
"last_active_ts", # int
"last_federation_update_ts", # int
"last_user_sync_ts", # int
"status_msg", # str
"currently_active", # bool
),
)
TypingStreamRow = namedtuple(
"TypingStreamRow", ("room_id", "user_ids") # str # list(str)
)
ReceiptsStreamRow = namedtuple(
"ReceiptsStreamRow",
(
"room_id", # str
"receipt_type", # str
"user_id", # str
"event_id", # str
"data", # dict
),
)
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
PushersStreamRow = namedtuple(
"PushersStreamRow",
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
)
# Some type aliases to make things a bit easier.
@attr.s
class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
# A stream position token
Token = int
Attributes:
cache_func: Name of the cached function.
keys: The entry in the cache to invalidate. If None then will
invalidate all.
invalidation_ts: Timestamp of when the invalidation took place.
"""
cache_func = attr.ib(type=str)
keys = attr.ib(type=Optional[List[Any]])
invalidation_ts = attr.ib(type=int)
PublicRoomsStreamRow = namedtuple(
"PublicRoomsStreamRow",
(
"room_id", # str
"visibility", # str
"appservice_id", # str, optional
"network_id", # str, optional
),
)
DeviceListsStreamRow = namedtuple(
"DeviceListsStreamRow", ("user_id", "destination") # str # str
)
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
AccountDataStreamRow = namedtuple(
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
)
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
StreamRow = Tuple[Token, tuple]
class Stream(object):
"""Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last
time it was called up until the point `advance_current_token` was called.
time it was called.
"""
NAME = None # type: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod
def parse_row(cls, row):
@@ -139,80 +65,56 @@ class Stream(object):
return cls.ROW_TYPE(*row)
def __init__(self, hs):
# The token from which we last asked for updates
self.last_token = self.current_token()
# The token that we will get updates up to
self.upto_token = self.current_token()
def advance_current_token(self):
"""Updates `upto_token` to "now", which updates up until which point
get_updates[_since] will fetch rows till.
"""
self.upto_token = self.current_token()
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
self.upto_token = self.current_token()
self.last_token = self.upto_token
self.last_token = self.current_token()
async def get_updates(self):
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before),
until the `upto_token`
since the stream was constructed if it hadn't been called before).
Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]:
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries, `new_last_token` is the new
position in stream, and `limited` is whether there are more updates
to fetch.
"""
updates, current_token = await self.get_updates_since(self.last_token)
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token
)
self.last_token = current_token
return updates, current_token
return updates, current_token, limited
async def get_updates_since(self, from_token):
async def get_updates_since(
self, from_token: Token, upto_token: Token, limit: int = 100
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]:
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries, `new_last_token` is the new
position in stream, and `limited` is whether there are more updates
to fetch.
"""
if from_token in ("NOW", "now"):
return [], self.upto_token
current_token = self.upto_token
from_token = int(from_token)
if from_token == current_token:
return [], current_token
if from_token == upto_token:
return [], upto_token, False
logger.info("get_updates_since: %s", self.__class__)
if self._LIMITED:
rows = await self.update_function(
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
rows = await self.update_function(from_token, current_token)
updates = [(row[0], row[1:]) for row in rows]
# check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator.
if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
return updates, current_token
updates, upto_token, limited = await self.update_function(
from_token, upto_token, limit=limit,
)
return updates, upto_token, limited
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -223,9 +125,8 @@ class Stream(object):
"""
raise NotImplementedError()
def update_function(self, from_token, current_token, limit=None):
"""Get updates between from_token and to_token. If Stream._LIMITED is
True then limit is provided, otherwise it's not.
def update_function(self, from_token, current_token, limit):
"""Get updates between from_token and to_token.
Returns:
Deferred(list(tuple)): the first entry in the tuple is the token for
@@ -235,52 +136,144 @@ class Stream(object):
raise NotImplementedError()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) == limit:
upto_token = rows[-1][0]
limited = True
return updates, upto_token, limited
return update_function
def make_http_update_function(
hs, stream_name: str
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
"""
client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function(
from_token: int, upto_token: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
return await client(
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
limit=limit,
)
return update_function
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
"""
BackfillStreamRow = namedtuple(
"BackfillStreamRow",
(
"event_id", # str
"room_id", # str
"type", # str
"state_key", # str, optional
"redacts", # str, optional
"relates_to", # str, optional
),
)
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore
self.update_function = store.get_all_new_backfill_event_rows # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs)
class PresenceStream(Stream):
PresenceStreamRow = namedtuple(
"PresenceStreamRow",
(
"user_id", # str
"state", # str
"last_active_ts", # int
"last_federation_update_ts", # int
"last_user_sync_ts", # int
"status_msg", # str
"currently_active", # bool
),
)
NAME = "presence"
_LIMITED = False
ROW_TYPE = PresenceStreamRow
def __init__(self, hs):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
self._is_worker = hs.config.worker_app is not None
self.current_token = store.get_current_presence_token # type: ignore
self.update_function = presence_handler.get_all_presence_updates # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(PresenceStream, self).__init__(hs)
class TypingStream(Stream):
TypingStreamRow = namedtuple(
"TypingStreamRow", ("room_id", "user_ids") # str # list(str)
)
NAME = "typing"
_LIMITED = False
ROW_TYPE = TypingStreamRow
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore
self.update_function = typing_handler.get_all_typing_updates # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(TypingStream, self).__init__(hs)
class ReceiptsStream(Stream):
ReceiptsStreamRow = namedtuple(
"ReceiptsStreamRow",
(
"room_id", # str
"receipt_type", # str
"user_id", # str
"event_id", # str
"data", # dict
),
)
NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow
@@ -288,7 +281,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore
self.update_function = store.get_all_updated_receipts # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -297,6 +290,8 @@ class PushRulesStream(Stream):
"""A user has changed their push rules
"""
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
@@ -310,13 +305,24 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
return [(row[0], row[2]) for row in rows]
limited = False
if len(rows) == limit:
to_token = rows[-1][0]
limited = True
return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream):
"""A user has added/changed/removed a pusher
"""
PushersStreamRow = namedtuple(
"PushersStreamRow",
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
)
NAME = "pushers"
ROW_TYPE = PushersStreamRow
@@ -324,7 +330,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore
self.update_function = store.get_all_updated_pushers_rows # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs)
@@ -334,6 +340,21 @@ class CachesStream(Stream):
the cache on the workers
"""
@attr.s
class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
Attributes:
cache_func: Name of the cached function.
keys: The entry in the cache to invalidate. If None then will
invalidate all.
invalidation_ts: Timestamp of when the invalidation took place.
"""
cache_func = attr.ib(type=str)
keys = attr.ib(type=Optional[List[Any]])
invalidation_ts = attr.ib(type=int)
NAME = "caches"
ROW_TYPE = CachesStreamRow
@@ -341,7 +362,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore
self.update_function = store.get_all_updated_caches # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs)
@@ -350,6 +371,16 @@ class PublicRoomsStream(Stream):
"""The public rooms list changed
"""
PublicRoomsStreamRow = namedtuple(
"PublicRoomsStreamRow",
(
"room_id", # str
"visibility", # str
"appservice_id", # str, optional
"network_id", # str, optional
),
)
NAME = "public_rooms"
ROW_TYPE = PublicRoomsStreamRow
@@ -357,24 +388,28 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore
self.update_function = store.get_all_new_public_rooms # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs)
class DeviceListsStream(Stream):
"""Someone added/changed/removed a device
"""Either a user has updated their devices or a remote server needs to be
told about a device update.
"""
@attr.s
class DeviceListsStreamRow:
entity = attr.ib(type=str)
NAME = "device_lists"
_LIMITED = False
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -383,6 +418,8 @@ class ToDeviceStream(Stream):
"""New to_device messages for a client
"""
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
@@ -390,7 +427,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore
self.update_function = store.get_all_new_device_messages # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -399,6 +436,10 @@ class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room
"""
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
@@ -406,7 +447,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore
self.update_function = store.get_all_updated_tags # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -415,6 +456,10 @@ class AccountDataStream(Stream):
"""Global or per room account data was changed
"""
AccountDataStreamRow = namedtuple(
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
@@ -422,10 +467,11 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
async def update_function(self, from_token, to_token, limit):
async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -440,6 +486,11 @@ class AccountDataStream(Stream):
class GroupServerStream(Stream):
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
NAME = "groups"
ROW_TYPE = GroupsStreamRow
@@ -447,7 +498,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore
self.update_function = store.get_all_groups_changes # type: ignore
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -456,14 +507,15 @@ class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key
"""
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
NAME = "user_signature"
_LIMITED = False
ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs)

View File

@@ -19,7 +19,7 @@ from typing import Tuple, Type
import attr
from ._base import Stream
from ._base import Stream, db_query_to_update_function
"""Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs)
async def update_function(self, from_token, current_token, limit=None):
async def _update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)

View File

@@ -15,15 +15,9 @@
# limitations under the License.
from collections import namedtuple
from ._base import Stream
from twisted.internet import defer
FederationStreamRow = namedtuple(
"FederationStreamRow",
(
"type", # str, the type of data as defined in the BaseFederationRows
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
),
)
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
class FederationStream(Stream):
@@ -31,13 +25,28 @@ class FederationStream(Stream):
sending disabled.
"""
FederationStreamRow = namedtuple(
"FederationStreamRow",
(
"type", # str, the type of data as defined in the BaseFederationRows
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
),
)
NAME = "federation"
ROW_TYPE = FederationStreamRow
_QUERY_MASTER = True
def __init__(self, hs):
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = federation_sender.get_replication_rows # type: ignore
# Not all synapse instances will have a federation sender instance,
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
# so we stub the stream out when that is the case.
if hs.config.worker_app is None or hs.should_send_federation():
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
else:
self.current_token = lambda: 0 # type: ignore
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
super(FederationStream, self).__init__(hs)

View File

@@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
keys,
notifications,
openid,
password_policy,
read_marker,
receipts,
register,
@@ -118,6 +119,7 @@ class ClientRestResource(JsonResource):
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
# moving to /_synapse/admin
synapse.rest.admin.register_servlets_for_client_rest_resource(

View File

@@ -29,7 +29,11 @@ from synapse.rest.admin._base import (
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet
from synapse.rest.admin.rooms import (
JoinRoomAliasServlet,
ListRoomRestServlet,
ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import (
AccountValidityRenewServlet,
@@ -189,6 +193,7 @@ def register_servlets(hs, http_server):
"""
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)

View File

@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Optional
from synapse.api.constants import Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -29,7 +30,7 @@ from synapse.rest.admin._base import (
historical_admin_path_patterns,
)
from synapse.storage.data_stores.main.room import RoomSortOrder
from synapse.types import create_requester
from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -237,3 +238,75 @@ class ListRoomRestServlet(RestServlet):
response["prev_batch"] = 0
return 200, response
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_handlers().admin_handler
self.state_handler = hs.get_state_handler()
async def on_POST(self, request, room_identifier):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["user_id"])
target_user = UserID.from_string(content["user_id"])
if not self.hs.is_mine(target_user):
raise SynapseError(400, "This endpoint can only be used with local users")
if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found")
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
fake_requester = create_requester(target_user)
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
await self.room_member_handler.update_membership(
requester=requester,
target=fake_requester.user,
room_id=room_id,
action="invite",
remote_room_hosts=remote_room_hosts,
ratelimit=False,
)
await self.room_member_handler.update_membership(
requester=fake_requester,
target=fake_requester.user,
room_id=room_id,
action="join",
remote_room_hosts=remote_room_hosts,
ratelimit=False,
)
return 200, {"room_id": room_id}

View File

@@ -14,11 +14,6 @@
# limitations under the License.
import logging
import xml.etree.ElementTree as ET
from six.moves import urllib
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -28,10 +23,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.push.mailer import load_jinja2_templates
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
logger = logging.getLogger(__name__)
@@ -402,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def on_GET(self, request):
def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
@@ -411,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet):
request.redirect(sso_url)
finish_request(request)
def get_sso_url(self, client_redirect_url):
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
client_redirect_url (bytes): the URL that we should redirect the
client_redirect_url: the URL that we should redirect the
client to when everything is done
Returns:
bytes: URL to redirect to
URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
@@ -427,19 +422,10 @@ class BaseSSORedirectServlet(RestServlet):
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url.encode("ascii")
self.cas_service_url = hs.config.cas_service_url.encode("ascii")
self._cas_handler = hs.get_cas_handler()
def get_sso_url(self, client_redirect_url):
client_redirect_url_param = urllib.parse.urlencode(
{b"redirectUrl": client_redirect_url}
).encode("ascii")
hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
service_param = urllib.parse.urlencode(
{b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
).encode("ascii")
return b"%s/login?%s" % (self.cas_server_url, service_param)
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
return self._cas_handler.handle_redirect_request(client_redirect_url)
class CasTicketServlet(RestServlet):
@@ -447,81 +433,15 @@ class CasTicketServlet(RestServlet):
def __init__(self, hs):
super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_proxied_http_client()
self._cas_handler = hs.get_cas_handler()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> None:
client_redirect_url = parse_string(request, "redirectUrl", required=True)
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url,
}
try:
body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
result = await self.handle_cas_response(request, body, client_redirect_url)
return result
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
displayname = attributes.pop(self.cas_displayname_attribute, None)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth(
user, request, client_redirect_url, displayname
ticket = parse_string(request, "ticket", required=True)
await self._cas_handler.handle_ticket_request(
request, client_redirect_url, ticket
)
def parse_cas_response(self, cas_response_body):
user = None
attributes = {}
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
class SAMLRedirectServlet(BaseSSORedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -529,72 +449,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
def get_sso_url(self, client_redirect_url):
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
class SSOAuthHandler(object):
"""
Utility class for Resources and Servlets which handle the response from a SSO
service
Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator()
# Load the redirect page HTML template
self._template = load_jinja2_templates(
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
)[0]
self._server_name = hs.config.server_name
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
async def on_successful_auth(
self, username, request, client_redirect_url, user_display_name=None
):
"""Called once the user has successfully authenticated with the SSO.
Registers the user if necessary, and then returns a redirect (with
a login token) to the client.
Args:
username (unicode|bytes): the remote user id. We'll map this onto
something sane for a MXID localpath.
request (SynapseRequest): the incoming request from the browser. We'll
respond to it with a redirect.
client_redirect_url (unicode): the redirect_url the client gave us when
it first started the process.
user_display_name (unicode|None): if set, and we have to register a new user,
we will set their displayname to this.
Returns:
Deferred[none]: Completes once we have handled the request.
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name
)
self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url
)
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled:

View File

@@ -234,13 +234,16 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
params = await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
requester, request, body, self.hs.get_ip_from_request(request),
)
user_id = requester.user.to_string()
else:
requester = None
result, params, _ = await self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
[[LoginType.EMAIL_IDENTITY]],
request,
body,
self.hs.get_ip_from_request(request),
)
if LoginType.EMAIL_IDENTITY in result:
@@ -308,7 +311,7 @@ class DeactivateAccountRestServlet(RestServlet):
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
requester, request, body, self.hs.get_ip_from_request(request),
)
result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server")
@@ -602,6 +605,11 @@ class ThreepidRestServlet(RestServlet):
return 200, {"threepids": threepids}
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -646,6 +654,11 @@ class ThreepidAddRestServlet(RestServlet):
@interactive_auth_handler
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -656,7 +669,7 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
requester, request, body, self.hs.get_ip_from_request(request),
)
validation_session = await self.identity_handler.validate_threepid_session(
@@ -741,10 +754,16 @@ class ThreepidDeleteRestServlet(RestServlet):
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])

View File

@@ -142,14 +142,6 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
elif stagetype == LoginType.TERMS:
html = TERMS_TEMPLATE % {
"session": session,
@@ -158,17 +150,19 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
else:
raise SynapseError(404, "Unknown auth stage type")
# Render the HTML and return.
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
async def on_POST(self, request, stagetype):
session = parse_string(request, "session")
@@ -196,15 +190,6 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
@@ -225,17 +210,19 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
else:
raise SynapseError(404, "Unknown auth stage type")
# Render the HTML and return.
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
def on_OPTIONS(self, _):
return 200, {}

View File

@@ -81,7 +81,7 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
requester, request, body, self.hs.get_ip_from_request(request),
)
await self.device_handler.delete_devices(
@@ -127,7 +127,7 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
requester, request, body, self.hs.get_ip_from_request(request),
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)

View File

@@ -263,7 +263,7 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
requester, request, body, self.hs.get_ip_from_request(request),
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)

View File

@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.http.servlet import RestServlet
from ._base import client_patterns
logger = logging.getLogger(__name__)
class PasswordPolicyServlet(RestServlet):
PATTERNS = client_patterns("/password_policy$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(PasswordPolicyServlet, self).__init__()
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
def on_GET(self, request):
if not self.enabled or not self.policy:
return (200, {})
policy = {}
for param in [
"minimum_length",
"require_digit",
"require_symbol",
"require_lowercase",
"require_uppercase",
]:
if param in self.policy:
policy["m.%s" % param] = self.policy[param]
return (200, policy)
def register_servlets(hs, http_server):
PasswordPolicyServlet(hs).register(http_server)

View File

@@ -373,6 +373,7 @@ class RegisterRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
self._registration_flows = _calculate_registration_flows(
@@ -420,6 +421,7 @@ class RegisterRestServlet(RestServlet):
or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(body["password"])
desired_username = None
if "username" in body:
@@ -499,7 +501,10 @@ class RegisterRestServlet(RestServlet):
)
auth_result, params, session_id = await self.auth_handler.check_auth(
self._registration_flows, body, self.hs.get_ip_from_request(request)
self._registration_flows,
request,
body,
self.hs.get_ip_from_request(request),
)
# Check that we're not trying to register a denied 3pid.

View File

@@ -188,7 +188,7 @@ class RoomKeysServlet(RestServlet):
"""
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version")
version = parse_string(request, "version", required=True)
room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id

View File

@@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource):
b" media-src 'self';"
b" object-src 'self';",
)
request.setHeader(
b"Referrer-Policy", b"no-referrer",
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)

View File

@@ -24,7 +24,6 @@ from six import iteritems
import twisted.internet.error
import twisted.web.http
from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -114,15 +113,14 @@ class MediaRepository(object):
"update_recently_accessed_media", self._update_recently_accessed
)
@defer.inlineCallbacks
def _update_recently_accessed(self):
async def _update_recently_accessed(self):
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()
yield self.store.update_cached_last_access_time(
await self.store.update_cached_last_access_time(
local_media, remote_media, self.clock.time_msec()
)
@@ -138,8 +136,7 @@ class MediaRepository(object):
else:
self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def create_content(
async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
"""Store uploaded content for a local user and return the mxc URL
@@ -158,11 +155,11 @@ class MediaRepository(object):
file_info = FileInfo(server_name=None, file_id=media_id)
fname = yield self.media_storage.store_file(content, file_info)
fname = await self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media(
await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -171,12 +168,11 @@ class MediaRepository(object):
user_id=auth_user,
)
yield self._generate_thumbnails(None, media_id, media_id, media_type)
await self._generate_thumbnails(None, media_id, media_id, media_type)
return "mxc://%s/%s" % (self.server_name, media_id)
@defer.inlineCallbacks
def get_local_media(self, request, media_id, name):
async def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
@@ -190,7 +186,7 @@ class MediaRepository(object):
Deferred: Resolves once a response has successfully been written
to request
"""
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
@@ -204,13 +200,12 @@ class MediaRepository(object):
file_info = FileInfo(None, media_id, url_cache=url_cache)
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
@defer.inlineCallbacks
def get_remote_media(self, request, server_name, media_id, name):
async def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
@@ -236,8 +231,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -246,14 +241,13 @@ class MediaRepository(object):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
yield respond_with_responder(
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)
@defer.inlineCallbacks
def get_remote_media_info(self, server_name, media_id):
async def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -274,8 +268,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -286,8 +280,7 @@ class MediaRepository(object):
return media_info
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
async def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -299,7 +292,7 @@ class MediaRepository(object):
Returns:
Deferred[(Responder, media_info)]
"""
media_info = yield self.store.get_cached_remote_media(server_name, media_id)
media_info = await self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
@@ -317,19 +310,18 @@ class MediaRepository(object):
logger.info("Media is quarantined")
raise NotFoundError()
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info
# Failed to find the file anywhere, lets download it.
media_info = yield self._download_remote_file(server_name, media_id, file_id)
media_info = await self._download_remote_file(server_name, media_id, file_id)
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id, file_id):
async def _download_remote_file(self, server_name, media_id, file_id):
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -351,7 +343,7 @@ class MediaRepository(object):
("/_matrix/media/v1/download", server_name, media_id)
)
try:
length, headers = yield self.client.get_file(
length, headers = await self.client.get_file(
server_name,
request_path,
output_stream=f,
@@ -397,7 +389,7 @@ class MediaRepository(object):
)
raise SynapseError(502, "Failed to fetch remote media")
yield finish()
await finish()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
@@ -405,7 +397,7 @@ class MediaRepository(object):
logger.info("Stored remote media in file %r", fname)
yield self.store.store_cached_remote_media(
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
@@ -423,7 +415,7 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
await self._generate_thumbnails(server_name, media_id, file_id, media_type)
return media_info
@@ -458,16 +450,15 @@ class MediaRepository(object):
return t_byte_source
@defer.inlineCallbacks
def generate_local_exact_thumbnail(
async def generate_local_exact_thumbnail(
self, media_id, t_width, t_height, t_method, t_type, url_cache
):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -490,7 +481,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -500,22 +491,21 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
yield self.store.store_local_thumbnail(
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return output_path
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(
async def generate_remote_exact_thumbnail(
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -537,7 +527,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -547,7 +537,7 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
yield self.store.store_remote_media_thumbnail(
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -560,8 +550,7 @@ class MediaRepository(object):
return output_path
@defer.inlineCallbacks
def _generate_thumbnails(
async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
"""Generate and store thumbnails for an image.
@@ -582,7 +571,7 @@ class MediaRepository(object):
if not requirements:
return
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
@@ -600,7 +589,7 @@ class MediaRepository(object):
return
if thumbnailer.transpose_method is not None:
m_width, m_height = yield defer_to_thread(
m_width, m_height = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)
@@ -620,11 +609,11 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
@@ -646,7 +635,7 @@ class MediaRepository(object):
url_cache=url_cache,
)
output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -656,7 +645,7 @@ class MediaRepository(object):
# Write to database
if server_name:
yield self.store.store_remote_media_thumbnail(
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -667,15 +656,14 @@ class MediaRepository(object):
t_len,
)
else:
yield self.store.store_local_thumbnail(
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return {"width": m_width, "height": m_height}
@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
old_media = yield self.store.get_remote_media_before(before_ts)
async def delete_old_remote_media(self, before_ts):
old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@@ -689,7 +677,7 @@ class MediaRepository(object):
# TODO: Should we delete from the backup store
with (yield self.remote_media_linearizer.queue(key)):
with (await self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
@@ -705,7 +693,7 @@ class MediaRepository(object):
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
yield self.store.delete_remote_media(origin, media_id)
await self.store.delete_remote_media(origin, media_id)
deleted += 1
return {"deleted": deleted}

View File

@@ -165,8 +165,7 @@ class PreviewUrlResource(DirectServeResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
@defer.inlineCallbacks
def _do_preview(self, url, user, ts):
async def _do_preview(self, url, user, ts):
"""Check the db, and download the URL and build a preview
Args:
@@ -179,7 +178,7 @@ class PreviewUrlResource(DirectServeResource):
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts)
cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
@@ -192,13 +191,13 @@ class PreviewUrlResource(DirectServeResource):
og = og.encode("utf8")
return og
media_info = yield self._download_url(url, user)
media_info = await self._download_url(url, user)
logger.debug("got media_info of '%s'", media_info)
if _is_media(media_info["media_type"]):
file_id = media_info["filesystem_id"]
dims = yield self.media_repo._generate_thumbnails(
dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, media_info["media_type"], url_cache=True
)
@@ -248,14 +247,14 @@ class PreviewUrlResource(DirectServeResource):
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if "og:image" in og and og["og:image"]:
image_info = yield self._download_url(
image_info = await self._download_url(
_rebase_url(og["og:image"], media_info["uri"]), user
)
if _is_media(image_info["media_type"]):
# TODO: make sure we don't choke on white-on-transparent images
file_id = image_info["filesystem_id"]
dims = yield self.media_repo._generate_thumbnails(
dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, image_info["media_type"], url_cache=True
)
if dims:
@@ -293,7 +292,7 @@ class PreviewUrlResource(DirectServeResource):
jsonog = json.dumps(og)
# store OG in history-aware DB cache
yield self.store.store_url_cache(
await self.store.store_url_cache(
url,
media_info["response_code"],
media_info["etag"],
@@ -305,8 +304,7 @@ class PreviewUrlResource(DirectServeResource):
return jsonog.encode("utf8")
@defer.inlineCallbacks
def _download_url(self, url, user):
async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -318,7 +316,7 @@ class PreviewUrlResource(DirectServeResource):
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
logger.debug("Trying to get url '%s'", url)
length, headers, uri, code = yield self.client.get_file(
length, headers, uri, code = await self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size
)
except SynapseError:
@@ -345,7 +343,7 @@ class PreviewUrlResource(DirectServeResource):
% (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
yield finish()
await finish()
try:
if b"Content-Type" in headers:
@@ -356,7 +354,7 @@ class PreviewUrlResource(DirectServeResource):
download_name = get_filename_from_headers(headers)
yield self.store.store_local_media(
await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -393,8 +391,7 @@ class PreviewUrlResource(DirectServeResource):
"expire_url_cache_data", self._expire_url_cache_data
)
@defer.inlineCallbacks
def _expire_url_cache_data(self):
async def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
@@ -403,12 +400,12 @@ class PreviewUrlResource(DirectServeResource):
logger.info("Running url preview cache expiry")
if not (yield self.store.db.updates.has_completed_background_updates()):
if not (await self.store.db.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
# First we delete expired url cache entries
media_ids = yield self.store.get_expired_url_cache(now)
media_ids = await self.store.get_expired_url_cache(now)
removed_media = []
for media_id in media_ids:
@@ -430,7 +427,7 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
yield self.store.delete_url_cache(removed_media)
await self.store.delete_url_cache(removed_media)
if removed_media:
logger.info("Deleted %d entries from url cache", len(removed_media))
@@ -440,7 +437,7 @@ class PreviewUrlResource(DirectServeResource):
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
expire_before = now - 2 * 24 * 60 * 60 * 1000
media_ids = yield self.store.get_url_cache_media_before(expire_before)
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
for media_id in media_ids:
@@ -478,7 +475,7 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
yield self.store.delete_url_cache_media(removed_media)
await self.store.delete_url_cache_media(removed_media)
logger.info("Deleted %d media from url cache", len(removed_media))

View File

@@ -16,8 +16,6 @@
import logging
from twisted.internet import defer
from synapse.http.server import (
DirectServeResource,
set_cors_headers,
@@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource):
)
self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(
async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type
):
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
@@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(request, responder, t_type, t_length)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_local_thumbnail(
async def _select_or_generate_local_thumbnail(
self,
request,
media_id,
@@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info["thumbnail_width"] == desired_width
t_h = info["thumbnail_height"] == desired_height
@@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
yield respond_with_responder(request, responder, t_type, t_length)
await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
file_path = await self.media_repo.generate_local_exact_thumbnail(
media_id,
desired_width,
desired_height,
@@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(
async def _select_or_generate_remote_thumbnail(
self,
request,
server_name,
@@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
yield respond_with_responder(request, responder, t_type, t_length)
await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
file_path = await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id,
media_id,
@@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
def _respond_remote_thumbnail(
async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(request, responder, t_type, t_length)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)

Some files were not shown because too many files have changed in this diff Show More