1
0

Compare commits

...

36 Commits

Author SHA1 Message Date
Andrew Morgan
3dace4b1aa uh 2020-03-25 15:50:40 +00:00
Andrew Morgan
4ac60a17a5 Possibly appease mypy 2020-03-25 15:46:28 +00:00
Andrew Morgan
f5fd9b98c7 Don't import Sqlite3Engine unless running synapse with sqlite3 2020-03-25 15:42:25 +00:00
Andrew Morgan
8895c38202 Use MYPY variable instead 2020-03-25 15:31:44 +00:00
Andrew Morgan
e0ee1b2224 __future__ import 2020-03-25 15:29:48 +00:00
Andrew Morgan
14c4f08f5c Add changelog 2020-03-25 15:24:58 +00:00
Andrew Morgan
cb76e53b7f Only import sqlite3 by default if running mypy checks 2020-03-25 15:23:46 +00:00
Erik Johnston
4cff617df1 Move catchup of replication streams to worker. (#7024)
This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date.
2020-03-25 14:54:01 +00:00
Andrew Morgan
7bab642707 Various cleanups to INSTALL.md (#7141) 2020-03-25 13:56:40 +00:00
Erik Johnston
b1cfaf08af Merge pull request #7133 from matrix-org/erikj/fix_worker_startup
Fix starting workers when federation sending not split out.
2020-03-25 09:42:39 +00:00
Richard van der Hoff
39230d2171 Clean up some LoggingContext stuff (#7120)
* Pull Sentinel out of LoggingContext

... and drop a few unnecessary references to it

* Factor out LoggingContext.current_context

move `current_context` and `set_context` out to top-level functions.

Mostly this means that I can more easily trace what's actually referring to
LoggingContext, but I think it's generally neater.

* move copy-to-parent into `stop`

this really just makes `start` and `stop` more symetric. It also means that it
behaves correctly if you manually `set_log_context` rather than using the
context manager.

* Replace `LoggingContext.alive` with `finished`

Turn `alive` into `finished` and make it a bit better defined.
2020-03-24 14:45:33 +00:00
Naugrimm
1fcf9c6f95 Fix CAS redirect url (#6634)
Build the same service URL when requesting the CAS ticket and when calling the proxyValidate URL.
2020-03-24 11:59:04 +00:00
Erik Johnston
d6828c129f Newsfile 2020-03-24 10:36:44 +00:00
Erik Johnston
c816072d47 Fix starting workers when federation sending not split out. 2020-03-24 10:35:00 +00:00
Patrick Cloke
190ab593b7 Use the proper error code when a canonical alias that does not exist is used. (#7109) 2020-03-23 15:21:54 -04:00
Kartikaya Gupta (kats)
e341518f92 Update pre-built package name for FreeBSD (#7107). (#7107)
Signed-off-by: Kartikaya Gupta <kats@trevize.staktrace.com>
2020-03-23 15:31:02 +00:00
Richard van der Hoff
a564b92d37 Convert *StreamRow classes to inner classes (#7116)
This just helps keep the rows closer to their streams, so that it's easier to
see what the format of each stream is.
2020-03-23 13:59:11 +00:00
Richard van der Hoff
5126cb1253 Merge branch 'master' into develop 2020-03-23 13:54:29 +00:00
Richard van der Hoff
229eb81498 Merge tag 'v1.12.0'
Synapse 1.12.0 (2020-03-23)
===========================

No significant changes since 1.12.0rc1.

Debian packages and Docker images are rebuilt using the latest versions of
dependency libraries, including Twisted 20.3.0. **Please see security advisory
below**.

Security advisory
-----------------

Synapse may be vulnerable to request-smuggling attacks when it is used with a
reverse-proxy. The vulnerabilties are fixed in Twisted 20.3.0, and are
described in
[CVE-2020-10108](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-10108)
and
[CVE-2020-10109](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-10109).
For a good introduction to this class of request-smuggling attacks, see
https://portswigger.net/research/http-desync-attacks-request-smuggling-reborn.

We are not aware of these vulnerabilities being exploited in the wild, and
do not believe that they are exploitable with current versions of any reverse
proxies. Nevertheless, we recommend that all Synapse administrators ensure that
they have the latest versions of the Twisted library to ensure that their
installation remains secure.

* Administrators using the [`matrix.org` Docker
  image](https://hub.docker.com/r/matrixdotorg/synapse/) or the [Debian/Ubuntu
  packages from
  `matrix.org`](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#matrixorg-packages)
  should ensure that they have version 1.12.0 installed: these images include
  Twisted 20.3.0.
* Administrators who have [installed Synapse from
  source](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#installing-from-source)
  should upgrade Twisted within their virtualenv by running:
  ```sh
  <path_to_virtualenv>/bin/pip install 'Twisted>=20.3.0'
  ```
* Administrators who have installed Synapse from distribution packages should
  consult the information from their distributions.

The `matrix.org` Synapse instance was not vulnerable to these vulnerabilities.

Advance notice of change to the default `git` branch for Synapse
----------------------------------------------------------------

Currently, the default `git` branch for Synapse is `master`, which tracks the
latest release.

After the release of Synapse 1.13.0, we intend to change this default to
`develop`, which is the development tip. This is more consistent with common
practice and modern `git` usage.

Although we try to keep `develop` in a stable state, there may be occasions
where regressions creep in. Developers and distributors who have scripts which
run builds using the default branch of `Synapse` should therefore consider
pinning their scripts to `master`.

Synapse 1.12.0rc1 (2020-03-19)
==============================

Features
--------

- Changes related to room alias management ([MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432)):
  - Publishing/removing a room from the room directory now requires the user to have a power level capable of modifying the canonical alias, instead of the room aliases. ([\#6965](https://github.com/matrix-org/synapse/issues/6965))
  - Validate the `alt_aliases` property of canonical alias events. ([\#6971](https://github.com/matrix-org/synapse/issues/6971))
  - Users with a power level sufficient to modify the canonical alias of a room can now delete room aliases. ([\#6986](https://github.com/matrix-org/synapse/issues/6986))
  - Implement updated authorization rules and redaction rules for aliases events, from [MSC2261](https://github.com/matrix-org/matrix-doc/pull/2261) and [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432). ([\#7037](https://github.com/matrix-org/synapse/issues/7037))
  - Stop sending m.room.aliases events during room creation and upgrade. ([\#6941](https://github.com/matrix-org/synapse/issues/6941))
  - Synapse no longer uses room alias events to calculate room names for push notifications. ([\#6966](https://github.com/matrix-org/synapse/issues/6966))
  - The room list endpoint no longer returns a list of aliases. ([\#6970](https://github.com/matrix-org/synapse/issues/6970))
  - Remove special handling of aliases events from [MSC2260](https://github.com/matrix-org/matrix-doc/pull/2260) added in v1.10.0rc1. ([\#7034](https://github.com/matrix-org/synapse/issues/7034))
- Expose the `synctl`, `hash_password` and `generate_config` commands in the snapcraft package. Contributed by @devec0. ([\#6315](https://github.com/matrix-org/synapse/issues/6315))
- Check that server_name is correctly set before running database updates. ([\#6982](https://github.com/matrix-org/synapse/issues/6982))
- Break down monthly active users by `appservice_id` and emit via Prometheus. ([\#7030](https://github.com/matrix-org/synapse/issues/7030))
- Render a configurable and comprehensible error page if something goes wrong during the SAML2 authentication process. ([\#7058](https://github.com/matrix-org/synapse/issues/7058), [\#7067](https://github.com/matrix-org/synapse/issues/7067))
- Add an optional parameter to control whether other sessions are logged out when a user's password is modified. ([\#7085](https://github.com/matrix-org/synapse/issues/7085))
- Add prometheus metrics for the number of active pushers. ([\#7103](https://github.com/matrix-org/synapse/issues/7103), [\#7106](https://github.com/matrix-org/synapse/issues/7106))
- Improve performance when making HTTPS requests to sygnal, sydent, etc, by sharing the SSL context object between connections. ([\#7094](https://github.com/matrix-org/synapse/issues/7094))

Bugfixes
--------

- When a user's profile is updated via the admin API, also generate a displayname/avatar update for that user in each room. ([\#6572](https://github.com/matrix-org/synapse/issues/6572))
- Fix a couple of bugs in email configuration handling. ([\#6962](https://github.com/matrix-org/synapse/issues/6962))
- Fix an issue affecting worker-based deployments where replication would stop working, necessitating a full restart, after joining a large room. ([\#6967](https://github.com/matrix-org/synapse/issues/6967))
- Fix `duplicate key` error which was logged when rejoining a room over federation. ([\#6968](https://github.com/matrix-org/synapse/issues/6968))
- Prevent user from setting 'deactivated' to anything other than a bool on the v2 PUT /users Admin API. ([\#6990](https://github.com/matrix-org/synapse/issues/6990))
- Fix py35-old CI by using native tox package. ([\#7018](https://github.com/matrix-org/synapse/issues/7018))
- Fix a bug causing `org.matrix.dummy_event` to be included in responses from `/sync`. ([\#7035](https://github.com/matrix-org/synapse/issues/7035))
- Fix a bug that renders UTF-8 text files incorrectly when loaded from media. Contributed by @TheStranjer. ([\#7044](https://github.com/matrix-org/synapse/issues/7044))
- Fix a bug that would cause Synapse to respond with an error about event visibility if a client tried to request the state of a room at a given token. ([\#7066](https://github.com/matrix-org/synapse/issues/7066))
- Repair a data-corruption issue which was introduced in Synapse 1.10, and fixed in Synapse 1.11, and which could cause `/sync` to return with 404 errors about missing events and unknown rooms. ([\#7070](https://github.com/matrix-org/synapse/issues/7070))
- Fix a bug causing account validity renewal emails to be sent even if the feature is turned off in some cases. ([\#7074](https://github.com/matrix-org/synapse/issues/7074))

Improved Documentation
----------------------

- Updated CentOS8 install instructions. Contributed by Richard Kellner. ([\#6925](https://github.com/matrix-org/synapse/issues/6925))
- Fix `POSTGRES_INITDB_ARGS` in the `contrib/docker/docker-compose.yml` example docker-compose configuration. ([\#6984](https://github.com/matrix-org/synapse/issues/6984))
- Change date in [INSTALL.md](./INSTALL.md#tls-certificates) for last date of getting TLS certificates to November 2019. ([\#7015](https://github.com/matrix-org/synapse/issues/7015))
- Document that the fallback auth endpoints must be routed to the same worker node as the register endpoints. ([\#7048](https://github.com/matrix-org/synapse/issues/7048))

Deprecations and Removals
-------------------------

- Remove the unused query_auth federation endpoint per [MSC2451](https://github.com/matrix-org/matrix-doc/pull/2451). ([\#7026](https://github.com/matrix-org/synapse/issues/7026))

Internal Changes
----------------

- Add type hints to `logging/context.py`. ([\#6309](https://github.com/matrix-org/synapse/issues/6309))
- Add some clarifications to `README.md` in the database schema directory. ([\#6615](https://github.com/matrix-org/synapse/issues/6615))
- Refactoring work in preparation for changing the event redaction algorithm. ([\#6874](https://github.com/matrix-org/synapse/issues/6874), [\#6875](https://github.com/matrix-org/synapse/issues/6875), [\#6983](https://github.com/matrix-org/synapse/issues/6983), [\#7003](https://github.com/matrix-org/synapse/issues/7003))
- Improve performance of v2 state resolution for large rooms. ([\#6952](https://github.com/matrix-org/synapse/issues/6952), [\#7095](https://github.com/matrix-org/synapse/issues/7095))
- Reduce time spent doing GC, by freezing objects on startup. ([\#6953](https://github.com/matrix-org/synapse/issues/6953))
- Minor perfermance fixes to `get_auth_chain_ids`. ([\#6954](https://github.com/matrix-org/synapse/issues/6954))
- Don't record remote cross-signing keys in the `devices` table. ([\#6956](https://github.com/matrix-org/synapse/issues/6956))
- Use flake8-comprehensions to enforce good hygiene of list/set/dict comprehensions. ([\#6957](https://github.com/matrix-org/synapse/issues/6957))
- Merge worker apps together. ([\#6964](https://github.com/matrix-org/synapse/issues/6964), [\#7002](https://github.com/matrix-org/synapse/issues/7002), [\#7055](https://github.com/matrix-org/synapse/issues/7055), [\#7104](https://github.com/matrix-org/synapse/issues/7104))
- Remove redundant `store_room` call from `FederationHandler._process_received_pdu`. ([\#6979](https://github.com/matrix-org/synapse/issues/6979))
- Update warning for incorrect database collation/ctype to include link to documentation. ([\#6985](https://github.com/matrix-org/synapse/issues/6985))
- Add some type annotations to the database storage classes. ([\#6987](https://github.com/matrix-org/synapse/issues/6987))
- Port `synapse.handlers.presence` to async/await. ([\#6991](https://github.com/matrix-org/synapse/issues/6991), [\#7019](https://github.com/matrix-org/synapse/issues/7019))
- Add some type annotations to the federation base & client classes. ([\#6995](https://github.com/matrix-org/synapse/issues/6995))
- Port `synapse.rest.keys` to async/await. ([\#7020](https://github.com/matrix-org/synapse/issues/7020))
- Add a type check to `is_verified` when processing room keys. ([\#7045](https://github.com/matrix-org/synapse/issues/7045))
- Add type annotations and comments to the auth handler. ([\#7063](https://github.com/matrix-org/synapse/issues/7063))
2020-03-23 13:54:17 +00:00
Richard van der Hoff
b3cee0ce67 Fix processing of groups stream, and use symbolic names for streams (#7117)
`groups` != `receipts`

Introduced in #6964
2020-03-23 11:39:36 +00:00
Dionysis Grigoropoulos
96071eea8f Set Referrer-Policy to no-referrer for media (#7009) 2020-03-23 09:48:28 +00:00
Patrick Cloke
477c4f5b1c Clean-up some auth/login REST code (#7115) 2020-03-20 16:22:47 -04:00
Richard van der Hoff
c165c1233b Improve database configuration docs (#6988)
Attempts to clarify the sample config for databases, and add some stuff about
tcp keepalives to `postgres.md`.
2020-03-20 15:24:22 +00:00
Erik Johnston
fdb1344716 Remove concept of a non-limited stream. (#7011) 2020-03-20 14:40:47 +00:00
Patrick Cloke
caec7d4fa0 Convert some of the media REST code to async/await (#7110) 2020-03-20 07:20:02 -04:00
Patrick Cloke
c2db6599c8 Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors (#7089). 2020-03-19 08:22:56 -04:00
Erik Johnston
a319cb1dd1 Change device list streams to have one row per ID (#7010)
* Add 'device_lists_outbound_pokes' as extra table.

This makes sure we check all the relevant tables to get the current max
stream ID.

Currently not doing so isn't problematic as the max stream ID in
`device_lists_outbound_pokes` is the same as in `device_lists_stream`,
however that will change.

* Change device lists stream to have one row per id.

This will make it possible to process the streams more incrementally,
avoiding having to process large chunks at once.

* Change device list replication to match new semantics.

Instead of sending down batches of user ID/host tuples, send down a row
per entity (user ID or host).

* Newsfile

* Remove handling of multiple rows per ID

* Fix worker handling

* Comments from review
2020-03-19 11:36:53 +00:00
Erik Johnston
6e6476ef07 Comments from review 2020-03-18 10:13:55 +00:00
Richard van der Hoff
4ce50519cd Update postgres.md
fix broken link
2020-03-17 18:08:43 +00:00
Erik Johnston
65a941d1f8 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/fixup_devices_stream 2020-03-02 16:55:55 +00:00
Erik Johnston
e53744c737 Fix worker handling 2020-03-02 12:52:28 +00:00
Erik Johnston
f70f44abc7 Remove handling of multiple rows per ID 2020-02-28 11:45:35 +00:00
Erik Johnston
59ad93d2a4 Newsfile 2020-02-28 11:27:37 +00:00
Erik Johnston
9ce4e344a8 Change device list replication to match new semantics.
Instead of sending down batches of user ID/host tuples, send down a row
per entity (user ID or host).
2020-02-28 11:25:34 +00:00
Erik Johnston
f5caa1864e Change device lists stream to have one row per id.
This will make it possible to process the streams more incrementally,
avoiding having to process large chunks at once.
2020-02-28 11:21:25 +00:00
Erik Johnston
c3c6c0e622 Add 'device_lists_outbound_pokes' as extra table.
This makes sure we check all the relevant tables to get the current max
stream ID.

Currently not doing so isn't problematic as the max stream ID in
`device_lists_outbound_pokes` is the same as in `device_lists_stream`,
however that will change.
2020-02-28 11:15:11 +00:00
88 changed files with 1594 additions and 1354 deletions

View File

@@ -2,7 +2,6 @@
- [Installing Synapse](#installing-synapse)
- [Installing from source](#installing-from-source)
- [Platform-Specific Instructions](#platform-specific-instructions)
- [Troubleshooting Installation](#troubleshooting-installation)
- [Prebuilt packages](#prebuilt-packages)
- [Setting up Synapse](#setting-up-synapse)
- [TLS certificates](#tls-certificates)
@@ -10,6 +9,7 @@
- [Registering a user](#registering-a-user)
- [Setting up a TURN server](#setting-up-a-turn-server)
- [URL previews](#url-previews)
- [Troubleshooting Installation](#troubleshooting-installation)
# Choosing your server name
@@ -70,7 +70,7 @@ pip install -U matrix-synapse
```
Before you can start Synapse, you will need to generate a configuration
file. To do this, run (in your virtualenv, as before)::
file. To do this, run (in your virtualenv, as before):
```
cd ~/synapse
@@ -84,22 +84,24 @@ python -m synapse.app.homeserver \
... substituting an appropriate value for `--server-name`.
This command will generate you a config file that you can then customise, but it will
also generate a set of keys for you. These keys will allow your Home Server to
identify itself to other Home Servers, so don't lose or delete them. It would be
also generate a set of keys for you. These keys will allow your homeserver to
identify itself to other homeserver, so don't lose or delete them. It would be
wise to back them up somewhere safe. (If, for whatever reason, you do need to
change your Home Server's keys, you may find that other Home Servers have the
change your homeserver's keys, you may find that other homeserver have the
old key cached. If you update the signing key, you should change the name of the
key in the `<server name>.signing.key` file (the second word) to something
different. See the
[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys)
for more information on key management.)
for more information on key management).
To actually run your new homeserver, pick a working directory for Synapse to
run (e.g. `~/synapse`), and::
run (e.g. `~/synapse`), and:
cd ~/synapse
source env/bin/activate
synctl start
```
cd ~/synapse
source env/bin/activate
synctl start
```
### Platform-Specific Instructions
@@ -188,7 +190,7 @@ doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
There is currently no port for OpenBSD. Additionally, OpenBSD's security
settings require a slightly more difficult installation process.
XXX: I suspect this is out of date.
(XXX: I suspect this is out of date)
1. Create a new directory in `/usr/local` called `_synapse`. Also, create a
new user called `_synapse` and set that directory as the new user's home.
@@ -196,7 +198,7 @@ XXX: I suspect this is out of date.
write and execute permissions on the same memory space to be run from
`/usr/local`.
2. `su` to the new `_synapse` user and change to their home directory.
3. Create a new virtualenv: `virtualenv -p python2.7 ~/.synapse`
3. Create a new virtualenv: `virtualenv -p python3 ~/.synapse`
4. Source the virtualenv configuration located at
`/usr/local/_synapse/.synapse/bin/activate`. This is done in `ksh` by
using the `.` command, rather than `bash`'s `source`.
@@ -217,45 +219,6 @@ be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for
Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server
for Windows Server.
### Troubleshooting Installation
XXX a bunch of this is no longer relevant.
Synapse requires pip 8 or later, so if your OS provides too old a version you
may need to manually upgrade it::
sudo pip install --upgrade pip
Installing may fail with `Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)`.
You can fix this by manually upgrading pip and virtualenv::
sudo pip install --upgrade virtualenv
You can next rerun `virtualenv -p python3 synapse` to update the virtual env.
Installing may fail during installing virtualenv with `InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.`
You can fix this by manually installing ndg-httpsclient::
pip install --upgrade ndg-httpsclient
Installing may fail with `mock requires setuptools>=17.1. Aborting installation`.
You can fix this by upgrading setuptools::
pip install --upgrade setuptools
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
pip install twisted
## Prebuilt packages
As an alternative to installing from source, prebuilt packages are available
@@ -314,7 +277,7 @@ For `buster` and `sid`, Synapse is available in the Debian repositories and
it should be possible to install it with simply:
```
sudo apt install matrix-synapse
sudo apt install matrix-synapse
```
There is also a version of `matrix-synapse` in `stretch-backports`. Please see
@@ -375,15 +338,17 @@ sudo pip install py-bcrypt
Synapse can be found in the void repositories as 'synapse':
xbps-install -Su
xbps-install -S synapse
```
xbps-install -Su
xbps-install -S synapse
```
### FreeBSD
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
- Packages: `pkg install py27-matrix-synapse`
- Packages: `pkg install py37-matrix-synapse`
### NixOS
@@ -420,6 +385,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
resources:
- names: [client, federation]
```
* You will also need to uncomment the `tls_certificate_path` and
`tls_private_key_path` lines under the `TLS` section. You can either
point these settings at an existing certificate and key, or you can
@@ -435,7 +401,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
`cert.pem`).
For a more detailed guide to configuring your server for federation, see
[federate.md](docs/federate.md)
[federate.md](docs/federate.md).
## Email
@@ -482,7 +448,7 @@ on your server even if `enable_registration` is `false`.
## Setting up a TURN server
For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
## URL previews
@@ -491,10 +457,24 @@ turn it on you must enable the `url_preview_enabled: True` config parameter
and explicitly specify the IP ranges that Synapse is not allowed to spider for
previewing in the `url_preview_ip_range_blacklist` configuration parameter.
This is critical from a security perspective to stop arbitrary Matrix users
spidering 'internal' URLs on your network. At the very least we recommend that
spidering 'internal' URLs on your network. At the very least we recommend that
your loopback and RFC1918 IP addresses are blacklisted.
This also requires the optional lxml and netaddr python dependencies to be
installed. This in turn requires the libxml2 library to be available - on
This also requires the optional `lxml` and `netaddr` python dependencies to be
installed. This in turn requires the `libxml2` library to be available - on
Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for
your OS.
# Troubleshooting Installation
`pip` seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.:
```
pip install twisted
```
If you have any other problems, feel free to ask in
[#synapse:matrix.org](https://matrix.to/#/#synapse:matrix.org).

1
changelog.d/6634.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix single-sign on with CAS systems: pass the same service URL when requesting the CAS ticket and when calling the `proxyValidate` URL. Contributed by @Naugrimm.

1
changelog.d/6988.doc Normal file
View File

@@ -0,0 +1 @@
Improve the documentation for database configuration.

1
changelog.d/7009.feature Normal file
View File

@@ -0,0 +1 @@
Set `Referrer-Policy` header to `no-referrer` on media downloads.

1
changelog.d/7010.misc Normal file
View File

@@ -0,0 +1 @@
Change device list streams to have one row per ID.

1
changelog.d/7011.misc Normal file
View File

@@ -0,0 +1 @@
Remove concept of a non-limited stream.

1
changelog.d/7024.misc Normal file
View File

@@ -0,0 +1 @@
Move catchup of replication streams logic to worker.

1
changelog.d/7089.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors.

1
changelog.d/7107.doc Normal file
View File

@@ -0,0 +1 @@
Update pre-built package name for FreeBSD.

1
changelog.d/7109.bugfix Normal file
View File

@@ -0,0 +1 @@
Return the proper error (M_BAD_ALIAS) when a non-existant canonical alias is provided.

1
changelog.d/7110.misc Normal file
View File

@@ -0,0 +1 @@
Convert some of synapse.rest.media to async/await.

1
changelog.d/7115.misc Normal file
View File

@@ -0,0 +1 @@
De-duplicate / remove unused REST code for login and auth.

1
changelog.d/7116.misc Normal file
View File

@@ -0,0 +1 @@
Convert `*StreamRow` classes to inner classes.

1
changelog.d/7117.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug which meant that groups updates were not correctly replicated between workers.

1
changelog.d/7120.misc Normal file
View File

@@ -0,0 +1 @@
Clean up some LoggingContext code.

1
changelog.d/7133.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix starting workers when federation sending not split out.

1
changelog.d/7141.doc Normal file
View File

@@ -0,0 +1 @@
Clean up INSTALL.md a bit.

1
changelog.d/7143.bugfix Normal file
View File

@@ -0,0 +1 @@
Prevent `sqlite3` module from being imported even when using the postgres backend.

View File

@@ -29,14 +29,13 @@ from synapse.logging import context # omitted from future snippets
def handle_request(request_id):
request_context = context.LoggingContext()
calling_context = context.LoggingContext.current_context()
context.LoggingContext.set_current_context(request_context)
calling_context = context.set_current_context(request_context)
try:
request_context.request = request_id
do_request_handling()
logger.debug("finished")
finally:
context.LoggingContext.set_current_context(calling_context)
context.set_current_context(calling_context)
def do_request_handling():
logger.debug("phew") # this will be logged against request_id

View File

@@ -72,8 +72,7 @@ underneath the database, or if a different version of the locale is used on any
replicas.
The safest way to fix the issue is to take a dump and recreate the database with
the correct `COLLATE` and `CTYPE` parameters (as per
[docs/postgres.md](docs/postgres.md)). It is also possible to change the
the correct `COLLATE` and `CTYPE` parameters (as shown above). It is also possible to change the
parameters on a live database and run a `REINDEX` on the entire database,
however extreme care must be taken to avoid database corruption.
@@ -105,19 +104,41 @@ of free memory the database host has available.
When you are ready to start using PostgreSQL, edit the `database`
section in your config file to match the following lines:
database:
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
```yaml
database:
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
```
All key, values in `args` are passed to the `psycopg2.connect(..)`
function, except keys beginning with `cp_`, which are consumed by the
twisted adbapi connection pool.
twisted adbapi connection pool. See the [libpq
documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS)
for a list of options which can be passed.
You should consider tuning the `args.keepalives_*` options if there is any danger of
the connection between your homeserver and database dropping, otherwise Synapse
may block for an extended period while it waits for a response from the
database server. Example values might be:
```yaml
# seconds of inactivity after which TCP should send a keepalive message to the server
keepalives_idle: 10
# the number of seconds after which a TCP keepalive message that is not
# acknowledged by the server should be retransmitted
keepalives_interval: 10
# the number of TCP keepalives that can be lost before the client's connection
# to the server is considered dead
keepalives_count: 3
```
## Porting from SQLite

View File

@@ -578,13 +578,46 @@ acme:
## Database ##
# The 'database' setting defines the database that synapse uses to store all of
# its data.
#
# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
# 'psycopg2' (for PostgreSQL).
#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:
# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
#
#
# Example SQLite configuration:
#
#database:
# name: sqlite3
# args:
# database: /path/to/homeserver.db
#
#
# Example Postgres configuration:
#
#database:
# name: psycopg2
# args:
# user: synapse
# password: secretpassword
# database: synapse
# host: localhost
# cp_min: 5
# cp_max: 10
#
# For more information on using Synapse with Postgres, see `docs/postgres.md`.
#
database:
# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
name: sqlite3
args:
# Path to the database
database: "DATADIR/homeserver.db"
database: DATADIR/homeserver.db
# Number of events to cache in memory.
#

View File

@@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and
'<' worker to master flows):
> SERVER example.com
< REPLICATE events 53
< REPLICATE
> POSITION events 53
> RDATA events 54 ["$foo1:bar.com", ...]
> RDATA events 55 ["$foo4:bar.com", ...]
The example shows the server accepting a new connection and sending its
identity with the `SERVER` command, followed by the client asking to
subscribe to the `events` stream from the token `53`. The server then
periodically sends `RDATA` commands which have the format
`RDATA <stream_name> <token> <row>`, where the format of `<row>` is
defined by the individual streams.
The example shows the server accepting a new connection and sending its identity
with the `SERVER` command, followed by the client server to respond with the
position of all streams. The server then periodically sends `RDATA` commands
which have the format `RDATA <stream_name> <token> <row>`, where the format of
`<row>` is defined by the individual streams.
Error reporting happens by either the client or server sending an ERROR
command, and usually the connection will be closed.
@@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually
connect to the server using a tool like netcat. A few things should be
noted when manually using the protocol:
- When subscribing to a stream using `REPLICATE`, the special token
`NOW` can be used to get all future updates. The special stream name
`ALL` can be used with `NOW` to subscribe to all available streams.
- The federation stream is only available if federation sending has
been disabled on the main process.
- The server will only time connections out that have sent a `PING`
@@ -91,9 +88,7 @@ The client:
- Sends a `NAME` command, allowing the server to associate a human
friendly name with the connection. This is optional.
- Sends a `PING` as above
- For each stream the client wishes to subscribe to it sends a
`REPLICATE` with the `stream_name` and token it wants to subscribe
from.
- Sends a `REPLICATE` to get the current position of all streams.
- On receipt of a `SERVER` command, checks that the server name
matches the expected server name.
@@ -140,9 +135,7 @@ the wire:
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE events 1
< REPLICATE backfill 1
< REPLICATE caches 1
< REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -181,9 +174,9 @@ client (C):
#### POSITION (S)
The position of the stream has been updated. Sent to the client
after all missing updates for a stream have been sent to the client
and they're now up to date.
On receipt of a POSITION command clients should check if they have missed any
updates, and if so then fetch them out of band. Sent in response to a
REPLICATE command (but can happen at any time).
#### ERROR (S, C)
@@ -199,20 +192,7 @@ client (C):
#### REPLICATE (C)
Asks the server to replicate a given stream. The syntax is:
```
REPLICATE <stream_name> <token>
```
Where `<token>` may be either:
* a numeric stream_id to stream updates since (exclusive)
* `NOW` to stream all subsequent updates.
The `<stream_name>` is the name of a replication stream to subscribe
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
of streams). It can also be `ALL` to subscribe to all known streams,
in which case the `<token>` must be set to `NOW`.
Asks the server for the current position of all streams.
#### USER_SYNC (C)

View File

@@ -65,12 +65,23 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams._base import (
from synapse.replication.tcp.streams import (
AccountDataStream,
DeviceListsStream,
GroupServerStream,
PresenceStream,
PushersStream,
PushRulesStream,
ReceiptsStream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@@ -390,6 +401,9 @@ class GenericWorkerTyping(object):
self._room_serials[row.room_id] = token
self._room_typing[row.room_id] = row.user_ids
def get_current_token(self) -> int:
return self._latest_room_serial
class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
@@ -626,7 +640,7 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
if self.send_handler:
self.send_handler.process_replication_rows(stream_name, token, rows)
if stream_name == "events":
if stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
for row in rows:
@@ -649,43 +663,44 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
)
await self.pusher_pool.on_new_notifications(token, token)
elif stream_name == "push_rules":
elif stream_name == PushRulesStream.NAME:
self.notifier.on_new_event(
"push_rules_key", token, users=[row.user_id for row in rows]
)
elif stream_name in ("account_data", "tag_account_data"):
elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
self.notifier.on_new_event(
"account_data_key", token, users=[row.user_id for row in rows]
)
elif stream_name == "receipts":
elif stream_name == ReceiptsStream.NAME:
self.notifier.on_new_event(
"receipt_key", token, rooms=[row.room_id for row in rows]
)
await self.pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
elif stream_name == "typing":
elif stream_name == TypingStream.NAME:
self.typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows]
)
elif stream_name == "to_device":
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
self.notifier.on_new_event("to_device_key", token, users=entities)
elif stream_name == "device_lists":
elif stream_name == DeviceListsStream.NAME:
all_room_ids = set()
for row in rows:
room_ids = await self.store.get_rooms_for_user(row.user_id)
all_room_ids.update(room_ids)
if row.entity.startswith("@"):
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
elif stream_name == "presence":
elif stream_name == PresenceStream.NAME:
await self.presence_handler.process_replication_rows(token, rows)
elif stream_name == "receipts":
elif stream_name == GroupServerStream.NAME:
self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows]
)
elif stream_name == "pushers":
elif stream_name == PushersStream.NAME:
for row in rows:
if row.deleted:
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
@@ -774,7 +789,10 @@ class FederationSenderHandler(object):
# ... as well as device updates and messages
elif stream_name == DeviceListsStream.NAME:
hosts = {row.destination for row in rows}
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
for host in hosts:
self.federation_sender.send_device_messages(host)
@@ -789,7 +807,7 @@ class FederationSenderHandler(object):
async def _on_new_receipts(self, rows):
"""
Args:
rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
new receipts to be processed
"""
for receipt in rows:
@@ -860,6 +878,9 @@ def start(config_options):
# Force the appservice to start since they will be disabled in the main config
config.notify_appservices = True
else:
# For other worker types we force this to off.
config.notify_appservices = False
if config.worker_app == "synapse.app.pusher":
if config.start_pushers:
@@ -873,6 +894,9 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.start_pushers = True
else:
# For other worker types we force this to off.
config.start_pushers = False
if config.worker_app == "synapse.app.user_dir":
if config.update_user_directory:
@@ -886,6 +910,9 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.update_user_directory = True
else:
# For other worker types we force this to off.
config.update_user_directory = False
if config.worker_app == "synapse.app.federation_sender":
if config.send_federation:
@@ -899,6 +926,9 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.send_federation = True
else:
# For other worker types we force this to off.
config.send_federation = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts

View File

@@ -294,7 +294,6 @@ class RootConfig(object):
report_stats=None,
open_private_ports=False,
listeners=None,
database_conf=None,
tls_certificate_path=None,
tls_private_key_path=None,
acme_domain=None,
@@ -367,7 +366,6 @@ class RootConfig(object):
report_stats=report_stats,
open_private_ports=open_private_ports,
listeners=listeners,
database_conf=database_conf,
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,14 +15,60 @@
# limitations under the License.
import logging
import os
from textwrap import indent
import yaml
from synapse.config._base import Config, ConfigError
logger = logging.getLogger(__name__)
DEFAULT_CONFIG = """\
## Database ##
# The 'database' setting defines the database that synapse uses to store all of
# its data.
#
# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
# 'psycopg2' (for PostgreSQL).
#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:
# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
#
#
# Example SQLite configuration:
#
#database:
# name: sqlite3
# args:
# database: /path/to/homeserver.db
#
#
# Example Postgres configuration:
#
#database:
# name: psycopg2
# args:
# user: synapse
# password: secretpassword
# database: synapse
# host: localhost
# cp_min: 5
# cp_max: 10
#
# For more information on using Synapse with Postgres, see `docs/postgres.md`.
#
database:
name: sqlite3
args:
database: %(database_path)s
# Number of events to cache in memory.
#
#event_cache_size: 10K
"""
class DatabaseConnectionConfig:
"""Contains the connection config for a particular database.
@@ -36,10 +83,12 @@ class DatabaseConnectionConfig:
"""
def __init__(self, name: str, db_config: dict):
if db_config["name"] not in ("sqlite3", "psycopg2"):
raise ConfigError("Unsupported database type %r" % (db_config["name"],))
db_engine = db_config.get("name", "sqlite3")
if db_config["name"] == "sqlite3":
if db_engine not in ("sqlite3", "psycopg2"):
raise ConfigError("Unsupported database type %r" % (db_engine,))
if db_engine == "sqlite3":
db_config.setdefault("args", {}).update(
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
)
@@ -97,34 +146,10 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path"))
def generate_config_section(self, data_dir_path, database_conf, **kwargs):
if not database_conf:
database_path = os.path.join(data_dir_path, "homeserver.db")
database_conf = (
"""# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
# Path to the database
database: "%(database_path)s"
"""
% locals()
)
else:
database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip()
return (
"""\
## Database ##
database:
%(database_conf)s
# Number of events to cache in memory.
#
#event_cache_size: 10K
"""
% locals()
)
def generate_config_section(self, data_dir_path, **kwargs):
return DEFAULT_CONFIG % {
"database_path": os.path.join(data_dir_path, "homeserver.db")
}
def read_arguments(self, args):
self.set_databasepath(args.database_path)

View File

@@ -43,8 +43,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
preserve_fn,
run_in_background,
@@ -236,7 +236,7 @@ class Keyring(object):
"""
try:
ctx = LoggingContext.current_context()
ctx = current_context()
# map from server name to a set of outstanding request ids
server_to_request_ids = {}

View File

@@ -25,19 +25,15 @@ from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
)
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.types import JsonDict, get_domain_from_id
@@ -55,13 +51,15 @@ class FederationBase(object):
self.store = hs.get_datastore()
self._clock = hs.get_clock()
def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
def _check_sigs_and_hash(
self, room_version: RoomVersion, pdu: EventBase
) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)
def _check_sigs_and_hashes(
self, room_version: str, pdus: List[EventBase]
self, room_version: RoomVersion, pdus: List[EventBase]
) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
@@ -80,7 +78,7 @@ class FederationBase(object):
"""
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
ctx = LoggingContext.current_context()
ctx = current_context()
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
@@ -146,7 +144,7 @@ class PduToCheckSig(
def _check_sigs_on_pdus(
keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
) -> List[Deferred]:
"""Check that the given events are correctly signed
@@ -191,10 +189,6 @@ def _check_sigs_on_pdus(
for p in pdus
]
v = KNOWN_ROOM_VERSIONS.get(room_version)
if not v:
raise RuntimeError("Unrecognized room version %s" % (room_version,))
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
@@ -204,7 +198,7 @@ def _check_sigs_on_pdus(
(
p.sender_domain,
p.redacted_pdu_json,
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
@@ -227,7 +221,7 @@ def _check_sigs_on_pdus(
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
# (ie, the room version uses old-style non-hash event IDs).
if v.event_format == EventFormatVersions.V1:
if room_version.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
p
for p in pdus_to_check
@@ -239,7 +233,7 @@ def _check_sigs_on_pdus(
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id

View File

@@ -220,8 +220,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = await make_deferred_yieldable(
defer.gatherResults(
self._check_sigs_and_hashes(room_version.identifier, pdus),
consumeErrors=True,
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
).addErrback(unwrapFirstError)
)
@@ -291,9 +290,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
signed_pdu = await self._check_sigs_and_hash(
room_version.identifier, pdu
)
signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
break
@@ -350,7 +347,7 @@ class FederationClient(FederationBase):
self,
origin: str,
pdus: List[EventBase],
room_version: str,
room_version: RoomVersion,
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
@@ -396,7 +393,7 @@ class FederationClient(FederationBase):
self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version, # type: ignore
room_version=room_version,
outlier=outlier,
timeout=10000,
)
@@ -434,7 +431,7 @@ class FederationClient(FederationBase):
]
signed_auth = await self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True, room_version=room_version.identifier
destination, auth_chain, outlier=True, room_version=room_version
)
signed_auth.sort(key=lambda e: e.depth)
@@ -661,7 +658,7 @@ class FederationClient(FederationBase):
destination,
list(pdus.values()),
outlier=True,
room_version=room_version.identifier,
room_version=room_version,
)
valid_pdus_map = {p.event_id: p for p in valid_pdus}
@@ -756,7 +753,7 @@ class FederationClient(FederationBase):
pdu = event_from_pdu_json(pdu_dict, room_version)
# Check signatures are correct.
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)
# FIXME: We should handle signature failures more gracefully.
@@ -948,7 +945,7 @@ class FederationClient(FederationBase):
]
signed_events = await self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False, room_version=room_version.identifier
destination, events, outlier=False, room_version=room_version
)
except HttpResponseException as e:
if not e.code == 400:

View File

@@ -409,7 +409,7 @@ class FederationServer(FederationBase):
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)
ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
@@ -425,7 +425,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)
res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -455,7 +455,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)
await self.handler.on_send_leave_request(origin, pdu)
return {}
@@ -611,7 +611,7 @@ class FederationServer(FederationBase):
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
room_version = await self.store.get_room_version_id(pdu.room_id)
room_version = await self.store.get_room_version(pdu.room_id)
# Check signature.
try:

View File

@@ -477,7 +477,7 @@ def process_rows_for_federation(transaction_queue, rows):
Args:
transaction_queue (FederationSender)
rows (list(synapse.replication.tcp.streams.FederationStreamRow))
rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow))
"""
# The federation stream contains a bunch of different types of

View File

@@ -499,4 +499,13 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self) -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return []

View File

@@ -851,6 +851,38 @@ class EventCreationHandler(object):
self.store.remove_push_actions_from_staging, event.event_id
)
@defer.inlineCallbacks
def _validate_canonical_alias(
self, directory_handler, room_alias_str, expected_room_id
):
"""
Ensure that the given room alias points to the expected room ID.
Args:
directory_handler: The directory handler object.
room_alias_str: The room alias to check.
expected_room_id: The room ID that the alias should point to.
"""
room_alias = RoomAlias.from_string(room_alias_str)
try:
mapping = yield directory_handler.get_association(room_alias)
except SynapseError as e:
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
if e.errcode == Codes.NOT_FOUND:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
)
raise
if mapping["room_id"] != expected_room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
)
@defer.inlineCallbacks
def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
@@ -905,15 +937,9 @@ class EventCreationHandler(object):
room_alias_str = event.content.get("alias", None)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias:
room_alias = RoomAlias.from_string(room_alias_str)
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
)
yield self._validate_canonical_alias(
directory_handler, room_alias_str, event.room_id
)
# Check that alt_aliases is the proper form.
alt_aliases = event.content.get("alt_aliases", [])
@@ -931,16 +957,9 @@ class EventCreationHandler(object):
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases:
for alias_str in new_alt_aliases:
room_alias = RoomAlias.from_string(alias_str)
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room"
% (room_alias_str,),
Codes.BAD_ALIAS,
)
yield self._validate_canonical_alias(
directory_handler, alias_str, event.room_id
)
federation_handler = self.hs.get_handlers().federation_handler

View File

@@ -747,7 +747,7 @@ class PresenceHandler(object):
return False
async def get_all_presence_updates(self, last_id, current_id):
async def get_all_presence_updates(self, last_id, current_id, limit):
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@@ -762,7 +762,7 @@ class PresenceHandler(object):
"""
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
rows = await self.store.get_all_presence_updates(last_id, current_id)
rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
return rows
def notify_new_event(self):

View File

@@ -26,7 +26,7 @@ from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import FilterCollection
from synapse.events import EventBase
from synapse.logging.context import LoggingContext
from synapse.logging.context import current_context
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
@@ -301,7 +301,7 @@ class SyncHandler(object):
else:
sync_type = "incremental_sync"
context = LoggingContext.current_context()
context = current_context()
if context:
context.tag = sync_type

View File

@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
from typing import List
from twisted.internet import defer
@@ -257,7 +258,13 @@ class TypingHandler(object):
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
async def get_all_typing_updates(self, last_id, current_id):
async def get_all_typing_updates(
self, last_id: int, current_id: int, limit: int
) -> List[dict]:
"""Get up to `limit` typing updates between the given tokens, earliest
updates first.
"""
if last_id == current_id:
return []
@@ -275,7 +282,7 @@ class TypingHandler(object):
typing = self._room_typing[room_id]
rows.append((serial, room_id, list(typing)))
rows.sort()
return rows
return rows[:limit]
def get_current_token(self):
return self._latest_room_serial

View File

@@ -19,7 +19,7 @@ import threading
from prometheus_client.core import Counter, Histogram
from synapse.logging.context import LoggingContext
from synapse.logging.context import current_context
from synapse.metrics import LaterGauge
logger = logging.getLogger(__name__)
@@ -148,7 +148,7 @@ LaterGauge(
class RequestMetrics(object):
def start(self, time_sec, name, method):
self.start = time_sec
self.start_context = LoggingContext.current_context()
self.start_context = current_context()
self.name = name
self.method = method
@@ -163,7 +163,7 @@ class RequestMetrics(object):
with _in_flight_requests_lock:
_in_flight_requests.discard(self)
context = LoggingContext.current_context()
context = current_context()
tag = ""
if context:

View File

@@ -42,7 +42,7 @@ from synapse.logging._terse_json import (
TerseJSONToConsoleLogObserver,
TerseJSONToTCPLogObserver,
)
from synapse.logging.context import LoggingContext
from synapse.logging.context import current_context
def stdlib_log_level_to_twisted(level: str) -> LogLevel:
@@ -86,7 +86,7 @@ class LogContextObserver(object):
].startswith("Timing out client"):
return
context = LoggingContext.current_context()
context = current_context()
# Copy the context information to the log event.
if context is not None:

View File

@@ -175,7 +175,54 @@ class ContextResourceUsage(object):
return res
LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
class _Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = ["previous_context", "finished", "request", "scope", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.finished = False
self.request = None
self.scope = None
self.tag = None
def __str__(self):
return "sentinel"
def copy_to(self, record):
pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self):
pass
def stop(self):
pass
def add_database_transaction(self, duration_sec):
pass
def add_database_scheduled(self, sched_sec):
pass
def record_event_fetch(self, event_count):
pass
def __nonzero__(self):
return False
__bool__ = __nonzero__ # python3
SENTINEL_CONTEXT = _Sentinel()
class LoggingContext(object):
@@ -199,76 +246,33 @@ class LoggingContext(object):
"_resource_usage",
"usage_start",
"main_thread",
"alive",
"finished",
"request",
"tag",
"scope",
]
thread_local = threading.local()
class Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = ["previous_context", "alive", "request", "scope", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.alive = None
self.request = None
self.scope = None
self.tag = None
def __str__(self):
return "sentinel"
def copy_to(self, record):
pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self):
pass
def stop(self):
pass
def add_database_transaction(self, duration_sec):
pass
def add_database_scheduled(self, sched_sec):
pass
def record_event_fetch(self, event_count):
pass
def __nonzero__(self):
return False
__bool__ = __nonzero__ # python3
sentinel = Sentinel()
def __init__(self, name=None, parent_context=None, request=None) -> None:
self.previous_context = LoggingContext.current_context()
self.previous_context = current_context()
self.name = name
# track the resources used by this context so far
self._resource_usage = ContextResourceUsage()
# If alive has the thread resource usage when the logcontext last
# became active.
# The thread resource usage when the logcontext became active. None
# if the context is not currently active.
self.usage_start = None
self.main_thread = get_thread_id()
self.request = None
self.tag = ""
self.alive = True
self.scope = None # type: Optional[_LogContextScope]
# keep track of whether we have hit the __exit__ block for this context
# (suggesting that the the thing that created the context thinks it should
# be finished, and that re-activating it would suggest an error).
self.finished = False
self.parent_context = parent_context
if self.parent_context is not None:
@@ -283,44 +287,15 @@ class LoggingContext(object):
return str(self.request)
return "%s@%x" % (self.name, id(self))
@classmethod
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
Returns:
LoggingContext: the current logging context
"""
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
current = cls.current_context()
if current is not context:
current.stop()
cls.thread_local.current_context = context
context.start()
return current
def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
old_context = self.set_current_context(self)
old_context = set_current_context(self)
if self.previous_context != old_context:
logger.warning(
"Expected previous context %r, found %r",
self.previous_context,
old_context,
)
self.alive = True
return self
def __exit__(self, type, value, traceback) -> None:
@@ -329,24 +304,19 @@ class LoggingContext(object):
Returns:
None to avoid suppressing any exceptions that were thrown.
"""
current = self.set_current_context(self.previous_context)
current = set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
if current is SENTINEL_CONTEXT:
logger.warning("Expected logging context %s was lost", self)
else:
logger.warning(
"Expected logging context %s but found %s", self, current
)
self.alive = False
# if we have a parent, pass our CPU usage stats on
if self.parent_context is not None and hasattr(
self.parent_context, "_resource_usage"
):
self.parent_context._resource_usage += self._resource_usage
# reset them in case we get entered again
self._resource_usage.reset()
# the fact that we are here suggests that the caller thinks that everything
# is done and dusted for this logcontext, and further activity will not get
# recorded against the correct metrics.
self.finished = True
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
@@ -371,9 +341,14 @@ class LoggingContext(object):
logger.warning("Started logcontext %s on different thread", self)
return
if self.finished:
logger.warning("Re-starting finished log context %s", self)
# If we haven't already started record the thread resource usage so
# far
if not self.usage_start:
if self.usage_start:
logger.warning("Re-starting already-active log context %s", self)
else:
self.usage_start = get_thread_resource_usage()
def stop(self) -> None:
@@ -396,6 +371,15 @@ class LoggingContext(object):
self.usage_start = None
# if we have a parent, pass our CPU usage stats on
if self.parent_context is not None and hasattr(
self.parent_context, "_resource_usage"
):
self.parent_context._resource_usage += self._resource_usage
# reset them in case we get entered again
self._resource_usage.reset()
def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far.
@@ -409,7 +393,7 @@ class LoggingContext(object):
# If we are on the correct thread and we're currently running then we
# can include resource usage so far.
is_main_thread = get_thread_id() == self.main_thread
if self.alive and self.usage_start and is_main_thread:
if self.usage_start and is_main_thread:
utime_delta, stime_delta = self._get_cputime()
res.ru_utime += utime_delta
res.ru_stime += stime_delta
@@ -492,7 +476,7 @@ class LoggingContextFilter(logging.Filter):
Returns:
True to include the record in the log output.
"""
context = LoggingContext.current_context()
context = current_context()
for key, value in self.defaults.items():
setattr(record, key, value)
@@ -512,27 +496,24 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
if new_context is None:
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
else:
self.new_context = new_context
def __init__(
self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
) -> None:
self.new_context = new_context
def __enter__(self) -> None:
"""Captures the current logging context"""
self.current_context = LoggingContext.set_current_context(self.new_context)
self.current_context = set_current_context(self.new_context)
if self.current_context:
self.has_parent = self.current_context.previous_context is not None
if not self.current_context.alive:
logger.debug("Entering dead context: %s", self.current_context)
def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context"""
context = LoggingContext.set_current_context(self.current_context)
context = set_current_context(self.current_context)
if context != self.new_context:
if context is LoggingContext.sentinel:
if not context:
logger.warning("Expected logging context %s was lost", self.new_context)
else:
logger.warning(
@@ -541,9 +522,30 @@ class PreserveLoggingContext(object):
context,
)
if self.current_context is not LoggingContext.sentinel:
if not self.current_context.alive:
logger.debug("Restoring dead context: %s", self.current_context)
_thread_local = threading.local()
_thread_local.current_context = SENTINEL_CONTEXT
def current_context() -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage"""
return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
current = current_context()
if current is not context:
current.stop()
_thread_local.current_context = context
context.start()
return current
def nested_logging_context(
@@ -572,7 +574,7 @@ def nested_logging_context(
if parent_context is not None:
context = parent_context # type: LoggingContextOrSentinel
else:
context = LoggingContext.current_context()
context = current_context()
return LoggingContext(
parent_context=context, request=str(context.request) + "-" + suffix
)
@@ -604,7 +606,7 @@ def run_in_background(f, *args, **kwargs):
CRITICAL error about an unhandled error will be logged without much
indication about where it came from.
"""
current = LoggingContext.current_context()
current = current_context()
try:
res = f(*args, **kwargs)
except: # noqa: E722
@@ -625,7 +627,7 @@ def run_in_background(f, *args, **kwargs):
# The function may have reset the context before returning, so
# we need to restore it now.
ctx = LoggingContext.set_current_context(current)
ctx = set_current_context(current)
# The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will
@@ -674,7 +676,7 @@ def make_deferred_yieldable(deferred):
# ok, we can't be sure that a yield won't block, so let's reset the
# logcontext, and add a callback to the deferred to restore it.
prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
prev_context = set_current_context(SENTINEL_CONTEXT)
deferred.addBoth(_set_context_cb, prev_context)
return deferred
@@ -684,7 +686,7 @@ ResultT = TypeVar("ResultT")
def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context"""
LoggingContext.set_current_context(context)
set_current_context(context)
return result
@@ -752,7 +754,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception.
"""
logcontext = LoggingContext.current_context()
logcontext = current_context()
def g():
with LoggingContext(parent_context=logcontext):

View File

@@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager
import twisted
from synapse.logging.context import LoggingContext, nested_logging_context
from synapse.logging.context import current_context, nested_logging_context
logger = logging.getLogger(__name__)
@@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager):
(Scope) : the Scope that is active, or None if not
available.
"""
ctx = LoggingContext.current_context()
if ctx is LoggingContext.sentinel:
return None
else:
return ctx.scope
ctx = current_context()
return ctx.scope
def activate(self, span, finish_on_close):
"""
@@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager):
"""
enter_logcontext = False
ctx = LoggingContext.current_context()
ctx = current_context()
if ctx is LoggingContext.sentinel:
if not ctx:
# We don't want this scope to affect.
logger.error("Tried to activate scope outside of loggingcontext")
return Scope(None, span)

View File

@@ -21,6 +21,7 @@ from synapse.replication.http import (
membership,
register,
send_event,
streams,
)
REPLICATION_PREFIX = "/_synapse/replication"
@@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
streams.register_servlets(hs, self)

View File

@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationGetStreamUpdates(ReplicationEndpoint):
"""Fetches stream updates from a server. Used for streams not persisted to
the database, e.g. typing notifications.
The API looks like:
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
200 OK
{
updates: [ ... ],
upto_token: 10,
limited: False,
}
"""
NAME = "get_repl_stream_updates"
PATH_ARGS = ("stream_name",)
METHOD = "GET"
def __init__(self, hs):
super().__init__(hs)
# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token, limit):
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
async def _handle_request(self, request, stream_name):
stream = self.streams.get(stream_name)
if stream is None:
raise SynapseError(400, "Unknown stream")
from_token = parse_integer(request, "from_token", required=True)
upto_token = parse_integer(request, "upto_token", required=True)
limit = parse_integer(request, "limit", required=True)
updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token, limit
)
return (
200,
{"updates": updates, "upto_token": upto_token, "limited": limited},
)
def register_servlets(hs, http_server):
ReplicationGetStreamUpdates(hs).register(http_server)

View File

@@ -18,8 +18,10 @@ from typing import Dict, Optional
import six
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
from synapse.storage.data_stores.main.cache import (
CURRENT_STATE_CACHE_NAME,
CacheInvalidationWorkerStore,
)
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
@@ -35,7 +37,7 @@ def __func__(inp):
return inp.__func__
class BaseSlavedStore(SQLBaseStore):
class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
@@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:

View File

@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id"
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
@@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
def _invalidate_caches_for_devices(self, token, user_id, destination):
self._device_list_stream_cache.entity_has_changed(user_id, token)
def _invalidate_caches_for_devices(self, token, rows):
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate_many((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
self.get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
)

View File

@@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)

View File

@@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
self.client_name, self.server_name, self._clock, self.handler
self.hs, self.client_name, self.server_name, self._clock, self.handler,
)
def clientConnectionLost(self, connector, reason):

View File

@@ -136,8 +136,8 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
Sent to the client after all missing updates for a stream have been sent
to the client and they're now up to date.
On receipt of a POSITION command clients should check if they have missed
any updates, and if so then fetch them out of band.
"""
NAME = "POSITION"
@@ -179,42 +179,24 @@ class NameCommand(Command):
class ReplicateCommand(Command):
"""Sent by the client to subscribe to the stream.
"""Sent by the client to subscribe to streams.
Format::
REPLICATE <stream_name> <token>
Where <token> may be either:
* a numeric stream_id to stream updates from
* "NOW" to stream all subsequent updates.
The <stream_name> can be "ALL" to subscribe to all known streams, in which
case the <token> must be set to "NOW", i.e.::
REPLICATE ALL NOW
REPLICATE
"""
NAME = "REPLICATE"
def __init__(self, stream_name, token):
self.stream_name = stream_name
self.token = token
def __init__(self):
pass
@classmethod
def from_line(cls, line):
stream_name, token = line.split(" ", 1)
if token in ("NOW", "now"):
token = "NOW"
else:
token = int(token)
return cls(stream_name, token)
return cls()
def to_line(self):
return " ".join((self.stream_name, str(self.token)))
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
return ""
class UserSyncCommand(Command):

View File

@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE events 1
< REPLICATE backfill 1
< REPLICATE caches 1
< REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -53,17 +51,15 @@ import fcntl
import logging
import struct
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Tuple
from typing import Any, DefaultDict, Dict, List, Set
from six import iteritems, iterkeys
from six import iteritems
from prometheus_client import Counter
from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
@@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
SyncCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
from synapse.server import HomeServer
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
@@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
deferreds = [
run_in_background(self.subscribe_to_stream, stream, token)
for stream in iterkeys(self.streamer.streams_by_name)
]
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
await self.subscribe_to_stream(stream_name, token)
# Subscribe to all streams we're publishing to.
for stream_name in self.streamer.streams_by_name:
current_token = self.streamer.get_stream_token(stream_name)
self.send_command(PositionCommand(stream_name, current_token))
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
@@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen,
)
async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
updates down if they have. During that time new updates for the stream
are queued and sent once we've sent down any missed updates.
"""
self.replication_streams.discard(stream_name)
self.connecting_streams.add(stream_name)
try:
# Get missing updates
updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)
# Send all the missing updates
for update in updates:
token, row = update[0], update[1]
self.send_command(RdataCommand(stream_name, token, row))
# We send a POSITION command to ensure that they have an up to
# date token (especially useful if we didn't send any updates
# above)
self.send_command(PositionCommand(stream_name, current_token))
# Now we can send any updates that came in while we were subscribing
pending_rdata = self.pending_rdata.pop(stream_name, [])
updates = []
for token, update in pending_rdata:
# If the token is null, it is part of a batch update. Batches
# are multiple updates that share a single token. To denote
# this, the token is set to None for all tokens in the batch
# except for the last. If we find a None token, we keep looking
# through tokens until we find one that is not None and then
# process all previous updates in the batch as if they had the
# final token.
if token is None:
# Store this update as part of a batch
updates.append(update)
continue
if token <= current_token:
# This update or batch of updates is older than
# current_token, dismiss it
updates = []
continue
updates.append(update)
# Send all updates that are part of this batch with the
# found token
for update in updates:
self.send_command(RdataCommand(stream_name, token, update))
# Clear stored updates
updates = []
# They're now fully subscribed
self.replication_streams.add(stream_name)
except Exception as e:
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
self.send_error("failed to handle replicate: %r", e)
finally:
self.connecting_streams.discard(stream_name)
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
if stream_name in self.replication_streams:
# The client is subscribed to the stream
self.send_command(RdataCommand(stream_name, token, data))
elif stream_name in self.connecting_streams:
# The client is being subscribed to the stream
logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
self.pending_rdata.setdefault(stream_name, []).append((token, data))
else:
# The client isn't subscribed
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
self.send_command(RdataCommand(stream_name, token, data))
def send_sync(self, data):
self.send_command(SyncCommand(data))
@@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(
self,
hs: "HomeServer",
client_name: str,
server_name: str,
clock: Clock,
@@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.handler = handler
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = set() # type: Set[str]
self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, Any]
self.pending_batches = {} # type: Dict[str, List[Any]]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token)
self.replicate()
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
@@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
# This will happen if we don't actually subscribe to any streams
if not self.streams_connecting:
self.handler.finished_connecting()
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
raise
if cmd.token is None:
if cmd.token is None or stream_name in self.streams_connecting:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
@@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
async def on_POSITION(self, cmd: PositionCommand):
stream = self.streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return
# Find where we previously streamed up to.
current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
)
return
# Fetch all updates between then and now.
limited = True
while limited:
updates, current_token, limited = await stream.get_updates_since(
current_token, cmd.token
)
# Check if the connection was closed underneath us, if so we bail
# rather than risk having concurrent catch ups going on.
if self.state == ConnectionStates.CLOSED:
return
if updates:
await self.handler.on_rdata(
cmd.stream_name,
current_token,
[stream.parse_row(update[1]) for update in updates],
)
# We've now caught up to position sent to us, notify handler.
await self.handler.on_position(cmd.stream_name, cmd.token)
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
await self.handler.on_position(cmd.stream_name, cmd.token)
# Check if the connection was closed underneath us, if so we bail
# rather than risk having concurrent catch ups going on.
if self.state == ConnectionStates.CLOSED:
return
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])
if rows:
await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
@@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
def replicate(self):
"""Send the subscription request to the server
"""
if stream_name not in STREAMS_MAP:
raise Exception("Invalid stream name %r" % (stream_name,))
logger.info("[%s] Subscribing to replication streams", self.id())
logger.info(
"[%s] Subscribing to replication stream: %r from %r",
self.id(),
stream_name,
token,
)
self.streams_connecting.add(stream_name)
self.send_command(ReplicateCommand(stream_name, token))
self.send_command(ReplicateCommand())
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)

View File

@@ -17,7 +17,7 @@
import logging
import random
from typing import Any, List
from typing import Any, Dict, List
from six import itervalues
@@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func
from .protocol import ServerReplicationStreamProtocol
from .streams import STREAMS_MAP
from .streams import STREAMS_MAP, Stream
from .streams.federation import FederationStream
stream_updates_counter = Counter(
@@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory):
"""
def __init__(self, hs):
self.streamer = ReplicationStreamer(hs)
self.streamer = hs.get_replication_streamer()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
@@ -133,6 +133,11 @@ class ReplicationStreamer(object):
for conn in self.connections:
conn.send_error("server shutting down")
def get_streams(self) -> Dict[str, Stream]:
"""Get a mapp from stream name to stream instance.
"""
return self.streams_by_name
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -166,11 +171,6 @@ class ReplicationStreamer(object):
self.pending_updates = False
with Measure(self.clock, "repl.stream.get_updates"):
# First we tell the streams that they should update their
# current tokens.
for stream in self.streams:
stream.advance_current_token()
all_streams = self.streams
if self._replication_torture_level is not None:
@@ -180,7 +180,7 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
if stream.last_token == stream.upto_token:
if stream.last_token == stream.current_token():
continue
if self._replication_torture_level:
@@ -192,10 +192,11 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
stream.upto_token,
stream.current_token(),
)
try:
updates, current_token = await stream.get_updates()
updates, current_token, limited = await stream.get_updates()
self.pending_updates |= limited
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
@@ -231,8 +232,7 @@ class ReplicationStreamer(object):
self.pending_updates = False
self.is_looping = False
@measure_func("repl.get_stream_updates")
async def get_stream_updates(self, stream_name, token):
def get_stream_token(self, stream_name):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
@@ -240,7 +240,7 @@ class ReplicationStreamer(object):
if not stream:
raise Exception("unknown stream %s", stream_name)
return await stream.get_updates_since(token)
return stream.current_token()
@measure_func("repl.federation_ack")
def federation_ack(self, token):

View File

@@ -25,26 +25,66 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
from . import _base, events, federation
from typing import Dict, Type
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
CachesStream,
DeviceListsStream,
GroupServerStream,
PresenceStream,
PublicRoomsStream,
PushersStream,
PushRulesStream,
ReceiptsStream,
Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
UserSignatureStream,
)
from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.federation import FederationStream
STREAMS_MAP = {
stream.NAME: stream
for stream in (
events.EventsStream,
_base.BackfillStream,
_base.PresenceStream,
_base.TypingStream,
_base.ReceiptsStream,
_base.PushRulesStream,
_base.PushersStream,
_base.CachesStream,
_base.PublicRoomsStream,
_base.DeviceListsStream,
_base.ToDeviceStream,
federation.FederationStream,
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.GroupServerStream,
_base.UserSignatureStream,
EventsStream,
BackfillStream,
PresenceStream,
TypingStream,
ReceiptsStream,
PushRulesStream,
PushersStream,
CachesStream,
PublicRoomsStream,
DeviceListsStream,
ToDeviceStream,
FederationStream,
TagAccountDataStream,
AccountDataStream,
GroupServerStream,
UserSignatureStream,
)
}
} # type: Dict[str, Type[Stream]]
__all__ = [
"STREAMS_MAP",
"Stream",
"BackfillStream",
"PresenceStream",
"TypingStream",
"ReceiptsStream",
"PushRulesStream",
"PushersStream",
"CachesStream",
"PublicRoomsStream",
"DeviceListsStream",
"ToDeviceStream",
"TagAccountDataStream",
"AccountDataStream",
"GroupServerStream",
"UserSignatureStream",
]

View File

@@ -14,114 +14,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import namedtuple
from typing import Any, List, Optional
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000
BackfillStreamRow = namedtuple(
"BackfillStreamRow",
(
"event_id", # str
"room_id", # str
"type", # str
"state_key", # str, optional
"redacts", # str, optional
"relates_to", # str, optional
),
)
PresenceStreamRow = namedtuple(
"PresenceStreamRow",
(
"user_id", # str
"state", # str
"last_active_ts", # int
"last_federation_update_ts", # int
"last_user_sync_ts", # int
"status_msg", # str
"currently_active", # bool
),
)
TypingStreamRow = namedtuple(
"TypingStreamRow", ("room_id", "user_ids") # str # list(str)
)
ReceiptsStreamRow = namedtuple(
"ReceiptsStreamRow",
(
"room_id", # str
"receipt_type", # str
"user_id", # str
"event_id", # str
"data", # dict
),
)
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
PushersStreamRow = namedtuple(
"PushersStreamRow",
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
)
# Some type aliases to make things a bit easier.
@attr.s
class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
# A stream position token
Token = int
Attributes:
cache_func: Name of the cached function.
keys: The entry in the cache to invalidate. If None then will
invalidate all.
invalidation_ts: Timestamp of when the invalidation took place.
"""
cache_func = attr.ib(type=str)
keys = attr.ib(type=Optional[List[Any]])
invalidation_ts = attr.ib(type=int)
PublicRoomsStreamRow = namedtuple(
"PublicRoomsStreamRow",
(
"room_id", # str
"visibility", # str
"appservice_id", # str, optional
"network_id", # str, optional
),
)
DeviceListsStreamRow = namedtuple(
"DeviceListsStreamRow", ("user_id", "destination") # str # str
)
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
AccountDataStreamRow = namedtuple(
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
)
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
StreamRow = Tuple[Token, tuple]
class Stream(object):
"""Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last
time it was called up until the point `advance_current_token` was called.
time it was called.
"""
NAME = None # type: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod
def parse_row(cls, row):
@@ -139,80 +65,56 @@ class Stream(object):
return cls.ROW_TYPE(*row)
def __init__(self, hs):
# The token from which we last asked for updates
self.last_token = self.current_token()
# The token that we will get updates up to
self.upto_token = self.current_token()
def advance_current_token(self):
"""Updates `upto_token` to "now", which updates up until which point
get_updates[_since] will fetch rows till.
"""
self.upto_token = self.current_token()
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
self.upto_token = self.current_token()
self.last_token = self.upto_token
self.last_token = self.current_token()
async def get_updates(self):
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before),
until the `upto_token`
since the stream was constructed if it hadn't been called before).
Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]:
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries, `new_last_token` is the new
position in stream, and `limited` is whether there are more updates
to fetch.
"""
updates, current_token = await self.get_updates_since(self.last_token)
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token
)
self.last_token = current_token
return updates, current_token
return updates, current_token, limited
async def get_updates_since(self, from_token):
async def get_updates_since(
self, from_token: Token, upto_token: Token, limit: int = 100
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]:
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries, `new_last_token` is the new
position in stream, and `limited` is whether there are more updates
to fetch.
"""
if from_token in ("NOW", "now"):
return [], self.upto_token
current_token = self.upto_token
from_token = int(from_token)
if from_token == current_token:
return [], current_token
if from_token == upto_token:
return [], upto_token, False
logger.info("get_updates_since: %s", self.__class__)
if self._LIMITED:
rows = await self.update_function(
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
rows = await self.update_function(from_token, current_token)
updates = [(row[0], row[1:]) for row in rows]
# check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator.
if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
return updates, current_token
updates, upto_token, limited = await self.update_function(
from_token, upto_token, limit=limit,
)
return updates, upto_token, limited
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -223,9 +125,8 @@ class Stream(object):
"""
raise NotImplementedError()
def update_function(self, from_token, current_token, limit=None):
"""Get updates between from_token and to_token. If Stream._LIMITED is
True then limit is provided, otherwise it's not.
def update_function(self, from_token, current_token, limit):
"""Get updates between from_token and to_token.
Returns:
Deferred(list(tuple)): the first entry in the tuple is the token for
@@ -235,52 +136,144 @@ class Stream(object):
raise NotImplementedError()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) == limit:
upto_token = rows[-1][0]
limited = True
return updates, upto_token, limited
return update_function
def make_http_update_function(
hs, stream_name: str
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
"""
client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function(
from_token: int, upto_token: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
return await client(
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
limit=limit,
)
return update_function
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
"""
BackfillStreamRow = namedtuple(
"BackfillStreamRow",
(
"event_id", # str
"room_id", # str
"type", # str
"state_key", # str, optional
"redacts", # str, optional
"relates_to", # str, optional
),
)
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore
self.update_function = store.get_all_new_backfill_event_rows # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs)
class PresenceStream(Stream):
PresenceStreamRow = namedtuple(
"PresenceStreamRow",
(
"user_id", # str
"state", # str
"last_active_ts", # int
"last_federation_update_ts", # int
"last_user_sync_ts", # int
"status_msg", # str
"currently_active", # bool
),
)
NAME = "presence"
_LIMITED = False
ROW_TYPE = PresenceStreamRow
def __init__(self, hs):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
self._is_worker = hs.config.worker_app is not None
self.current_token = store.get_current_presence_token # type: ignore
self.update_function = presence_handler.get_all_presence_updates # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(PresenceStream, self).__init__(hs)
class TypingStream(Stream):
TypingStreamRow = namedtuple(
"TypingStreamRow", ("room_id", "user_ids") # str # list(str)
)
NAME = "typing"
_LIMITED = False
ROW_TYPE = TypingStreamRow
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore
self.update_function = typing_handler.get_all_typing_updates # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(TypingStream, self).__init__(hs)
class ReceiptsStream(Stream):
ReceiptsStreamRow = namedtuple(
"ReceiptsStreamRow",
(
"room_id", # str
"receipt_type", # str
"user_id", # str
"event_id", # str
"data", # dict
),
)
NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow
@@ -288,7 +281,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore
self.update_function = store.get_all_updated_receipts # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -297,6 +290,8 @@ class PushRulesStream(Stream):
"""A user has changed their push rules
"""
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
@@ -310,13 +305,24 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
return [(row[0], row[2]) for row in rows]
limited = False
if len(rows) == limit:
to_token = rows[-1][0]
limited = True
return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream):
"""A user has added/changed/removed a pusher
"""
PushersStreamRow = namedtuple(
"PushersStreamRow",
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
)
NAME = "pushers"
ROW_TYPE = PushersStreamRow
@@ -324,7 +330,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore
self.update_function = store.get_all_updated_pushers_rows # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs)
@@ -334,6 +340,21 @@ class CachesStream(Stream):
the cache on the workers
"""
@attr.s
class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
Attributes:
cache_func: Name of the cached function.
keys: The entry in the cache to invalidate. If None then will
invalidate all.
invalidation_ts: Timestamp of when the invalidation took place.
"""
cache_func = attr.ib(type=str)
keys = attr.ib(type=Optional[List[Any]])
invalidation_ts = attr.ib(type=int)
NAME = "caches"
ROW_TYPE = CachesStreamRow
@@ -341,7 +362,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore
self.update_function = store.get_all_updated_caches # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs)
@@ -350,6 +371,16 @@ class PublicRoomsStream(Stream):
"""The public rooms list changed
"""
PublicRoomsStreamRow = namedtuple(
"PublicRoomsStreamRow",
(
"room_id", # str
"visibility", # str
"appservice_id", # str, optional
"network_id", # str, optional
),
)
NAME = "public_rooms"
ROW_TYPE = PublicRoomsStreamRow
@@ -357,24 +388,28 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore
self.update_function = store.get_all_new_public_rooms # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs)
class DeviceListsStream(Stream):
"""Someone added/changed/removed a device
"""Either a user has updated their devices or a remote server needs to be
told about a device update.
"""
@attr.s
class DeviceListsStreamRow:
entity = attr.ib(type=str)
NAME = "device_lists"
_LIMITED = False
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -383,6 +418,8 @@ class ToDeviceStream(Stream):
"""New to_device messages for a client
"""
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
@@ -390,7 +427,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore
self.update_function = store.get_all_new_device_messages # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -399,6 +436,10 @@ class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room
"""
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
@@ -406,7 +447,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore
self.update_function = store.get_all_updated_tags # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -415,6 +456,10 @@ class AccountDataStream(Stream):
"""Global or per room account data was changed
"""
AccountDataStreamRow = namedtuple(
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
@@ -422,10 +467,11 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
async def update_function(self, from_token, to_token, limit):
async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -440,6 +486,11 @@ class AccountDataStream(Stream):
class GroupServerStream(Stream):
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
NAME = "groups"
ROW_TYPE = GroupsStreamRow
@@ -447,7 +498,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore
self.update_function = store.get_all_groups_changes # type: ignore
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -456,14 +507,15 @@ class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key
"""
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
NAME = "user_signature"
_LIMITED = False
ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs)

View File

@@ -19,7 +19,7 @@ from typing import Tuple, Type
import attr
from ._base import Stream
from ._base import Stream, db_query_to_update_function
"""Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs)
async def update_function(self, from_token, current_token, limit=None):
async def _update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)

View File

@@ -15,15 +15,9 @@
# limitations under the License.
from collections import namedtuple
from ._base import Stream
from twisted.internet import defer
FederationStreamRow = namedtuple(
"FederationStreamRow",
(
"type", # str, the type of data as defined in the BaseFederationRows
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
),
)
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
class FederationStream(Stream):
@@ -31,13 +25,28 @@ class FederationStream(Stream):
sending disabled.
"""
FederationStreamRow = namedtuple(
"FederationStreamRow",
(
"type", # str, the type of data as defined in the BaseFederationRows
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
),
)
NAME = "federation"
ROW_TYPE = FederationStreamRow
_QUERY_MASTER = True
def __init__(self, hs):
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = federation_sender.get_replication_rows # type: ignore
# Not all synapse instances will have a federation sender instance,
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
# so we stub the stream out when that is the case.
if hs.config.worker_app is None or hs.should_send_federation():
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
else:
self.current_token = lambda: 0 # type: ignore
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
super(FederationStream, self).__init__(hs)

View File

@@ -28,7 +28,6 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.push.mailer import load_jinja2_templates
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -73,6 +72,14 @@ def login_id_thirdparty_from_phone(identifier):
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
def build_service_param(cas_service_url, client_redirect_url):
return "%s%s?redirectUrl=%s" % (
cas_service_url,
"/_matrix/client/r0/login/cas/ticket",
urllib.parse.quote(client_redirect_url, safe=""),
)
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
@@ -428,18 +435,15 @@ class BaseSSORedirectServlet(RestServlet):
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url.encode("ascii")
self.cas_service_url = hs.config.cas_service_url.encode("ascii")
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
def get_sso_url(self, client_redirect_url):
client_redirect_url_param = urllib.parse.urlencode(
{b"redirectUrl": client_redirect_url}
).encode("ascii")
hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
service_param = urllib.parse.urlencode(
{b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
).encode("ascii")
return b"%s/login?%s" % (self.cas_server_url, service_param)
args = urllib.parse.urlencode(
{"service": build_service_param(self.cas_service_url, client_redirect_url)}
)
return "%s/login?%s" % (self.cas_server_url, args)
class CasTicketServlet(RestServlet):
@@ -459,7 +463,7 @@ class CasTicketServlet(RestServlet):
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url,
"service": build_service_param(self.cas_service_url, client_redirect_url),
}
try:
body = await self._http_client.get_raw(uri, args)
@@ -548,13 +552,6 @@ class SSOAuthHandler(object):
self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator()
# Load the redirect page HTML template
self._template = load_jinja2_templates(
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
)[0]
self._server_name = hs.config.server_name
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)

View File

@@ -142,14 +142,6 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
elif stagetype == LoginType.TERMS:
html = TERMS_TEMPLATE % {
"session": session,
@@ -158,17 +150,19 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
else:
raise SynapseError(404, "Unknown auth stage type")
# Render the HTML and return.
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
async def on_POST(self, request, stagetype):
session = parse_string(request, "session")
@@ -196,15 +190,6 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
@@ -225,17 +210,19 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
else:
raise SynapseError(404, "Unknown auth stage type")
# Render the HTML and return.
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
def on_OPTIONS(self, _):
return 200, {}

View File

@@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource):
b" media-src 'self';"
b" object-src 'self';",
)
request.setHeader(
b"Referrer-Policy", b"no-referrer",
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)

View File

@@ -24,7 +24,6 @@ from six import iteritems
import twisted.internet.error
import twisted.web.http
from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -114,15 +113,14 @@ class MediaRepository(object):
"update_recently_accessed_media", self._update_recently_accessed
)
@defer.inlineCallbacks
def _update_recently_accessed(self):
async def _update_recently_accessed(self):
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()
yield self.store.update_cached_last_access_time(
await self.store.update_cached_last_access_time(
local_media, remote_media, self.clock.time_msec()
)
@@ -138,8 +136,7 @@ class MediaRepository(object):
else:
self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def create_content(
async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
"""Store uploaded content for a local user and return the mxc URL
@@ -158,11 +155,11 @@ class MediaRepository(object):
file_info = FileInfo(server_name=None, file_id=media_id)
fname = yield self.media_storage.store_file(content, file_info)
fname = await self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media(
await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -171,12 +168,11 @@ class MediaRepository(object):
user_id=auth_user,
)
yield self._generate_thumbnails(None, media_id, media_id, media_type)
await self._generate_thumbnails(None, media_id, media_id, media_type)
return "mxc://%s/%s" % (self.server_name, media_id)
@defer.inlineCallbacks
def get_local_media(self, request, media_id, name):
async def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
@@ -190,7 +186,7 @@ class MediaRepository(object):
Deferred: Resolves once a response has successfully been written
to request
"""
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
@@ -204,13 +200,12 @@ class MediaRepository(object):
file_info = FileInfo(None, media_id, url_cache=url_cache)
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
@defer.inlineCallbacks
def get_remote_media(self, request, server_name, media_id, name):
async def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
@@ -236,8 +231,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -246,14 +241,13 @@ class MediaRepository(object):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
yield respond_with_responder(
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)
@defer.inlineCallbacks
def get_remote_media_info(self, server_name, media_id):
async def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -274,8 +268,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -286,8 +280,7 @@ class MediaRepository(object):
return media_info
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
async def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -299,7 +292,7 @@ class MediaRepository(object):
Returns:
Deferred[(Responder, media_info)]
"""
media_info = yield self.store.get_cached_remote_media(server_name, media_id)
media_info = await self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
@@ -317,19 +310,18 @@ class MediaRepository(object):
logger.info("Media is quarantined")
raise NotFoundError()
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info
# Failed to find the file anywhere, lets download it.
media_info = yield self._download_remote_file(server_name, media_id, file_id)
media_info = await self._download_remote_file(server_name, media_id, file_id)
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id, file_id):
async def _download_remote_file(self, server_name, media_id, file_id):
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -351,7 +343,7 @@ class MediaRepository(object):
("/_matrix/media/v1/download", server_name, media_id)
)
try:
length, headers = yield self.client.get_file(
length, headers = await self.client.get_file(
server_name,
request_path,
output_stream=f,
@@ -397,7 +389,7 @@ class MediaRepository(object):
)
raise SynapseError(502, "Failed to fetch remote media")
yield finish()
await finish()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
@@ -405,7 +397,7 @@ class MediaRepository(object):
logger.info("Stored remote media in file %r", fname)
yield self.store.store_cached_remote_media(
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
@@ -423,7 +415,7 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
await self._generate_thumbnails(server_name, media_id, file_id, media_type)
return media_info
@@ -458,16 +450,15 @@ class MediaRepository(object):
return t_byte_source
@defer.inlineCallbacks
def generate_local_exact_thumbnail(
async def generate_local_exact_thumbnail(
self, media_id, t_width, t_height, t_method, t_type, url_cache
):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -490,7 +481,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -500,22 +491,21 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
yield self.store.store_local_thumbnail(
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return output_path
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(
async def generate_remote_exact_thumbnail(
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -537,7 +527,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -547,7 +537,7 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
yield self.store.store_remote_media_thumbnail(
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -560,8 +550,7 @@ class MediaRepository(object):
return output_path
@defer.inlineCallbacks
def _generate_thumbnails(
async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
"""Generate and store thumbnails for an image.
@@ -582,7 +571,7 @@ class MediaRepository(object):
if not requirements:
return
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
@@ -600,7 +589,7 @@ class MediaRepository(object):
return
if thumbnailer.transpose_method is not None:
m_width, m_height = yield defer_to_thread(
m_width, m_height = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)
@@ -620,11 +609,11 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
@@ -646,7 +635,7 @@ class MediaRepository(object):
url_cache=url_cache,
)
output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -656,7 +645,7 @@ class MediaRepository(object):
# Write to database
if server_name:
yield self.store.store_remote_media_thumbnail(
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -667,15 +656,14 @@ class MediaRepository(object):
t_len,
)
else:
yield self.store.store_local_thumbnail(
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return {"width": m_width, "height": m_height}
@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
old_media = yield self.store.get_remote_media_before(before_ts)
async def delete_old_remote_media(self, before_ts):
old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@@ -689,7 +677,7 @@ class MediaRepository(object):
# TODO: Should we delete from the backup store
with (yield self.remote_media_linearizer.queue(key)):
with (await self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
@@ -705,7 +693,7 @@ class MediaRepository(object):
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
yield self.store.delete_remote_media(origin, media_id)
await self.store.delete_remote_media(origin, media_id)
deleted += 1
return {"deleted": deleted}

View File

@@ -165,8 +165,7 @@ class PreviewUrlResource(DirectServeResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
@defer.inlineCallbacks
def _do_preview(self, url, user, ts):
async def _do_preview(self, url, user, ts):
"""Check the db, and download the URL and build a preview
Args:
@@ -179,7 +178,7 @@ class PreviewUrlResource(DirectServeResource):
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts)
cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
@@ -192,13 +191,13 @@ class PreviewUrlResource(DirectServeResource):
og = og.encode("utf8")
return og
media_info = yield self._download_url(url, user)
media_info = await self._download_url(url, user)
logger.debug("got media_info of '%s'", media_info)
if _is_media(media_info["media_type"]):
file_id = media_info["filesystem_id"]
dims = yield self.media_repo._generate_thumbnails(
dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, media_info["media_type"], url_cache=True
)
@@ -248,14 +247,14 @@ class PreviewUrlResource(DirectServeResource):
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if "og:image" in og and og["og:image"]:
image_info = yield self._download_url(
image_info = await self._download_url(
_rebase_url(og["og:image"], media_info["uri"]), user
)
if _is_media(image_info["media_type"]):
# TODO: make sure we don't choke on white-on-transparent images
file_id = image_info["filesystem_id"]
dims = yield self.media_repo._generate_thumbnails(
dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, image_info["media_type"], url_cache=True
)
if dims:
@@ -293,7 +292,7 @@ class PreviewUrlResource(DirectServeResource):
jsonog = json.dumps(og)
# store OG in history-aware DB cache
yield self.store.store_url_cache(
await self.store.store_url_cache(
url,
media_info["response_code"],
media_info["etag"],
@@ -305,8 +304,7 @@ class PreviewUrlResource(DirectServeResource):
return jsonog.encode("utf8")
@defer.inlineCallbacks
def _download_url(self, url, user):
async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -318,7 +316,7 @@ class PreviewUrlResource(DirectServeResource):
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
logger.debug("Trying to get url '%s'", url)
length, headers, uri, code = yield self.client.get_file(
length, headers, uri, code = await self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size
)
except SynapseError:
@@ -345,7 +343,7 @@ class PreviewUrlResource(DirectServeResource):
% (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
yield finish()
await finish()
try:
if b"Content-Type" in headers:
@@ -356,7 +354,7 @@ class PreviewUrlResource(DirectServeResource):
download_name = get_filename_from_headers(headers)
yield self.store.store_local_media(
await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -393,8 +391,7 @@ class PreviewUrlResource(DirectServeResource):
"expire_url_cache_data", self._expire_url_cache_data
)
@defer.inlineCallbacks
def _expire_url_cache_data(self):
async def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
@@ -403,12 +400,12 @@ class PreviewUrlResource(DirectServeResource):
logger.info("Running url preview cache expiry")
if not (yield self.store.db.updates.has_completed_background_updates()):
if not (await self.store.db.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
# First we delete expired url cache entries
media_ids = yield self.store.get_expired_url_cache(now)
media_ids = await self.store.get_expired_url_cache(now)
removed_media = []
for media_id in media_ids:
@@ -430,7 +427,7 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
yield self.store.delete_url_cache(removed_media)
await self.store.delete_url_cache(removed_media)
if removed_media:
logger.info("Deleted %d entries from url cache", len(removed_media))
@@ -440,7 +437,7 @@ class PreviewUrlResource(DirectServeResource):
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
expire_before = now - 2 * 24 * 60 * 60 * 1000
media_ids = yield self.store.get_url_cache_media_before(expire_before)
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
for media_id in media_ids:
@@ -478,7 +475,7 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
yield self.store.delete_url_cache_media(removed_media)
await self.store.delete_url_cache_media(removed_media)
logger.info("Deleted %d media from url cache", len(removed_media))

View File

@@ -16,8 +16,6 @@
import logging
from twisted.internet import defer
from synapse.http.server import (
DirectServeResource,
set_cors_headers,
@@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource):
)
self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(
async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type
):
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
@@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(request, responder, t_type, t_length)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_local_thumbnail(
async def _select_or_generate_local_thumbnail(
self,
request,
media_id,
@@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info["thumbnail_width"] == desired_width
t_h = info["thumbnail_height"] == desired_height
@@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
yield respond_with_responder(request, responder, t_type, t_length)
await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
file_path = await self.media_repo.generate_local_exact_thumbnail(
media_id,
desired_width,
desired_height,
@@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(
async def _select_or_generate_remote_thumbnail(
self,
request,
server_name,
@@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
yield respond_with_responder(request, responder, t_type, t_length)
await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
file_path = await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id,
media_id,
@@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
def _respond_remote_thumbnail(
async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(request, responder, t_type, t_length)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)

View File

@@ -85,6 +85,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
@@ -199,6 +200,7 @@ class HomeServer(object):
"saml_handler",
"event_client_serializer",
"storage",
"replication_streamer",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -536,6 +538,9 @@ class HomeServer(object):
def build_storage(self) -> Storage:
return Storage(self, self.datastores)
def build_replication_streamer(self) -> ReplicationStreamer:
return ReplicationStreamer(self)
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@@ -144,7 +144,10 @@ class DataStore(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[("user_signature_stream", "stream_id")],
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"

View File

@@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationStore(SQLBaseStore):
class CacheInvalidationWorkerStore(SQLBaseStore):
def get_all_updated_caches(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
class CacheInvalidationStore(CacheInvalidationWorkerStore):
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
},
)
def get_all_updated_caches(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()

View File

@@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
last_pos(int):
current_pos(int):
limit(int):
Returns:
A deferred list of rows from the device inbox
"""
if last_pos == current_pos:
return defer.succeed([])
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
sql = (
"SELECT max(stream_id), destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn)
# Order by ascending stream ordering
rows.sort()
return rows
return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
@@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
last_pos(int):
current_pos(int):
limit(int):
Returns:
A deferred list of rows from the device inbox
"""
if last_pos == current_pos:
return defer.succeed([])
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
sql = (
"SELECT max(stream_id), destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn)
# Order by ascending stream ordering
rows.sort()
return rows
return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Tuple
from six import iteritems
@@ -31,7 +32,7 @@ from synapse.logging.opentracing import (
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.database import Database, LoggingTransaction
from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import (
Cache,
@@ -112,23 +113,13 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed:
return now_stream_id, []
# We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
# device has the same stream_id as the second-to-last device. If so,
# then we ignore all devices with that stream_id and only send the
# devices with a lower stream_id.
#
# If when culling the list we end up with no devices afterwards, we
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
limit + 1,
limit,
)
# Return an empty list if there are no updates
@@ -166,14 +157,6 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
stream_id_cutoff = updates[-1][2]
now_stream_id = stream_id_cutoff - 1
else:
stream_id_cutoff = None
# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
@@ -192,10 +175,6 @@ class DeviceWorkerStore(SQLBaseStore):
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
@@ -218,17 +197,6 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
# steam_id.
# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
@@ -607,22 +575,33 @@ class DeviceWorkerStore(SQLBaseStore):
else:
return set()
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
async def get_all_device_list_changes_for_remotes(
self, from_key: int, to_key: int, limit: int,
) -> List[Tuple[int, str]]:
"""Return a list of `(stream_id, entity)` which is the combined list of
changes to devices and which destinations need to be poked. Entity is
either a user ID (starting with '@') or a remote destination.
"""
# We do a group by here as there can be a large number of duplicate
# entries, since we throw away device IDs.
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
SELECT MAX(stream_id) AS stream_id, user_id, destination
FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
SELECT stream_id, entity FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream
UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
LIMIT ?
"""
return self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
return await self.db.execute(
"get_all_device_list_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
)
@cached(max_entries=10000)
@@ -1017,29 +996,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
if not device_ids:
return
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
yield self.db.runInteraction(
"add_device_change_to_streams",
self._add_device_change_txn,
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
user_id,
device_ids,
stream_ids,
)
if not hosts:
return stream_ids[-1]
context = get_active_span_text_map()
with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
yield self.db.runInteraction(
"add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn,
user_id,
device_ids,
hosts,
stream_id,
stream_ids,
context,
)
return stream_id
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
return stream_ids[-1]
def _add_device_change_to_stream_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Collection[str],
stream_ids: List[str],
):
txn.call_after(
self._device_list_stream_cache.entity_has_changed, user_id, stream_id
self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_id,
)
min_stream_id = stream_ids[0]
# Delete older entries in the table, as we really only care about
# when the latest change happened.
@@ -1048,7 +1047,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
[(user_id, device_id, stream_id) for device_id in device_ids],
[(user_id, device_id, min_stream_id) for device_id in device_ids],
)
self.db.simple_insert_many_txn(
@@ -1056,11 +1055,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_stream",
values=[
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
for device_id in device_ids
for stream_id, device_id in zip(stream_ids, device_ids)
],
)
context = get_active_span_text_map()
def _add_device_outbound_poke_to_stream_txn(
self, txn, user_id, device_ids, hosts, stream_ids, context,
):
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)
now = self._clock.time_msec()
next_stream_id = iter(stream_ids)
self.db.simple_insert_many_txn(
txn,
@@ -1068,7 +1078,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
{
"destination": destination,
"stream_id": stream_id,
"stream_id": next(next_stream_id),
"user_id": user_id,
"device_id": device_id,
"sent": False,

View File

@@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
@@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
"""
sql = """
SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id
ORDER BY stream_id ASC
LIMIT ?
"""
return self.db.execute(
"get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
"get_all_user_signature_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
)

View File

@@ -1267,104 +1267,6 @@ class EventsStore(
ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
return ret
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_id, upper_bound))
new_event_updates.extend(txn)
return new_event_updates
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound))
new_event_updates.extend(txn.fetchall())
return new_event_updates
return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
@cached(num_args=5, max_entries=10)
def get_all_new_events(
self,
@@ -1850,22 +1752,6 @@ class EventsStore(
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
def insert_labels_for_event_txn(
self, txn, event_id, labels, room_id, topological_ordering
):

View File

@@ -35,7 +35,7 @@ from synapse.api.room_versions import (
)
from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
@@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
log_ctx = LoggingContext.current_context()
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
# Note that _get_events_from_db is also responsible for turning db rows
@@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore):
complexity_v1 = round(state_events / 500, 2)
return {"v1": complexity_v1}
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_id, upper_bound))
new_event_updates.extend(txn)
return new_event_updates
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound))
new_event_updates.extend(txn.fetchall())
return new_event_updates
return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)

View File

@@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore):
"status_msg": state.status_msg,
"currently_active": state.currently_active,
}
for state in presence_states
for stream_id, state in zip(stream_orderings, presence_states)
],
)
@@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore):
)
txn.execute(sql + clause, [stream_id] + list(args))
def get_all_presence_updates(self, last_id, current_id):
def get_all_presence_updates(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_presence_updates_txn(txn):
sql = (
"SELECT stream_id, user_id, state, last_active_ts,"
" last_federation_update_ts, last_user_sync_ts, status_msg,"
" currently_active"
" FROM presence_stream"
" WHERE ? < stream_id AND stream_id <= ?"
)
txn.execute(sql, (last_id, current_id))
sql = """
SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts,
status_msg,
currently_active
FROM presence_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(

View File

@@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
"""Marks the room as blocked. Can be called multiple times.

View File

@@ -32,6 +32,7 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import (
LoggingContext,
LoggingContextOrSentinel,
current_context,
make_deferred_yieldable,
)
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -483,7 +484,7 @@ class Database(object):
end = monotonic_time()
duration = end - start
LoggingContext.current_context().add_database_transaction(duration)
current_context().add_database_transaction(duration)
transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
@@ -510,7 +511,7 @@ class Database(object):
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
if LoggingContext.current_context() == LoggingContext.sentinel:
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
@@ -547,10 +548,8 @@ class Database(object):
Returns:
Deferred: The result of func
"""
parent_context = (
LoggingContext.current_context()
) # type: Optional[LoggingContextOrSentinel]
if parent_context == LoggingContext.sentinel:
parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
if not parent_context:
logger.warning(
"Starting db connection from sentinel context: metrics will be lost"
)

View File

@@ -15,25 +15,28 @@
import platform
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
MYPY = False
def create_engine(database_config) -> BaseDatabaseEngine:
name = database_config["name"]
if name == "sqlite3":
if name == "sqlite3" or MYPY:
import sqlite3
from .sqlite import Sqlite3Engine
return Sqlite3Engine(sqlite3, database_config)
if name == "psycopg2":
if name == "psycopg2" or MYPY:
# pypy requires psycopg2cffi rather than psycopg2
if platform.python_implementation() == "PyPy":
import psycopg2cffi as psycopg2 # type: ignore
else:
import psycopg2 # type: ignore
from .postgres import PostgresEngine
return PostgresEngine(psycopg2, database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,))

View File

@@ -12,12 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sqlite3
import struct
import threading
from synapse.storage.engines import BaseDatabaseEngine
MYPY = False
if MYPY:
import sqlite3
class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
def __init__(self, database_module, database_config):

View File

@@ -21,7 +21,7 @@ from prometheus_client import Counter
from twisted.internet import defer
from synapse.logging.context import LoggingContext
from synapse.logging.context import LoggingContext, current_context
from synapse.metrics import InFlightGauge
logger = logging.getLogger(__name__)
@@ -106,7 +106,7 @@ class Measure(object):
raise RuntimeError("Measure() objects cannot be re-used")
self.start = self.clock.time()
parent_context = LoggingContext.current_context()
parent_context = current_context()
self._logging_context = LoggingContext(
"Measure[%s]" % (self.name,), parent_context
)

View File

@@ -32,7 +32,7 @@ def do_patch():
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
"""
from synapse.logging.context import LoggingContext
from synapse.logging.context import current_context
global _already_patched
@@ -43,35 +43,35 @@ def do_patch():
def new_inline_callbacks(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
start_context = LoggingContext.current_context()
start_context = current_context()
changes = [] # type: List[str]
orig = orig_inline_callbacks(_check_yield_points(f, changes))
try:
res = orig(*args, **kwargs)
except Exception:
if LoggingContext.current_context() != start_context:
if current_context() != start_context:
for err in changes:
print(err, file=sys.stderr)
err = "%s changed context from %s to %s on exception" % (
f,
start_context,
LoggingContext.current_context(),
current_context(),
)
print(err, file=sys.stderr)
raise Exception(err)
raise
if not isinstance(res, Deferred) or res.called:
if LoggingContext.current_context() != start_context:
if current_context() != start_context:
for err in changes:
print(err, file=sys.stderr)
err = "Completed %s changed context from %s to %s" % (
f,
start_context,
LoggingContext.current_context(),
current_context(),
)
# print the error to stderr because otherwise all we
# see in travis-ci is the 500 error
@@ -79,23 +79,23 @@ def do_patch():
raise Exception(err)
return res
if LoggingContext.current_context() != LoggingContext.sentinel:
if current_context():
err = (
"%s returned incomplete deferred in non-sentinel context "
"%s (start was %s)"
) % (f, LoggingContext.current_context(), start_context)
) % (f, current_context(), start_context)
print(err, file=sys.stderr)
raise Exception(err)
def check_ctx(r):
if LoggingContext.current_context() != start_context:
if current_context() != start_context:
for err in changes:
print(err, file=sys.stderr)
err = "%s completion of %s changed context from %s to %s" % (
"Failure" if isinstance(r, Failure) else "Success",
f,
start_context,
LoggingContext.current_context(),
current_context(),
)
print(err, file=sys.stderr)
raise Exception(err)
@@ -127,7 +127,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
function
"""
from synapse.logging.context import LoggingContext
from synapse.logging.context import current_context
@functools.wraps(f)
def check_yield_points_inner(*args, **kwargs):
@@ -136,7 +136,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
last_yield_line_no = gen.gi_frame.f_lineno
result = None # type: Any
while True:
expected_context = LoggingContext.current_context()
expected_context = current_context()
try:
isFailure = isinstance(result, Failure)
@@ -145,7 +145,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
else:
d = gen.send(result)
except (StopIteration, defer._DefGen_Return) as e:
if LoggingContext.current_context() != expected_context:
if current_context() != expected_context:
# This happens when the context is lost sometime *after* the
# final yield and returning. E.g. we forgot to yield on a
# function that returns a deferred.
@@ -159,7 +159,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
% (
f.__qualname__,
expected_context,
LoggingContext.current_context(),
current_context(),
f.__code__.co_filename,
last_yield_line_no,
)
@@ -173,13 +173,13 @@ def _check_yield_points(f: Callable, changes: List[str]):
# This happens if we yield on a deferred that doesn't follow
# the log context rules without wrapping in a `make_deferred_yieldable`.
# We raise here as this should never happen.
if LoggingContext.current_context() is not LoggingContext.sentinel:
if current_context():
err = (
"%s yielded with context %s rather than sentinel,"
" yielded on line %d in %s"
% (
frame.f_code.co_name,
LoggingContext.current_context(),
current_context(),
frame.f_lineno,
frame.f_code.co_filename,
)
@@ -191,7 +191,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
except Exception as e:
result = Failure(e)
if LoggingContext.current_context() != expected_context:
if current_context() != expected_context:
# This happens because the context is lost sometime *after* the
# previous yield and *after* the current yield. E.g. the
@@ -206,7 +206,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
% (
frame.f_code.co_name,
expected_context,
LoggingContext.current_context(),
current_context(),
last_yield_line_no,
frame.f_lineno,
frame.f_code.co_filename,

View File

@@ -21,9 +21,9 @@ from tests import unittest
class DatabaseConfigTestCase(unittest.TestCase):
def test_database_configured_correctly_no_database_conf_param(self):
def test_database_configured_correctly(self):
conf = yaml.safe_load(
DatabaseConfig().generate_config_section("/data_dir_path", None)
DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
)
expected_database_conf = {
@@ -32,21 +32,3 @@ class DatabaseConfigTestCase(unittest.TestCase):
}
self.assertEqual(conf["database"], expected_database_conf)
def test_database_configured_correctly_database_conf_param(self):
database_conf = {
"name": "my super fast datastore",
"args": {
"user": "matrix",
"password": "synapse_database_password",
"host": "synapse_database_host",
"database": "matrix",
},
}
conf = yaml.safe_load(
DatabaseConfig().generate_config_section("/data_dir_path", database_conf)
)
self.assertEqual(conf["database"], database_conf)

View File

@@ -34,6 +34,7 @@ from synapse.crypto.keyring import (
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.storage.keys import FetchKeyResult
@@ -83,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
def check_context(self, _, expected):
self.assertEquals(
getattr(LoggingContext.current_context(), "request", None), expected
)
self.assertEquals(getattr(current_context(), "request", None), expected)
def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1)
@@ -105,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
self.assertEquals(LoggingContext.current_context().request, "11")
self.assertEquals(current_context().request, "11")
with PreserveLoggingContext():
yield persp_deferred
return persp_resp

View File

@@ -38,7 +38,7 @@ from synapse.http.federation.well_known_resolver import (
WellKnownResolver,
_cache_period_from_headers,
)
from synapse.logging.context import LoggingContext
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
@@ -155,7 +155,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
_check_logcontext(LoggingContext.sentinel)
_check_logcontext(SENTINEL_CONTEXT)
try:
fetch_res = yield fetch_d
@@ -1197,7 +1197,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
def _check_logcontext(context):
current = LoggingContext.current_context()
current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))

View File

@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
from synapse.logging.context import LoggingContext
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from tests import unittest
from tests.utils import MockClock
@@ -54,12 +54,12 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertNoResult(resolve_d)
# should have reset to the sentinel context
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
self.assertIs(current_context(), SENTINEL_CONTEXT)
result = yield resolve_d
# should have restored our context
self.assertIs(LoggingContext.current_context(), ctx)
self.assertIs(current_context(), ctx)
return result

View File

@@ -29,14 +29,14 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient,
MatrixFederationRequest,
)
from synapse.logging.context import LoggingContext
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
def check_logcontext(context):
current = LoggingContext.current_context()
current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
@@ -64,7 +64,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
check_logcontext(LoggingContext.sentinel)
check_logcontext(SENTINEL_CONTEXT)
try:
fetch_res = yield fetch_d

View File

@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from synapse.replication.tcp.commands import ReplicateCommand
@@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = server_factory.streamer
server = server_factory.buildProtocol(None)
self.server = server_factory.buildProtocol(None)
# build a replication client, with a dummy handler
handler_factory = Mock()
self.test_handler = TestReplicationClientHandler()
self.test_handler.factory = handler_factory
self.test_handler = Mock(wraps=TestReplicationClientHandler())
self.client = ClientReplicationStreamProtocol(
"client", "test", clock, self.test_handler
hs, "client", "test", clock, self.test_handler,
)
# wire them together
self.client.makeConnection(FakeTransport(server, reactor))
server.makeConnection(FakeTransport(self.client, reactor))
self._client_transport = None
self._server_transport = None
def reconnect(self):
if self._client_transport:
self.client.close()
if self._server_transport:
self.server.close()
self._client_transport = FakeTransport(self.server, self.reactor)
self.client.makeConnection(self._client_transport)
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)
def disconnect(self):
if self._client_transport:
self._client_transport = None
self.client.close()
if self._server_transport:
self._server_transport = None
self.server.close()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump(0.1)
def replicate_stream(self, stream, token="NOW"):
def replicate_stream(self):
"""Make the client end a REPLICATE command to set up a subscription to a stream"""
self.client.send_command(ReplicateCommand(stream, token))
self.client.send_command(ReplicateCommand())
class TestReplicationClientHandler(object):
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
def __init__(self):
self.received_rdata_rows = []
self.streams = set()
self._received_rdata_rows = []
def get_streams_to_replicate(self):
return {}
positions = {s: 0 for s in self.streams}
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
def get_currently_syncing_users(self):
return []
@@ -73,6 +97,9 @@ class TestReplicationClientHandler(object):
def finished_connecting(self):
pass
async def on_position(self, stream_name, token):
"""Called when we get new position data."""
async def on_rdata(self, stream_name, token, rows):
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
self._received_rdata_rows.append((stream_name, token, r))

View File

@@ -12,35 +12,69 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams._base import ReceiptsStreamRow
from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
ROOM_ID = "!room:blue"
EVENT_ID = "$event:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self):
self.reconnect()
# make the client subscribe to the receipts stream
self.replicate_stream("receipts", "NOW")
self.replicate_stream()
self.test_handler.streams.add("receipts")
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
"!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
)
)
self.replicate()
# there should be one RDATA command
rdata_rows = self.test_handler.received_rdata_rows
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
self.assertEqual(rdata_rows[0][0], "receipts")
row = rdata_rows[0][2] # type: ReceiptsStreamRow
self.assertEqual(ROOM_ID, row.room_id)
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
self.assertEqual(EVENT_ID, row.event_id)
self.assertEqual("$event:blue", row.event_id)
self.assertEqual({"a": 1}, row.data)
# Now let's disconnect and insert some data.
self.disconnect()
self.test_handler.on_rdata.reset_mock()
self.get_success(
self.hs.get_datastore().insert_receipt(
"!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
)
)
self.replicate()
# Nothing should have happened as we are disconnected
self.test_handler.on_rdata.assert_not_called()
self.reconnect()
self.pump(0.1)
# We should now have caught up and get the missing data
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
self.assertEqual("!room2:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
self.assertEqual("$event2:foo", row.event_id)
self.assertEqual({"a": 2}, row.data)

View File

@@ -2,7 +2,7 @@ from mock import Mock, call
from twisted.internet import defer, reactor
from synapse.logging.context import LoggingContext
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
from synapse.util import Clock
@@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test():
with LoggingContext("c") as c1:
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertIs(LoggingContext.current_context(), c1)
self.assertIs(current_context(), c1)
self.assertEqual(res, "yay")
# run the test twice in parallel
d = defer.gatherResults([test(), test()])
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
self.assertIs(current_context(), SENTINEL_CONTEXT)
yield d
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
self.assertIs(current_context(), SENTINEL_CONTEXT)
@defer.inlineCallbacks
def test_does_not_cache_exceptions(self):
@@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e:
self.assertEqual(e.args[0], "boo")
self.assertIs(LoggingContext.current_context(), test_context)
self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertEqual(res, self.mock_http_response)
self.assertIs(LoggingContext.current_context(), test_context)
self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_does_not_cache_failures(self):
@@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e:
self.assertEqual(e.args[0], "boo")
self.assertIs(LoggingContext.current_context(), test_context)
self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertEqual(res, self.mock_http_response)
self.assertIs(LoggingContext.current_context(), test_context)
self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_cleans_up(self):

View File

@@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
@defer.inlineCallbacks
def test_get_device_updates_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments
# first add one device
device_ids1 = ["device_id0"]
yield self.store.add_device_change_to_streams(
"user_id", device_ids1, ["someotherhost"]
)
# then add 101
device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
yield self.store.add_device_change_to_streams(
"user_id", device_ids2, ["someotherhost"]
)
# then one more
device_ids3 = ["newdevice"]
yield self.store.add_device_change_to_streams(
"user_id", device_ids3, ["someotherhost"]
)
#
# now read them back.
#
# first we should get a single update
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", -1, limit=100
)
self._check_devices_in_updates(device_ids1, device_updates)
# Then we should get an empty list back as the 101 devices broke the limit
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self.assertEqual(len(device_updates), 0)
# The 101 devices should've been cleared, so we should now just get one device
# update
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self._check_devices_in_updates(device_ids3, device_updates)
def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))

View File

@@ -38,7 +38,11 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import LoggingContext
from synapse.logging.context import (
SENTINEL_CONTEXT,
current_context,
set_current_context,
)
from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -97,10 +101,10 @@ class TestCase(unittest.TestCase):
def setUp(orig):
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
if LoggingContext.current_context() is not LoggingContext.sentinel:
if current_context():
self.fail(
"Test starting with non-sentinel logging context %s"
% (LoggingContext.current_context(),)
% (current_context(),)
)
old_level = logging.getLogger().level
@@ -122,7 +126,7 @@ class TestCase(unittest.TestCase):
# force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs)
gc.collect()
LoggingContext.set_current_context(LoggingContext.sentinel)
set_current_context(SENTINEL_CONTEXT)
return ret

View File

@@ -22,8 +22,10 @@ from twisted.internet import defer, reactor
from synapse.api.errors import SynapseError
from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
@@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase):
with LoggingContext() as c1:
c1.name = "c1"
r = yield obj.fn(1)
self.assertEqual(LoggingContext.current_context(), c1)
self.assertEqual(current_context(), c1)
return r
def check_result(r):
@@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
d1.addCallback(check_result)
# and another
d2 = do_lookup()
self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
d2.addCallback(check_result)
# let the lookup complete
@@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase):
try:
d = obj.fn(1)
self.assertEqual(
LoggingContext.current_context(), LoggingContext.sentinel
current_context(), SENTINEL_CONTEXT,
)
yield d
self.fail("No exception thrown")
except SynapseError:
pass
self.assertEqual(LoggingContext.current_context(), c1)
self.assertEqual(current_context(), c1)
# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)
@@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
return d1
@@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
assert LoggingContext.current_context().request == "c1"
assert current_context().request == "c1"
# we want this to behave like an asynchronous function
yield run_on_reactor()
assert LoggingContext.current_context().request == "c1"
assert current_context().request == "c1"
return self.mock(args1, arg2)
with LoggingContext() as c1:
@@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2)
self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
r = yield d1
self.assertEqual(LoggingContext.current_context(), c1)
self.assertEqual(current_context(), c1)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: "fish", 20: "chips"})
obj.mock.reset_mock()

View File

@@ -16,7 +16,12 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.task import Clock
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
current_context,
)
from synapse.util.async_helpers import timeout_deferred
from tests.unittest import TestCase
@@ -79,10 +84,10 @@ class TimeoutDeferredTest(TestCase):
# the errbacks should be run in the test logcontext
def errback(res, deferred_name):
self.assertIs(
LoggingContext.current_context(),
current_context(),
context_one,
"errback %s run in unexpected logcontext %s"
% (deferred_name, LoggingContext.current_context()),
% (deferred_name, current_context()),
)
return res
@@ -90,7 +95,7 @@ class TimeoutDeferredTest(TestCase):
original_deferred.addErrback(errback, "orig")
timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
self.assertNoResult(timing_out_d)
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
self.assertIs(current_context(), SENTINEL_CONTEXT)
timing_out_d.addErrback(errback, "timingout")
self.clock.pump((1.0,))
@@ -99,4 +104,4 @@ class TimeoutDeferredTest(TestCase):
blocking_was_cancelled[0], "non-completing deferred was not cancelled"
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(LoggingContext.current_context(), context_one)
self.assertIs(current_context(), context_one)

View File

@@ -19,7 +19,7 @@ from six.moves import range
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
from synapse.logging.context import LoggingContext
from synapse.logging.context import LoggingContext, current_context
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
@@ -54,11 +54,11 @@ class LinearizerTestCase(unittest.TestCase):
def func(i, sleep=False):
with LoggingContext("func(%s)" % i) as lc:
with (yield linearizer.queue("")):
self.assertEqual(LoggingContext.current_context(), lc)
self.assertEqual(current_context(), lc)
if sleep:
yield Clock(reactor).sleep(0)
self.assertEqual(LoggingContext.current_context(), lc)
self.assertEqual(current_context(), lc)
func(0, sleep=True)
for i in range(1, 100):

View File

@@ -2,8 +2,10 @@ import twisted.python.failure
from twisted.internet import defer, reactor
from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
nested_logging_context,
run_in_background,
@@ -15,7 +17,7 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
self.assertEquals(LoggingContext.current_context().request, value)
self.assertEquals(current_context().request, value)
def test_with_context(self):
with LoggingContext() as context_one:
@@ -41,7 +43,7 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
def _test_run_in_background(self, function):
sentinel_context = LoggingContext.current_context()
sentinel_context = current_context()
callback_completed = [False]
@@ -71,7 +73,7 @@ class LoggingContextTestCase(unittest.TestCase):
# make sure that the context was reset before it got thrown back
# into the reactor
try:
self.assertIs(LoggingContext.current_context(), sentinel_context)
self.assertIs(current_context(), sentinel_context)
d2.callback(None)
except BaseException:
d2.errback(twisted.python.failure.Failure())
@@ -108,7 +110,7 @@ class LoggingContextTestCase(unittest.TestCase):
async def testfunc():
self._check_test_key("one")
d = Clock(reactor).sleep(0)
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
self.assertIs(current_context(), SENTINEL_CONTEXT)
await d
self._check_test_key("one")
@@ -129,14 +131,14 @@ class LoggingContextTestCase(unittest.TestCase):
reactor.callLater(0, d.callback, None)
return d
sentinel_context = LoggingContext.current_context()
sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context)
self.assertIs(current_context(), sentinel_context)
yield d1
@@ -145,14 +147,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_chained_deferreds(self):
sentinel_context = LoggingContext.current_context()
sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context)
self.assertIs(current_context(), sentinel_context)
yield d1
@@ -189,14 +191,14 @@ class LoggingContextTestCase(unittest.TestCase):
reactor.callLater(0, d.callback, None)
await d
sentinel_context = LoggingContext.current_context()
sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context)
self.assertIs(current_context(), sentinel_context)
yield d1

View File

@@ -35,7 +35,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
from synapse.logging.context import LoggingContext
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
@@ -493,10 +493,10 @@ class MockClock(object):
return self.time() * 1000
def call_later(self, delay, callback, *args, **kwargs):
current_context = LoggingContext.current_context()
ctx = current_context()
def wrapped_callback():
LoggingContext.thread_local.current_context = current_context
set_current_context(ctx)
callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False]