1
0

Compare commits

..

1 Commits

Author SHA1 Message Date
Andrew Morgan
5bd0d19b95 Add push debug logging 2020-03-30 14:05:23 +01:00
151 changed files with 2165 additions and 4248 deletions

View File

@@ -2,6 +2,7 @@
- [Installing Synapse](#installing-synapse)
- [Installing from source](#installing-from-source)
- [Platform-Specific Instructions](#platform-specific-instructions)
- [Troubleshooting Installation](#troubleshooting-installation)
- [Prebuilt packages](#prebuilt-packages)
- [Setting up Synapse](#setting-up-synapse)
- [TLS certificates](#tls-certificates)
@@ -9,7 +10,6 @@
- [Registering a user](#registering-a-user)
- [Setting up a TURN server](#setting-up-a-turn-server)
- [URL previews](#url-previews)
- [Troubleshooting Installation](#troubleshooting-installation)
# Choosing your server name
@@ -70,7 +70,7 @@ pip install -U matrix-synapse
```
Before you can start Synapse, you will need to generate a configuration
file. To do this, run (in your virtualenv, as before):
file. To do this, run (in your virtualenv, as before)::
```
cd ~/synapse
@@ -84,24 +84,22 @@ python -m synapse.app.homeserver \
... substituting an appropriate value for `--server-name`.
This command will generate you a config file that you can then customise, but it will
also generate a set of keys for you. These keys will allow your homeserver to
identify itself to other homeserver, so don't lose or delete them. It would be
also generate a set of keys for you. These keys will allow your Home Server to
identify itself to other Home Servers, so don't lose or delete them. It would be
wise to back them up somewhere safe. (If, for whatever reason, you do need to
change your homeserver's keys, you may find that other homeserver have the
change your Home Server's keys, you may find that other Home Servers have the
old key cached. If you update the signing key, you should change the name of the
key in the `<server name>.signing.key` file (the second word) to something
different. See the
[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys)
for more information on key management).
for more information on key management.)
To actually run your new homeserver, pick a working directory for Synapse to
run (e.g. `~/synapse`), and:
run (e.g. `~/synapse`), and::
```
cd ~/synapse
source env/bin/activate
synctl start
```
cd ~/synapse
source env/bin/activate
synctl start
### Platform-Specific Instructions
@@ -112,7 +110,7 @@ Installing prerequisites on Ubuntu or Debian:
```
sudo apt-get install build-essential python3-dev libffi-dev \
python3-pip python3-setuptools sqlite3 \
libssl-dev virtualenv libjpeg-dev libxslt1-dev
libssl-dev python3-virtualenv libjpeg-dev libxslt1-dev
```
#### ArchLinux
@@ -190,7 +188,7 @@ doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
There is currently no port for OpenBSD. Additionally, OpenBSD's security
settings require a slightly more difficult installation process.
(XXX: I suspect this is out of date)
XXX: I suspect this is out of date.
1. Create a new directory in `/usr/local` called `_synapse`. Also, create a
new user called `_synapse` and set that directory as the new user's home.
@@ -198,7 +196,7 @@ settings require a slightly more difficult installation process.
write and execute permissions on the same memory space to be run from
`/usr/local`.
2. `su` to the new `_synapse` user and change to their home directory.
3. Create a new virtualenv: `virtualenv -p python3 ~/.synapse`
3. Create a new virtualenv: `virtualenv -p python2.7 ~/.synapse`
4. Source the virtualenv configuration located at
`/usr/local/_synapse/.synapse/bin/activate`. This is done in `ksh` by
using the `.` command, rather than `bash`'s `source`.
@@ -219,6 +217,45 @@ be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for
Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server
for Windows Server.
### Troubleshooting Installation
XXX a bunch of this is no longer relevant.
Synapse requires pip 8 or later, so if your OS provides too old a version you
may need to manually upgrade it::
sudo pip install --upgrade pip
Installing may fail with `Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)`.
You can fix this by manually upgrading pip and virtualenv::
sudo pip install --upgrade virtualenv
You can next rerun `virtualenv -p python3 synapse` to update the virtual env.
Installing may fail during installing virtualenv with `InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.`
You can fix this by manually installing ndg-httpsclient::
pip install --upgrade ndg-httpsclient
Installing may fail with `mock requires setuptools>=17.1. Aborting installation`.
You can fix this by upgrading setuptools::
pip install --upgrade setuptools
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
pip install twisted
## Prebuilt packages
As an alternative to installing from source, prebuilt packages are available
@@ -277,7 +314,7 @@ For `buster` and `sid`, Synapse is available in the Debian repositories and
it should be possible to install it with simply:
```
sudo apt install matrix-synapse
sudo apt install matrix-synapse
```
There is also a version of `matrix-synapse` in `stretch-backports`. Please see
@@ -338,17 +375,15 @@ sudo pip install py-bcrypt
Synapse can be found in the void repositories as 'synapse':
```
xbps-install -Su
xbps-install -S synapse
```
xbps-install -Su
xbps-install -S synapse
### FreeBSD
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
- Packages: `pkg install py37-matrix-synapse`
- Packages: `pkg install py27-matrix-synapse`
### NixOS
@@ -385,7 +420,6 @@ so, you will need to edit `homeserver.yaml`, as follows:
resources:
- names: [client, federation]
```
* You will also need to uncomment the `tls_certificate_path` and
`tls_private_key_path` lines under the `TLS` section. You can either
point these settings at an existing certificate and key, or you can
@@ -401,7 +435,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
`cert.pem`).
For a more detailed guide to configuring your server for federation, see
[federate.md](docs/federate.md).
[federate.md](docs/federate.md)
## Email
@@ -448,7 +482,7 @@ on your server even if `enable_registration` is `false`.
## Setting up a TURN server
For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
## URL previews
@@ -457,24 +491,10 @@ turn it on you must enable the `url_preview_enabled: True` config parameter
and explicitly specify the IP ranges that Synapse is not allowed to spider for
previewing in the `url_preview_ip_range_blacklist` configuration parameter.
This is critical from a security perspective to stop arbitrary Matrix users
spidering 'internal' URLs on your network. At the very least we recommend that
spidering 'internal' URLs on your network. At the very least we recommend that
your loopback and RFC1918 IP addresses are blacklisted.
This also requires the optional `lxml` and `netaddr` python dependencies to be
installed. This in turn requires the `libxml2` library to be available - on
This also requires the optional lxml and netaddr python dependencies to be
installed. This in turn requires the libxml2 library to be available - on
Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for
your OS.
# Troubleshooting Installation
`pip` seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.:
```
pip install twisted
```
If you have any other problems, feel free to ask in
[#synapse:matrix.org](https://matrix.to/#/#synapse:matrix.org).

View File

@@ -1 +0,0 @@
Don't attempt to use an invalid sqlite config if no database configuration is provided. Contributed by @nekatak.

View File

@@ -1 +0,0 @@
Fix single-sign on with CAS systems: pass the same service URL when requesting the CAS ticket and when calling the `proxyValidate` URL. Contributed by @Naugrimm.

View File

@@ -1 +0,0 @@
Update Debian installation instructions to recommend installing the `virtualenv` package instead of `python3-virtualenv`.

View File

@@ -1 +0,0 @@
Improve the documentation for database configuration.

View File

@@ -1 +0,0 @@
Set `Referrer-Policy` header to `no-referrer` on media downloads.

View File

@@ -1 +0,0 @@
Change device list streams to have one row per ID.

View File

@@ -1 +0,0 @@
Remove concept of a non-limited stream.

View File

@@ -1 +0,0 @@
Move catchup of replication streams logic to worker.

View File

@@ -1 +0,0 @@
Admin API `POST /_synapse/admin/v1/join/<roomIdOrAlias>` to join users to a room like `auto_join_rooms` for creation of users.

View File

@@ -1 +0,0 @@
Ensure that a user inteactive authentication session is tied to a single request.

View File

@@ -1 +0,0 @@
Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors.

View File

@@ -1 +0,0 @@
Add options to prevent users from changing their profile or associated 3PIDs.

View File

@@ -1 +0,0 @@
Update pre-built package name for FreeBSD.

View File

@@ -1 +0,0 @@
Return the proper error (M_BAD_ALIAS) when a non-existant canonical alias is provided.

View File

@@ -1 +0,0 @@
Convert some of synapse.rest.media to async/await.

View File

@@ -1 +0,0 @@
De-duplicate / remove unused REST code for login and auth.

View File

@@ -1 +0,0 @@
Convert `*StreamRow` classes to inner classes.

View File

@@ -1 +0,0 @@
Fix a bug which meant that groups updates were not correctly replicated between workers.

View File

@@ -1 +0,0 @@
Allow server admins to define and enforce a password policy (MSC2000).

View File

@@ -1 +0,0 @@
Clean up some LoggingContext code.

View File

@@ -1 +0,0 @@
Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage.

View File

@@ -1 +0,0 @@
Fix starting workers when federation sending not split out.

View File

@@ -1 +0,0 @@
Refactor replication command handling to reduce difference between server vs client.

View File

@@ -1 +0,0 @@
Refactored the CAS authentication logic to a separate class.

View File

@@ -1 +0,0 @@
Remove nonfunctional `captcha_bypass_secret` option from `homeserver.yaml`.

View File

@@ -1 +0,0 @@
Clean up INSTALL.md a bit.

View File

@@ -1 +0,0 @@
Add documentation for running a local CAS server for testing.

View File

@@ -1 +0,0 @@
Ensure `is_verified` is a boolean in responses to `GET /_matrix/client/r0/room_keys/keys`. Also warn the user if they forgot the `version` query param.

View File

@@ -1 +0,0 @@
Fix error page being shown when a custom SAML handler attempted to redirect when processing an auth response.

View File

@@ -1 +0,0 @@
Improve the support for SSO authentication on the login fallback page.

View File

@@ -1 +0,0 @@
Always whitelist the login fallback in the SSO configuration if `public_baseurl` is set.

View File

@@ -1 +0,0 @@
Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo.

View File

@@ -1 +0,0 @@
Add tests for outbound device pokes.

View File

@@ -1 +0,0 @@
Always send users their own device updates.

View File

@@ -1,34 +0,0 @@
# Edit Room Membership API
This API allows an administrator to join an user account with a given `user_id`
to a room with a given `room_id_or_alias`. You can only modify the membership of
local users. The server administrator must be in the room and have permission to
invite users.
## Parameters
The following parameters are available:
* `user_id` - Fully qualified user: for example, `@user:server.com`.
* `room_id_or_alias` - The room identifier or alias to join: for example,
`!636q39766251:server.com`.
## Usage
```
POST /_synapse/admin/v1/join/<room_id_or_alias>
{
"user_id": "@user:server.com"
}
```
Including an `access_token` of a server admin.
Response:
```
{
"room_id": "!636q39766251:server.com"
}
```

View File

@@ -1,64 +0,0 @@
# How to test CAS as a developer without a server
The [django-mama-cas](https://github.com/jbittel/django-mama-cas) project is an
easy to run CAS implementation built on top of Django.
## Prerequisites
1. Create a new virtualenv: `python3 -m venv <your virtualenv>`
2. Activate your virtualenv: `source /path/to/your/virtualenv/bin/activate`
3. Install Django and django-mama-cas:
```
python -m pip install "django<3" "django-mama-cas==2.4.0"
```
4. Create a Django project in the current directory:
```
django-admin startproject cas_test .
```
5. Follow the [install directions](https://django-mama-cas.readthedocs.io/en/latest/installation.html#configuring) for django-mama-cas
6. Setup the SQLite database: `python manage.py migrate`
7. Create a user:
```
python manage.py createsuperuser
```
1. Use whatever you want as the username and password.
2. Leave the other fields blank.
8. Use the built-in Django test server to serve the CAS endpoints on port 8000:
```
python manage.py runserver
```
You should now have a Django project configured to serve CAS authentication with
a single user created.
## Configure Synapse (and Riot) to use CAS
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
running Django test server:
```yaml
cas_config:
enabled: true
server_url: "http://localhost:8000"
service_url: "http://localhost:8081"
#displayname_attribute: name
#required_attributes:
# name: value
```
2. Restart Synapse.
Note that the above configuration assumes the homeserver is running on port 8081
and that the CAS server is on port 8000, both on localhost.
## Testing the configuration
Then in Riot:
1. Visit the login page with a Riot pointing at your homeserver.
2. Click the Single Sign-On button.
3. Login using the credentials created with `createsuperuser`.
4. You should be logged in.
If you want to repeat this process you'll need to manually logout first:
1. http://localhost:8000/admin/
2. Click "logout" in the top right.

View File

@@ -18,13 +18,9 @@ To make Synapse (and therefore Riot) use it:
metadata:
local: ["samling.xml"]
```
5. Ensure that your `homeserver.yaml` has a setting for `public_baseurl`:
```yaml
public_baseurl: http://localhost:8080/
```
6. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure
5. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure
the dependencies are installed and ready to go.
7. Restart Synapse.
6. Restart Synapse.
Then in Riot:

View File

@@ -29,13 +29,14 @@ from synapse.logging import context # omitted from future snippets
def handle_request(request_id):
request_context = context.LoggingContext()
calling_context = context.set_current_context(request_context)
calling_context = context.LoggingContext.current_context()
context.LoggingContext.set_current_context(request_context)
try:
request_context.request = request_id
do_request_handling()
logger.debug("finished")
finally:
context.set_current_context(calling_context)
context.LoggingContext.set_current_context(calling_context)
def do_request_handling():
logger.debug("phew") # this will be logged against request_id

View File

@@ -72,7 +72,8 @@ underneath the database, or if a different version of the locale is used on any
replicas.
The safest way to fix the issue is to take a dump and recreate the database with
the correct `COLLATE` and `CTYPE` parameters (as shown above). It is also possible to change the
the correct `COLLATE` and `CTYPE` parameters (as per
[docs/postgres.md](docs/postgres.md)). It is also possible to change the
parameters on a live database and run a `REINDEX` on the entire database,
however extreme care must be taken to avoid database corruption.
@@ -104,41 +105,19 @@ of free memory the database host has available.
When you are ready to start using PostgreSQL, edit the `database`
section in your config file to match the following lines:
```yaml
database:
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
```
database:
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
All key, values in `args` are passed to the `psycopg2.connect(..)`
function, except keys beginning with `cp_`, which are consumed by the
twisted adbapi connection pool. See the [libpq
documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS)
for a list of options which can be passed.
You should consider tuning the `args.keepalives_*` options if there is any danger of
the connection between your homeserver and database dropping, otherwise Synapse
may block for an extended period while it waits for a response from the
database server. Example values might be:
```yaml
# seconds of inactivity after which TCP should send a keepalive message to the server
keepalives_idle: 10
# the number of seconds after which a TCP keepalive message that is not
# acknowledged by the server should be retransmitted
keepalives_interval: 10
# the number of TCP keepalives that can be lost before the client's connection
# to the server is considered dead
keepalives_count: 3
```
twisted adbapi connection pool.
## Porting from SQLite

View File

@@ -578,46 +578,13 @@ acme:
## Database ##
# The 'database' setting defines the database that synapse uses to store all of
# its data.
#
# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
# 'psycopg2' (for PostgreSQL).
#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:
# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
#
#
# Example SQLite configuration:
#
#database:
# name: sqlite3
# args:
# database: /path/to/homeserver.db
#
#
# Example Postgres configuration:
#
#database:
# name: psycopg2
# args:
# user: synapse
# password: secretpassword
# database: synapse
# host: localhost
# cp_min: 5
# cp_max: 10
#
# For more information on using Synapse with Postgres, see `docs/postgres.md`.
#
database:
name: sqlite3
# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
database: DATADIR/homeserver.db
# Path to the database
database: "DATADIR/homeserver.db"
# Number of events to cache in memory.
#
@@ -872,6 +839,10 @@ media_store_path: "DATADIR/media_store"
#
#enable_registration_captcha: false
# A secret key used to bypass the captcha test entirely.
#
#captcha_bypass_secret: "YOUR_SECRET_HERE"
# The API endpoint to use for verifying m.login.recaptcha responses.
#
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
@@ -1086,29 +1057,6 @@ account_threepid_delegates:
#email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
# Whether users are allowed to change their displayname after it has
# been initially set. Useful when provisioning users based on the
# contents of a third-party directory.
#
# Does not apply to server administrators. Defaults to 'true'
#
#enable_set_displayname: false
# Whether users are allowed to change their avatar after it has been
# initially set. Useful when provisioning users based on the contents
# of a third-party directory.
#
# Does not apply to server administrators. Defaults to 'true'
#
#enable_set_avatar_url: false
# Whether users can change the 3PIDs associated with their accounts
# (email address and msisdn).
#
# Defaults to 'true'
#
#enable_3pid_changes: false
# Users who register on this homeserver will automatically be joined
# to these rooms
#
@@ -1444,10 +1392,6 @@ sso:
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
# If public_baseurl is set, then the login fallback page (used by clients
# that don't natively support the required login flows) is whitelisted in
# addition to any URLs in this list.
#
# By default, this list is empty.
#
#client_whitelist:
@@ -1509,41 +1453,6 @@ password_config:
#
#pepper: "EVEN_MORE_SECRET"
# Define and enforce a password policy. Each parameter is optional.
# This is an implementation of MSC2000.
#
policy:
# Whether to enforce the password policy.
# Defaults to 'false'.
#
#enabled: true
# Minimum accepted length for a password.
# Defaults to 0.
#
#minimum_length: 15
# Whether a password must contain at least one digit.
# Defaults to 'false'.
#
#require_digit: true
# Whether a password must contain at least one symbol.
# A symbol is any character that's not a number or a letter.
# Defaults to 'false'.
#
#require_symbol: true
# Whether a password must contain at least one lowercase letter.
# Defaults to 'false'.
#
#require_lowercase: true
# Whether a password must contain at least one lowercase letter.
# Defaults to 'false'.
#
#require_uppercase: true
# Configuration for sending emails from Synapse.
#

View File

@@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and
'<' worker to master flows):
> SERVER example.com
< REPLICATE
> POSITION events 53
< REPLICATE events 53
> RDATA events 54 ["$foo1:bar.com", ...]
> RDATA events 55 ["$foo4:bar.com", ...]
The example shows the server accepting a new connection and sending its identity
with the `SERVER` command, followed by the client server to respond with the
position of all streams. The server then periodically sends `RDATA` commands
which have the format `RDATA <stream_name> <token> <row>`, where the format of
`<row>` is defined by the individual streams.
The example shows the server accepting a new connection and sending its
identity with the `SERVER` command, followed by the client asking to
subscribe to the `events` stream from the token `53`. The server then
periodically sends `RDATA` commands which have the format
`RDATA <stream_name> <token> <row>`, where the format of `<row>` is
defined by the individual streams.
Error reporting happens by either the client or server sending an ERROR
command, and usually the connection will be closed.
@@ -32,6 +32,9 @@ Since the protocol is a simple line based, its possible to manually
connect to the server using a tool like netcat. A few things should be
noted when manually using the protocol:
- When subscribing to a stream using `REPLICATE`, the special token
`NOW` can be used to get all future updates. The special stream name
`ALL` can be used with `NOW` to subscribe to all available streams.
- The federation stream is only available if federation sending has
been disabled on the main process.
- The server will only time connections out that have sent a `PING`
@@ -88,7 +91,9 @@ The client:
- Sends a `NAME` command, allowing the server to associate a human
friendly name with the connection. This is optional.
- Sends a `PING` as above
- Sends a `REPLICATE` to get the current position of all streams.
- For each stream the client wishes to subscribe to it sends a
`REPLICATE` with the `stream_name` and token it wants to subscribe
from.
- On receipt of a `SERVER` command, checks that the server name
matches the expected server name.
@@ -135,7 +140,9 @@ the wire:
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE
< REPLICATE events 1
< REPLICATE backfill 1
< REPLICATE caches 1
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -174,9 +181,9 @@ client (C):
#### POSITION (S)
On receipt of a POSITION command clients should check if they have missed any
updates, and if so then fetch them out of band. Sent in response to a
REPLICATE command (but can happen at any time).
The position of the stream has been updated. Sent to the client
after all missing updates for a stream have been sent to the client
and they're now up to date.
#### ERROR (S, C)
@@ -192,18 +199,25 @@ client (C):
#### REPLICATE (C)
Asks the server for the current position of all streams.
Asks the server to replicate a given stream. The syntax is:
```
REPLICATE <stream_name> <token>
```
Where `<token>` may be either:
* a numeric stream_id to stream updates since (exclusive)
* `NOW` to stream all subsequent updates.
The `<stream_name>` is the name of a replication stream to subscribe
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
of streams). It can also be `ALL` to subscribe to all known streams,
in which case the `<token>` must be set to `NOW`.
#### USER_SYNC (C)
A user has started or stopped syncing
#### CLEAR_USER_SYNC (C)
The server should clear all associated user sync data from the worker.
This is used when a worker is shutting down.
#### FEDERATION_ACK (C)
Acknowledge receipt of some federation data

View File

@@ -64,13 +64,6 @@ class Codes(object):
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
WEAK_PASSWORD = "M_WEAK_PASSWORD"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
BAD_ALIAS = "M_BAD_ALIAS"
@@ -446,20 +439,6 @@ class IncompatibleRoomVersionError(SynapseError):
return cs_error(self.msg, self.errcode, room_version=self._room_version)
class PasswordRefusedError(SynapseError):
"""A password has been refused, either during password reset/change or registration.
"""
def __init__(
self,
msg="This password doesn't comply with the server's policy",
errcode=Codes.WEAK_PASSWORD,
):
super(PasswordRefusedError, self).__init__(
code=400, msg=msg, errcode=errcode,
)
class RequestSendFailed(RuntimeError):
"""Sending a HTTP request over federation failed due to not being able to
talk to the remote server for some reason.

View File

@@ -64,26 +64,13 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientFactory
from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.handler import WorkerReplicationDataHandler
from synapse.replication.tcp.streams import (
AccountDataStream,
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams._base import (
DeviceListsStream,
GroupServerStream,
PresenceStream,
PushersStream,
PushRulesStream,
ReceiptsStream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@@ -126,6 +113,7 @@ from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.generic_worker")
@@ -234,7 +222,6 @@ class GenericWorkerPresence(object):
self.user_to_num_current_syncs = {}
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -247,24 +234,13 @@ class GenericWorkerPresence(object):
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
hs.get_reactor().addSystemEventTrigger(
"before",
"shutdown",
run_as_background_process,
"generic_presence.on_shutdown",
self._on_shutdown,
)
def _on_shutdown(self):
if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_command(
ClearUserSyncsCommand(self.instance_id)
)
self.process_id = random_string(16)
logger.info("Presence process_id is %r", self.process_id)
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_user_sync(
self.instance_id, user_id, is_syncing, last_sync_ms
user_id, is_syncing, last_sync_ms
)
def mark_as_coming_online(self, user_id):
@@ -414,9 +390,6 @@ class GenericWorkerTyping(object):
self._room_serials[row.room_id] = token
self._room_typing[row.room_id] = row.user_ids
def get_current_token(self) -> int:
return self._latest_room_serial
class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
@@ -599,26 +572,25 @@ class GenericWorkerServer(HomeServer):
else:
logger.warning("Unrecognized listener type: %s", listener["type"])
factory = ReplicationClientFactory(self, self.config.worker_name)
host = self.config.worker_replication_host
port = self.config.worker_replication_port
self.get_reactor().connectTCP(host, port, factory)
self.get_tcp_replication().start_replication(self)
def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
def build_tcp_replication(self):
return GenericWorkerReplicationHandler(self)
def build_presence_handler(self):
return GenericWorkerPresence(self)
def build_typing_handler(self):
return GenericWorkerTyping(self)
def build_replication_data_handler(self):
return GenericWorkerReplicationHandler(self)
class GenericWorkerReplicationHandler(WorkerReplicationDataHandler):
class GenericWorkerReplicationHandler(ReplicationClientHandler):
def __init__(self, hs):
super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
# NB this is a SynchrotronPresence, not a normal PresenceHandler
@@ -646,12 +618,15 @@ class GenericWorkerReplicationHandler(WorkerReplicationDataHandler):
args.update(self.send_handler.stream_positions())
return args
def get_currently_syncing_users(self):
return self.presence_handler.get_currently_syncing_users()
async def process_and_notify(self, stream_name, token, rows):
try:
if self.send_handler:
self.send_handler.process_replication_rows(stream_name, token, rows)
if stream_name == EventsStream.NAME:
if stream_name == "events":
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
for row in rows:
@@ -674,44 +649,43 @@ class GenericWorkerReplicationHandler(WorkerReplicationDataHandler):
)
await self.pusher_pool.on_new_notifications(token, token)
elif stream_name == PushRulesStream.NAME:
elif stream_name == "push_rules":
self.notifier.on_new_event(
"push_rules_key", token, users=[row.user_id for row in rows]
)
elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
elif stream_name in ("account_data", "tag_account_data"):
self.notifier.on_new_event(
"account_data_key", token, users=[row.user_id for row in rows]
)
elif stream_name == ReceiptsStream.NAME:
elif stream_name == "receipts":
self.notifier.on_new_event(
"receipt_key", token, rooms=[row.room_id for row in rows]
)
await self.pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
elif stream_name == TypingStream.NAME:
elif stream_name == "typing":
self.typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows]
)
elif stream_name == ToDeviceStream.NAME:
elif stream_name == "to_device":
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
self.notifier.on_new_event("to_device_key", token, users=entities)
elif stream_name == DeviceListsStream.NAME:
elif stream_name == "device_lists":
all_room_ids = set()
for row in rows:
if row.entity.startswith("@"):
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
room_ids = await self.store.get_rooms_for_user(row.user_id)
all_room_ids.update(room_ids)
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
elif stream_name == PresenceStream.NAME:
elif stream_name == "presence":
await self.presence_handler.process_replication_rows(token, rows)
elif stream_name == GroupServerStream.NAME:
elif stream_name == "receipts":
self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows]
)
elif stream_name == PushersStream.NAME:
elif stream_name == "pushers":
for row in rows:
if row.deleted:
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
@@ -800,10 +774,7 @@ class FederationSenderHandler(object):
# ... as well as device updates and messages
elif stream_name == DeviceListsStream.NAME:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
hosts = {row.destination for row in rows}
for host in hosts:
self.federation_sender.send_device_messages(host)
@@ -818,7 +789,7 @@ class FederationSenderHandler(object):
async def _on_new_receipts(self, rows):
"""
Args:
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
new receipts to be processed
"""
for receipt in rows:
@@ -889,9 +860,6 @@ def start(config_options):
# Force the appservice to start since they will be disabled in the main config
config.notify_appservices = True
else:
# For other worker types we force this to off.
config.notify_appservices = False
if config.worker_app == "synapse.app.pusher":
if config.start_pushers:
@@ -905,9 +873,6 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.start_pushers = True
else:
# For other worker types we force this to off.
config.start_pushers = False
if config.worker_app == "synapse.app.user_dir":
if config.update_user_directory:
@@ -921,9 +886,6 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.update_user_directory = True
else:
# For other worker types we force this to off.
config.update_user_directory = False
if config.worker_app == "synapse.app.federation_sender":
if config.send_federation:
@@ -937,9 +899,6 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.send_federation = True
else:
# For other worker types we force this to off.
config.send_federation = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts

View File

@@ -294,6 +294,7 @@ class RootConfig(object):
report_stats=None,
open_private_ports=False,
listeners=None,
database_conf=None,
tls_certificate_path=None,
tls_private_key_path=None,
acme_domain=None,
@@ -366,6 +367,7 @@ class RootConfig(object):
report_stats=report_stats,
open_private_ports=open_private_ports,
listeners=listeners,
database_conf=database_conf,
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,

View File

@@ -24,6 +24,7 @@ class CaptchaConfig(Config):
self.enable_registration_captcha = config.get(
"enable_registration_captcha", False
)
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config.get(
"recaptcha_siteverify_api",
"https://www.recaptcha.net/recaptcha/api/siteverify",
@@ -48,6 +49,10 @@ class CaptchaConfig(Config):
#
#enable_registration_captcha: false
# A secret key used to bypass the captcha test entirely.
#
#captcha_bypass_secret: "YOUR_SECRET_HERE"
# The API endpoint to use for verifying m.login.recaptcha responses.
#
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"

View File

@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,65 +14,14 @@
# limitations under the License.
import logging
import os
from textwrap import indent
import yaml
from synapse.config._base import Config, ConfigError
logger = logging.getLogger(__name__)
NON_SQLITE_DATABASE_PATH_WARNING = """\
Ignoring 'database_path' setting: not using a sqlite3 database.
--------------------------------------------------------------------------------
"""
DEFAULT_CONFIG = """\
## Database ##
# The 'database' setting defines the database that synapse uses to store all of
# its data.
#
# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
# 'psycopg2' (for PostgreSQL).
#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:
# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
#
#
# Example SQLite configuration:
#
#database:
# name: sqlite3
# args:
# database: /path/to/homeserver.db
#
#
# Example Postgres configuration:
#
#database:
# name: psycopg2
# args:
# user: synapse
# password: secretpassword
# database: synapse
# host: localhost
# cp_min: 5
# cp_max: 10
#
# For more information on using Synapse with Postgres, see `docs/postgres.md`.
#
database:
name: sqlite3
args:
database: %(database_path)s
# Number of events to cache in memory.
#
#event_cache_size: 10K
"""
class DatabaseConnectionConfig:
"""Contains the connection config for a particular database.
@@ -88,12 +36,10 @@ class DatabaseConnectionConfig:
"""
def __init__(self, name: str, db_config: dict):
db_engine = db_config.get("name", "sqlite3")
if db_config["name"] not in ("sqlite3", "psycopg2"):
raise ConfigError("Unsupported database type %r" % (db_config["name"],))
if db_engine not in ("sqlite3", "psycopg2"):
raise ConfigError("Unsupported database type %r" % (db_engine,))
if db_engine == "sqlite3":
if db_config["name"] == "sqlite3":
db_config.setdefault("args", {}).update(
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
)
@@ -110,11 +56,6 @@ class DatabaseConnectionConfig:
class DatabaseConfig(Config):
section = "database"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.databases = []
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
@@ -135,13 +76,12 @@ class DatabaseConfig(Config):
multi_database_config = config.get("databases")
database_config = config.get("database")
database_path = config.get("database_path")
if multi_database_config and database_config:
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
if multi_database_config:
if database_path:
if config.get("database_path"):
raise ConfigError("Can't specify 'database_path' with 'databases'")
self.databases = [
@@ -149,55 +89,65 @@ class DatabaseConfig(Config):
for name, db_conf in multi_database_config.items()
]
if database_config:
else:
if database_config is None:
database_config = {"name": "sqlite3", "args": {}}
self.databases = [DatabaseConnectionConfig("master", database_config)]
if database_path:
if self.databases and self.databases[0].name != "sqlite3":
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
return
self.set_databasepath(config.get("database_path"))
database_config = {"name": "sqlite3", "args": {}}
self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(database_path)
def generate_config_section(self, data_dir_path, database_conf, **kwargs):
if not database_conf:
database_path = os.path.join(data_dir_path, "homeserver.db")
database_conf = (
"""# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
# Path to the database
database: "%(database_path)s"
"""
% locals()
)
else:
database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip()
def generate_config_section(self, data_dir_path, **kwargs):
return DEFAULT_CONFIG % {
"database_path": os.path.join(data_dir_path, "homeserver.db")
}
return (
"""\
## Database ##
database:
%(database_conf)s
# Number of events to cache in memory.
#
#event_cache_size: 10K
"""
% locals()
)
def read_arguments(self, args):
"""
Cases for the cli input:
- If no databases are configured and no database_path is set, raise.
- No databases and only database_path available ==> sqlite3 db.
- If there are multiple databases and a database_path raise an error.
- If the database set in the config file is sqlite then
overwrite with the command line argument.
"""
if args.database_path is None:
if not self.databases:
raise ConfigError("No database config provided")
return
if len(self.databases) == 0:
database_config = {"name": "sqlite3", "args": {}}
self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(args.database_path)
return
if self.get_single_database().name == "sqlite3":
self.set_databasepath(args.database_path)
else:
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
self.set_databasepath(args.database_path)
def set_databasepath(self, database_path):
if database_path is None:
return
if database_path != ":memory:":
database_path = self.abspath(database_path)
self.databases[0].config["args"]["database"] = database_path
# We only support setting a database path if we have a single sqlite3
# database.
if len(self.databases) != 1:
raise ConfigError("Cannot specify 'database_path' with multiple databases")
database = self.get_single_database()
if database.config["name"] != "sqlite3":
# We don't raise here as we haven't done so before for this case.
logger.warn("Ignoring 'database_path' for non-sqlite3 database")
return
database.config["args"]["database"] = database_path
@staticmethod
def add_arguments(parser):
@@ -212,7 +162,7 @@ class DatabaseConfig(Config):
def get_single_database(self) -> DatabaseConnectionConfig:
"""Returns the database if there is only one, useful for e.g. tests
"""
if not self.databases:
if len(self.databases) != 1:
raise Exception("More than one database exists")
return self.databases[0]

View File

@@ -31,10 +31,6 @@ class PasswordConfig(Config):
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
self.password_pepper = password_config.get("pepper", "")
# Password policy
self.password_policy = password_config.get("policy") or {}
self.password_policy_enabled = self.password_policy.get("enabled", False)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
@@ -52,39 +48,4 @@ class PasswordConfig(Config):
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#
#pepper: "EVEN_MORE_SECRET"
# Define and enforce a password policy. Each parameter is optional.
# This is an implementation of MSC2000.
#
policy:
# Whether to enforce the password policy.
# Defaults to 'false'.
#
#enabled: true
# Minimum accepted length for a password.
# Defaults to 0.
#
#minimum_length: 15
# Whether a password must contain at least one digit.
# Defaults to 'false'.
#
#require_digit: true
# Whether a password must contain at least one symbol.
# A symbol is any character that's not a number or a letter.
# Defaults to 'false'.
#
#require_symbol: true
# Whether a password must contain at least one lowercase letter.
# Defaults to 'false'.
#
#require_lowercase: true
# Whether a password must contain at least one lowercase letter.
# Defaults to 'false'.
#
#require_uppercase: true
"""

View File

@@ -129,10 +129,6 @@ class RegistrationConfig(Config):
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
self.enable_set_displayname = config.get("enable_set_displayname", True)
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
self.enable_3pid_changes = config.get("enable_3pid_changes", True)
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -334,29 +330,6 @@ class RegistrationConfig(Config):
#email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
# Whether users are allowed to change their displayname after it has
# been initially set. Useful when provisioning users based on the
# contents of a third-party directory.
#
# Does not apply to server administrators. Defaults to 'true'
#
#enable_set_displayname: false
# Whether users are allowed to change their avatar after it has been
# initially set. Useful when provisioning users based on the contents
# of a third-party directory.
#
# Does not apply to server administrators. Defaults to 'true'
#
#enable_set_avatar_url: false
# Whether users can change the 3PIDs associated with their accounts
# (email address and msisdn).
#
# Defaults to 'true'
#
#enable_3pid_changes: false
# Users who register on this homeserver will automatically be joined
# to these rooms
#

View File

@@ -39,17 +39,6 @@ class SSOConfig(Config):
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
# Attempt to also whitelist the server's login fallback, since that fallback sets
# the redirect URL to itself (so it can process the login token then return
# gracefully to the client). This would make it pointless to ask the user for
# confirmation, since the URL the confirmation page would be showing wouldn't be
# the client's.
# public_baseurl is an optional setting, so we only add the fallback's URL to the
# list if it's provided (because we can't figure out what that URL is otherwise).
if self.public_baseurl:
login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
self.sso_client_whitelist.append(login_fallback_url)
def generate_config_section(self, **kwargs):
return """\
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
@@ -65,10 +54,6 @@ class SSOConfig(Config):
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
# If public_baseurl is set, then the login fallback page (used by clients
# that don't natively support the required login flows) is whitelisted in
# addition to any URLs in this list.
#
# By default, this list is empty.
#
#client_whitelist:

View File

@@ -43,8 +43,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
preserve_fn,
run_in_background,
@@ -236,7 +236,7 @@ class Keyring(object):
"""
try:
ctx = current_context()
ctx = LoggingContext.current_context()
# map from server name to a set of outstanding request ids
server_to_request_ids = {}

View File

@@ -25,15 +25,19 @@ from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
)
from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.types import JsonDict, get_domain_from_id
@@ -51,15 +55,13 @@ class FederationBase(object):
self.store = hs.get_datastore()
self._clock = hs.get_clock()
def _check_sigs_and_hash(
self, room_version: RoomVersion, pdu: EventBase
) -> Deferred:
def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)
def _check_sigs_and_hashes(
self, room_version: RoomVersion, pdus: List[EventBase]
self, room_version: str, pdus: List[EventBase]
) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
@@ -78,7 +80,7 @@ class FederationBase(object):
"""
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
ctx = current_context()
ctx = LoggingContext.current_context()
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
@@ -144,7 +146,7 @@ class PduToCheckSig(
def _check_sigs_on_pdus(
keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
) -> List[Deferred]:
"""Check that the given events are correctly signed
@@ -189,6 +191,10 @@ def _check_sigs_on_pdus(
for p in pdus
]
v = KNOWN_ROOM_VERSIONS.get(room_version)
if not v:
raise RuntimeError("Unrecognized room version %s" % (room_version,))
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
@@ -198,7 +204,7 @@ def _check_sigs_on_pdus(
(
p.sender_domain,
p.redacted_pdu_json,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
@@ -221,7 +227,7 @@ def _check_sigs_on_pdus(
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
# (ie, the room version uses old-style non-hash event IDs).
if room_version.event_format == EventFormatVersions.V1:
if v.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
p
for p in pdus_to_check
@@ -233,7 +239,7 @@ def _check_sigs_on_pdus(
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id

View File

@@ -220,7 +220,8 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = await make_deferred_yieldable(
defer.gatherResults(
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
self._check_sigs_and_hashes(room_version.identifier, pdus),
consumeErrors=True,
).addErrback(unwrapFirstError)
)
@@ -290,7 +291,9 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
signed_pdu = await self._check_sigs_and_hash(
room_version.identifier, pdu
)
break
@@ -347,7 +350,7 @@ class FederationClient(FederationBase):
self,
origin: str,
pdus: List[EventBase],
room_version: RoomVersion,
room_version: str,
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
@@ -393,7 +396,7 @@ class FederationClient(FederationBase):
self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version,
room_version=room_version, # type: ignore
outlier=outlier,
timeout=10000,
)
@@ -431,7 +434,7 @@ class FederationClient(FederationBase):
]
signed_auth = await self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True, room_version=room_version
destination, auth_chain, outlier=True, room_version=room_version.identifier
)
signed_auth.sort(key=lambda e: e.depth)
@@ -658,7 +661,7 @@ class FederationClient(FederationBase):
destination,
list(pdus.values()),
outlier=True,
room_version=room_version,
room_version=room_version.identifier,
)
valid_pdus_map = {p.event_id: p for p in valid_pdus}
@@ -753,7 +756,7 @@ class FederationClient(FederationBase):
pdu = event_from_pdu_json(pdu_dict, room_version)
# Check signatures are correct.
pdu = await self._check_sigs_and_hash(room_version, pdu)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
# FIXME: We should handle signature failures more gracefully.
@@ -945,7 +948,7 @@ class FederationClient(FederationBase):
]
signed_events = await self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False, room_version=room_version
destination, events, outlier=False, room_version=room_version.identifier
)
except HttpResponseException as e:
if not e.code == 400:

View File

@@ -409,7 +409,7 @@ class FederationServer(FederationBase):
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id)
pdu = await self._check_sigs_and_hash(room_version, pdu)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
@@ -425,7 +425,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
pdu = await self._check_sigs_and_hash(room_version, pdu)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -455,7 +455,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
pdu = await self._check_sigs_and_hash(room_version, pdu)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
await self.handler.on_send_leave_request(origin, pdu)
return {}
@@ -611,7 +611,7 @@ class FederationServer(FederationBase):
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
room_version = await self.store.get_room_version(pdu.room_id)
room_version = await self.store.get_room_version_id(pdu.room_id)
# Check signature.
try:

View File

@@ -477,7 +477,7 @@ def process_rows_for_federation(transaction_queue, rows):
Args:
transaction_queue (FederationSender)
rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow))
rows (list(synapse.replication.tcp.streams.FederationStreamRow))
"""
# The federation stream contains a bunch of different types of

View File

@@ -499,13 +499,4 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self) -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return []

View File

@@ -125,11 +125,7 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def validate_user_via_ui_auth(
self,
requester: Requester,
request: SynapseRequest,
request_body: Dict[str, Any],
clientip: str,
self, requester: Requester, request_body: Dict[str, Any], clientip: str
):
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -141,8 +137,6 @@ class AuthHandler(BaseHandler):
Args:
requester: The user, as given by the access token
request: The request sent by the client.
request_body: The body of the request sent by the client
clientip: The IP address of the client.
@@ -178,9 +172,7 @@ class AuthHandler(BaseHandler):
flows = [[login_type] for login_type in self._supported_login_types]
try:
result, params, _ = yield self.check_auth(
flows, request, request_body, clientip
)
result, params, _ = yield self.check_auth(flows, request_body, clientip)
except LoginError:
# Update the ratelimite to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(
@@ -219,11 +211,7 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def check_auth(
self,
flows: List[List[str]],
request: SynapseRequest,
clientdict: Dict[str, Any],
clientip: str,
self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
):
"""
Takes a dictionary sent by the client in the login / registration
@@ -243,8 +231,6 @@ class AuthHandler(BaseHandler):
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
request: The request sent by the client.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
@@ -284,27 +270,13 @@ class AuthHandler(BaseHandler):
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a homeserver.
# Revisit: Assuming the REST APIs do sensible validation, the data
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
session["clientdict"] = clientdict
self._save_session(session)
elif "clientdict" in session:
clientdict = session["clientdict"]
# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
comparator = (request.uri, request.method, clientdict)
if "ui_auth" not in session:
session["ui_auth"] = comparator
elif session["ui_auth"] != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
)
if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session)
@@ -350,7 +322,6 @@ class AuthHandler(BaseHandler):
creds,
list(clientdict),
)
return creds, clientdict, session["id"]
ret = self._auth_dict_for_flows(flows, session)

View File

@@ -1,204 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import xml.etree.ElementTree as ET
from typing import AnyStr, Dict, Optional, Tuple
from six.moves import urllib
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
logger = logging.getLogger(__name__)
class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.
Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self._cas_displayname_attribute = hs.config.cas_displayname_attribute
self._cas_required_attributes = hs.config.cas_required_attributes
self._http_client = hs.get_proxied_http_client()
def _build_service_param(self, client_redirect_url: AnyStr) -> str:
return "%s%s?%s" % (
self._cas_service_url,
"/_matrix/client/r0/login/cas/ticket",
urllib.parse.urlencode({"redirectUrl": client_redirect_url}),
)
async def _handle_cas_response(
self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str
) -> None:
"""
Retrieves the user and display name from the CAS response and continues with the authentication.
Args:
request: The original client request.
cas_response_body: The response from the CAS server.
client_redirect_url: The URl to redirect the client to when
everything is done.
"""
user, attributes = self._parse_cas_response(cas_response_body)
displayname = attributes.pop(self._cas_displayname_attribute, None)
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
await self._on_successful_auth(user, request, client_redirect_url, displayname)
def _parse_cas_response(
self, cas_response_body: str
) -> Tuple[str, Dict[str, Optional[str]]]:
"""
Retrieve the user and other parameters from the CAS response.
Args:
cas_response_body: The response from the CAS query.
Returns:
A tuple of the user and a mapping of other attributes.
"""
user = None
attributes = {}
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
async def _on_successful_auth(
self,
username: str,
request: SynapseRequest,
client_redirect_url: str,
user_display_name: Optional[str] = None,
) -> None:
"""Called once the user has successfully authenticated with the SSO.
Registers the user if necessary, and then returns a redirect (with
a login token) to the client.
Args:
username: the remote user id. We'll map this onto
something sane for a MXID localpath.
request: the incoming request from the browser. We'll
respond to it with a redirect.
client_redirect_url: the redirect_url the client gave us when
it first started the process.
user_display_name: if set, and we have to register a new user,
we will set their displayname to this.
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name
)
self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url
)
def handle_redirect_request(self, client_redirect_url: bytes) -> bytes:
"""
Generates a URL to the CAS server where the client should be redirected.
Args:
client_redirect_url: The final URL the client should go to after the
user has negotiated SSO.
Returns:
The URL to redirect to.
"""
args = urllib.parse.urlencode(
{"service": self._build_service_param(client_redirect_url)}
)
return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii")
async def handle_ticket_request(
self, request: SynapseRequest, client_redirect_url: str, ticket: str
) -> None:
"""
Validates a CAS ticket sent by the client for login/registration.
On a successful request, writes a redirect to the request.
"""
uri = self._cas_server_url + "/proxyValidate"
args = {
"ticket": ticket,
"service": self._build_service_param(client_redirect_url),
}
try:
body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
await self._handle_cas_response(request, body, client_redirect_url)

View File

@@ -125,14 +125,8 @@ class DeviceWorkerHandler(BaseHandler):
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
tracked_users = set(users_who_share_room)
# Always tell the user about their own devices
tracked_users.add(user_id)
changed = yield self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users
from_token.device_list_key, users_who_share_room
)
# Then work out if any users have since joined
@@ -462,11 +456,7 @@ class DeviceHandler(DeviceWorkerHandler):
room_ids = yield self.store.get_rooms_for_user(user_id)
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
yield self.notifier.on_new_event(
"device_list_key", position, users=[user_id], rooms=room_ids
)
yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
if hosts:
logger.info(

View File

@@ -851,38 +851,6 @@ class EventCreationHandler(object):
self.store.remove_push_actions_from_staging, event.event_id
)
@defer.inlineCallbacks
def _validate_canonical_alias(
self, directory_handler, room_alias_str, expected_room_id
):
"""
Ensure that the given room alias points to the expected room ID.
Args:
directory_handler: The directory handler object.
room_alias_str: The room alias to check.
expected_room_id: The room ID that the alias should point to.
"""
room_alias = RoomAlias.from_string(room_alias_str)
try:
mapping = yield directory_handler.get_association(room_alias)
except SynapseError as e:
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
if e.errcode == Codes.NOT_FOUND:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
)
raise
if mapping["room_id"] != expected_room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
)
@defer.inlineCallbacks
def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
@@ -937,9 +905,15 @@ class EventCreationHandler(object):
room_alias_str = event.content.get("alias", None)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias:
yield self._validate_canonical_alias(
directory_handler, room_alias_str, event.room_id
)
room_alias = RoomAlias.from_string(room_alias_str)
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
)
# Check that alt_aliases is the proper form.
alt_aliases = event.content.get("alt_aliases", [])
@@ -957,9 +931,16 @@ class EventCreationHandler(object):
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases:
for alias_str in new_alt_aliases:
yield self._validate_canonical_alias(
directory_handler, alias_str, event.room_id
)
room_alias = RoomAlias.from_string(alias_str)
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room"
% (room_alias_str,),
Codes.BAD_ALIAS,
)
federation_handler = self.hs.get_handlers().federation_handler

View File

@@ -1,93 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from synapse.api.errors import Codes, PasswordRefusedError
logger = logging.getLogger(__name__)
class PasswordPolicyHandler(object):
def __init__(self, hs):
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
# Regexps for the spec'd policy parameters.
self.regexp_digit = re.compile("[0-9]")
self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
self.regexp_uppercase = re.compile("[A-Z]")
self.regexp_lowercase = re.compile("[a-z]")
def validate_password(self, password):
"""Checks whether a given password complies with the server's policy.
Args:
password (str): The password to check against the server's policy.
Raises:
PasswordRefusedError: The password doesn't comply with the server's policy.
"""
if not self.enabled:
return
minimum_accepted_length = self.policy.get("minimum_length", 0)
if len(password) < minimum_accepted_length:
raise PasswordRefusedError(
msg=(
"The password must be at least %d characters long"
% minimum_accepted_length
),
errcode=Codes.PASSWORD_TOO_SHORT,
)
if (
self.policy.get("require_digit", False)
and self.regexp_digit.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one digit",
errcode=Codes.PASSWORD_NO_DIGIT,
)
if (
self.policy.get("require_symbol", False)
and self.regexp_symbol.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one symbol",
errcode=Codes.PASSWORD_NO_SYMBOL,
)
if (
self.policy.get("require_uppercase", False)
and self.regexp_uppercase.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one uppercase letter",
errcode=Codes.PASSWORD_NO_UPPERCASE,
)
if (
self.policy.get("require_lowercase", False)
and self.regexp_lowercase.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one lowercase letter",
errcode=Codes.PASSWORD_NO_LOWERCASE,
)

View File

@@ -747,7 +747,7 @@ class PresenceHandler(object):
return False
async def get_all_presence_updates(self, last_id, current_id, limit):
async def get_all_presence_updates(self, last_id, current_id):
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@@ -762,7 +762,7 @@ class PresenceHandler(object):
"""
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
rows = await self.store.get_all_presence_updates(last_id, current_id)
return rows
def notify_new_event(self):

View File

@@ -157,15 +157,6 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname:
profile = yield self.store.get_profileinfo(target_user.localpart)
if profile.display_name:
raise SynapseError(
400,
"Changing display name is disabled on this server",
Codes.FORBIDDEN,
)
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -227,13 +218,6 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
if not by_admin and not self.hs.config.enable_set_avatar_url:
profile = yield self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url:
raise SynapseError(
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
)
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)

View File

@@ -26,7 +26,6 @@ from synapse.config import ConfigError
from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi
from synapse.module_api.errors import RedirectException
from synapse.types import (
UserID,
map_username_to_mxid_localpart,
@@ -120,9 +119,6 @@ class SamlHandler:
try:
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
except RedirectException:
# Raise the exception as per the wishes of the SAML module response
raise
except Exception as e:
# If decoding the response or mapping it to a user failed, then log the
# error and tell the user that something went wrong.

View File

@@ -32,7 +32,6 @@ class SetPasswordHandler(BaseHandler):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._password_policy_handler = hs.get_password_policy_handler()
@defer.inlineCallbacks
def set_password(
@@ -45,7 +44,6 @@ class SetPasswordHandler(BaseHandler):
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
self._password_policy_handler.validate_password(new_password)
password_hash = yield self._auth_handler.hash(new_password)
try:

View File

@@ -26,7 +26,7 @@ from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import FilterCollection
from synapse.events import EventBase
from synapse.logging.context import current_context
from synapse.logging.context import LoggingContext
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
@@ -301,7 +301,7 @@ class SyncHandler(object):
else:
sync_type = "incremental_sync"
context = current_context()
context = LoggingContext.current_context()
if context:
context.tag = sync_type
@@ -1143,14 +1143,9 @@ class SyncHandler(object):
user_id
)
tracked_users = set(users_who_share_room)
# Always tell the user about their own devices
tracked_users.add(user_id)
# Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = await self.store.get_users_whose_devices_changed(
since_token.device_list_key, tracked_users
since_token.device_list_key, users_who_share_room
)
# Step 1b, check for newly joined rooms

View File

@@ -15,7 +15,6 @@
import logging
from collections import namedtuple
from typing import List
from twisted.internet import defer
@@ -258,13 +257,7 @@ class TypingHandler(object):
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
async def get_all_typing_updates(
self, last_id: int, current_id: int, limit: int
) -> List[dict]:
"""Get up to `limit` typing updates between the given tokens, earliest
updates first.
"""
async def get_all_typing_updates(self, last_id, current_id):
if last_id == current_id:
return []
@@ -282,7 +275,7 @@ class TypingHandler(object):
typing = self._room_typing[room_id]
rows.append((serial, room_id, list(typing)))
rows.sort()
return rows[:limit]
return rows
def get_current_token(self):
return self._latest_room_serial

View File

@@ -19,7 +19,7 @@ import threading
from prometheus_client.core import Counter, Histogram
from synapse.logging.context import current_context
from synapse.logging.context import LoggingContext
from synapse.metrics import LaterGauge
logger = logging.getLogger(__name__)
@@ -148,7 +148,7 @@ LaterGauge(
class RequestMetrics(object):
def start(self, time_sec, name, method):
self.start = time_sec
self.start_context = current_context()
self.start_context = LoggingContext.current_context()
self.name = name
self.method = method
@@ -163,7 +163,7 @@ class RequestMetrics(object):
with _in_flight_requests_lock:
_in_flight_requests.discard(self)
context = current_context()
context = LoggingContext.current_context()
tag = ""
if context:

View File

@@ -42,7 +42,7 @@ from synapse.logging._terse_json import (
TerseJSONToConsoleLogObserver,
TerseJSONToTCPLogObserver,
)
from synapse.logging.context import current_context
from synapse.logging.context import LoggingContext
def stdlib_log_level_to_twisted(level: str) -> LogLevel:
@@ -86,7 +86,7 @@ class LogContextObserver(object):
].startswith("Timing out client"):
return
context = current_context()
context = LoggingContext.current_context()
# Copy the context information to the log event.
if context is not None:

View File

@@ -175,54 +175,7 @@ class ContextResourceUsage(object):
return res
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
class _Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = ["previous_context", "finished", "request", "scope", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.finished = False
self.request = None
self.scope = None
self.tag = None
def __str__(self):
return "sentinel"
def copy_to(self, record):
pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self):
pass
def stop(self):
pass
def add_database_transaction(self, duration_sec):
pass
def add_database_scheduled(self, sched_sec):
pass
def record_event_fetch(self, event_count):
pass
def __nonzero__(self):
return False
__bool__ = __nonzero__ # python3
SENTINEL_CONTEXT = _Sentinel()
LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
class LoggingContext(object):
@@ -246,33 +199,76 @@ class LoggingContext(object):
"_resource_usage",
"usage_start",
"main_thread",
"finished",
"alive",
"request",
"tag",
"scope",
]
thread_local = threading.local()
class Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = ["previous_context", "alive", "request", "scope", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.alive = None
self.request = None
self.scope = None
self.tag = None
def __str__(self):
return "sentinel"
def copy_to(self, record):
pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self):
pass
def stop(self):
pass
def add_database_transaction(self, duration_sec):
pass
def add_database_scheduled(self, sched_sec):
pass
def record_event_fetch(self, event_count):
pass
def __nonzero__(self):
return False
__bool__ = __nonzero__ # python3
sentinel = Sentinel()
def __init__(self, name=None, parent_context=None, request=None) -> None:
self.previous_context = current_context()
self.previous_context = LoggingContext.current_context()
self.name = name
# track the resources used by this context so far
self._resource_usage = ContextResourceUsage()
# The thread resource usage when the logcontext became active. None
# if the context is not currently active.
# If alive has the thread resource usage when the logcontext last
# became active.
self.usage_start = None
self.main_thread = get_thread_id()
self.request = None
self.tag = ""
self.alive = True
self.scope = None # type: Optional[_LogContextScope]
# keep track of whether we have hit the __exit__ block for this context
# (suggesting that the the thing that created the context thinks it should
# be finished, and that re-activating it would suggest an error).
self.finished = False
self.parent_context = parent_context
if self.parent_context is not None:
@@ -287,15 +283,44 @@ class LoggingContext(object):
return str(self.request)
return "%s@%x" % (self.name, id(self))
@classmethod
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
Returns:
LoggingContext: the current logging context
"""
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
current = cls.current_context()
if current is not context:
current.stop()
cls.thread_local.current_context = context
context.start()
return current
def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
old_context = set_current_context(self)
old_context = self.set_current_context(self)
if self.previous_context != old_context:
logger.warning(
"Expected previous context %r, found %r",
self.previous_context,
old_context,
)
self.alive = True
return self
def __exit__(self, type, value, traceback) -> None:
@@ -304,19 +329,24 @@ class LoggingContext(object):
Returns:
None to avoid suppressing any exceptions that were thrown.
"""
current = set_current_context(self.previous_context)
current = self.set_current_context(self.previous_context)
if current is not self:
if current is SENTINEL_CONTEXT:
if current is self.sentinel:
logger.warning("Expected logging context %s was lost", self)
else:
logger.warning(
"Expected logging context %s but found %s", self, current
)
self.alive = False
# the fact that we are here suggests that the caller thinks that everything
# is done and dusted for this logcontext, and further activity will not get
# recorded against the correct metrics.
self.finished = True
# if we have a parent, pass our CPU usage stats on
if self.parent_context is not None and hasattr(
self.parent_context, "_resource_usage"
):
self.parent_context._resource_usage += self._resource_usage
# reset them in case we get entered again
self._resource_usage.reset()
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
@@ -341,14 +371,9 @@ class LoggingContext(object):
logger.warning("Started logcontext %s on different thread", self)
return
if self.finished:
logger.warning("Re-starting finished log context %s", self)
# If we haven't already started record the thread resource usage so
# far
if self.usage_start:
logger.warning("Re-starting already-active log context %s", self)
else:
if not self.usage_start:
self.usage_start = get_thread_resource_usage()
def stop(self) -> None:
@@ -371,15 +396,6 @@ class LoggingContext(object):
self.usage_start = None
# if we have a parent, pass our CPU usage stats on
if self.parent_context is not None and hasattr(
self.parent_context, "_resource_usage"
):
self.parent_context._resource_usage += self._resource_usage
# reset them in case we get entered again
self._resource_usage.reset()
def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far.
@@ -393,7 +409,7 @@ class LoggingContext(object):
# If we are on the correct thread and we're currently running then we
# can include resource usage so far.
is_main_thread = get_thread_id() == self.main_thread
if self.usage_start and is_main_thread:
if self.alive and self.usage_start and is_main_thread:
utime_delta, stime_delta = self._get_cputime()
res.ru_utime += utime_delta
res.ru_stime += stime_delta
@@ -476,7 +492,7 @@ class LoggingContextFilter(logging.Filter):
Returns:
True to include the record in the log output.
"""
context = current_context()
context = LoggingContext.current_context()
for key, value in self.defaults.items():
setattr(record, key, value)
@@ -496,24 +512,27 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"]
def __init__(
self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
) -> None:
self.new_context = new_context
def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
if new_context is None:
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
else:
self.new_context = new_context
def __enter__(self) -> None:
"""Captures the current logging context"""
self.current_context = set_current_context(self.new_context)
self.current_context = LoggingContext.set_current_context(self.new_context)
if self.current_context:
self.has_parent = self.current_context.previous_context is not None
if not self.current_context.alive:
logger.debug("Entering dead context: %s", self.current_context)
def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context"""
context = set_current_context(self.current_context)
context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
if not context:
if context is LoggingContext.sentinel:
logger.warning("Expected logging context %s was lost", self.new_context)
else:
logger.warning(
@@ -522,30 +541,9 @@ class PreserveLoggingContext(object):
context,
)
_thread_local = threading.local()
_thread_local.current_context = SENTINEL_CONTEXT
def current_context() -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage"""
return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
current = current_context()
if current is not context:
current.stop()
_thread_local.current_context = context
context.start()
return current
if self.current_context is not LoggingContext.sentinel:
if not self.current_context.alive:
logger.debug("Restoring dead context: %s", self.current_context)
def nested_logging_context(
@@ -574,7 +572,7 @@ def nested_logging_context(
if parent_context is not None:
context = parent_context # type: LoggingContextOrSentinel
else:
context = current_context()
context = LoggingContext.current_context()
return LoggingContext(
parent_context=context, request=str(context.request) + "-" + suffix
)
@@ -606,7 +604,7 @@ def run_in_background(f, *args, **kwargs):
CRITICAL error about an unhandled error will be logged without much
indication about where it came from.
"""
current = current_context()
current = LoggingContext.current_context()
try:
res = f(*args, **kwargs)
except: # noqa: E722
@@ -627,7 +625,7 @@ def run_in_background(f, *args, **kwargs):
# The function may have reset the context before returning, so
# we need to restore it now.
ctx = set_current_context(current)
ctx = LoggingContext.set_current_context(current)
# The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will
@@ -676,7 +674,7 @@ def make_deferred_yieldable(deferred):
# ok, we can't be sure that a yield won't block, so let's reset the
# logcontext, and add a callback to the deferred to restore it.
prev_context = set_current_context(SENTINEL_CONTEXT)
prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
deferred.addBoth(_set_context_cb, prev_context)
return deferred
@@ -686,7 +684,7 @@ ResultT = TypeVar("ResultT")
def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context"""
set_current_context(context)
LoggingContext.set_current_context(context)
return result
@@ -754,7 +752,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception.
"""
logcontext = current_context()
logcontext = LoggingContext.current_context()
def g():
with LoggingContext(parent_context=logcontext):

View File

@@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager
import twisted
from synapse.logging.context import current_context, nested_logging_context
from synapse.logging.context import LoggingContext, nested_logging_context
logger = logging.getLogger(__name__)
@@ -49,8 +49,11 @@ class LogContextScopeManager(ScopeManager):
(Scope) : the Scope that is active, or None if not
available.
"""
ctx = current_context()
return ctx.scope
ctx = LoggingContext.current_context()
if ctx is LoggingContext.sentinel:
return None
else:
return ctx.scope
def activate(self, span, finish_on_close):
"""
@@ -67,9 +70,9 @@ class LogContextScopeManager(ScopeManager):
"""
enter_logcontext = False
ctx = current_context()
ctx = LoggingContext.current_context()
if not ctx:
if ctx is LoggingContext.sentinel:
# We don't want this scope to affect.
logger.error("Tried to activate scope outside of loggingcontext")
return Scope(None, span)

View File

@@ -375,6 +375,7 @@ class HttpPusher(object):
if not notification_dict:
return []
try:
logger.info("SENDING PUSH EVENT to %s: %s", self.url, notification_dict)
resp = yield self.http_client.post_json_get_json(
self.url, notification_dict
)

View File

@@ -21,7 +21,6 @@ from synapse.replication.http import (
membership,
register,
send_event,
streams,
)
REPLICATION_PREFIX = "/_synapse/replication"
@@ -39,4 +38,3 @@ class ReplicationRestResource(JsonResource):
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
streams.register_servlets(hs, self)

View File

@@ -1,78 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationGetStreamUpdates(ReplicationEndpoint):
"""Fetches stream updates from a server. Used for streams not persisted to
the database, e.g. typing notifications.
The API looks like:
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
200 OK
{
updates: [ ... ],
upto_token: 10,
limited: False,
}
"""
NAME = "get_repl_stream_updates"
PATH_ARGS = ("stream_name",)
METHOD = "GET"
def __init__(self, hs):
super().__init__(hs)
# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token, limit):
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
async def _handle_request(self, request, stream_name):
stream = self.streams.get(stream_name)
if stream is None:
raise SynapseError(400, "Unknown stream")
from_token = parse_integer(request, "from_token", required=True)
upto_token = parse_integer(request, "upto_token", required=True)
limit = parse_integer(request, "limit", required=True)
updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token, limit
)
return (
200,
{"updates": updates, "upto_token": upto_token, "limited": limited},
)
def register_servlets(hs, http_server):
ReplicationGetStreamUpdates(hs).register(http_server)

View File

@@ -18,10 +18,8 @@ from typing import Dict, Optional
import six
from synapse.storage.data_stores.main.cache import (
CURRENT_STATE_CACHE_NAME,
CacheInvalidationWorkerStore,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
@@ -37,7 +35,7 @@ def __func__(inp):
return inp.__func__
class BaseSlavedStore(CacheInvalidationWorkerStore):
class BaseSlavedStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
@@ -62,12 +60,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:

View File

@@ -29,13 +29,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
db_conn, "device_lists_stream", "stream_id"
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
@@ -61,27 +55,23 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
self._invalidate_caches_for_devices(token, rows)
for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
def _invalidate_caches_for_devices(self, token, rows):
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate_many((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
def _invalidate_caches_for_devices(self, token, user_id, destination):
self._device_list_stream_cache.entity_has_changed(user_id, token)
else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
self.get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))

View File

@@ -33,9 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)

View File

@@ -16,10 +16,26 @@
"""
import logging
from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.protocol import (
AbstractReplicationClientHandler,
ClientReplicationStreamProtocol,
)
from .commands import (
Command,
FederationAckCommand,
InvalidateCacheCommand,
RemoteServerUpCommand,
RemovePusherCommand,
UserIpCommand,
UserSyncCommand,
)
logger = logging.getLogger(__name__)
@@ -35,11 +51,10 @@ class ReplicationClientFactory(ReconnectingClientFactory):
initialDelay = 0.1
maxDelay = 1 # Try at least once every N seconds
def __init__(self, hs, client_name):
def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
self.client_name = client_name
self.handler = hs.get_tcp_replication()
self.handler = handler
self.server_name = hs.config.server_name
self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -50,7 +65,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
self.hs, self.client_name, self.server_name, self._clock, self.handler,
self.client_name, self.server_name, self._clock, self.handler
)
def clientConnectionLost(self, connector, reason):
@@ -60,3 +75,170 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def clientConnectionFailed(self, connector, reason):
logger.error("Failed to connect to replication: %r", reason)
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
self.connection = None
# Any pending commands to be sent once a new connection has been
# established
self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
client_name = hs.config.worker_name
self.factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
self.store.process_replication_rows(stream_name, token, rows)
async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
for the sync with the given data.
Used by tests.
"""
d = self.awaiting_syncs.pop(data, None)
if d:
d.callback(data)
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users. (Overriden by the synchrotron's only)
"""
return []
def send_command(self, cmd):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
if self.connection:
self.connection.send_command(cmd)
else:
logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
self.send_command(FederationAckCommand(token))
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
"""Poke the master that a user has started/stopped syncing.
"""
self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
def send_remove_pusher(self, app_id, push_key, user_id):
"""Poke the master to remove a pusher for a user
"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func, keys):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
"""Tell the master that the user made a request.
"""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def await_sync(self, data):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
[Not currently] used by tests.
"""
return self.awaiting_syncs.setdefault(data, defer.Deferred())
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
self.connection = connection
if connection:
for cmd in self.pending_commands:
connection.send_command(cmd)
self.pending_commands = []
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
logger.info("Finished connecting to server")
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
if self.factory:
self.factory.resetDelay()

View File

@@ -136,8 +136,8 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
On receipt of a POSITION command clients should check if they have missed
any updates, and if so then fetch them out of band.
Sent to the client after all missing updates for a stream have been sent
to the client and they're now up to date.
"""
NAME = "POSITION"
@@ -179,24 +179,42 @@ class NameCommand(Command):
class ReplicateCommand(Command):
"""Sent by the client to subscribe to streams.
"""Sent by the client to subscribe to the stream.
Format::
REPLICATE
REPLICATE <stream_name> <token>
Where <token> may be either:
* a numeric stream_id to stream updates from
* "NOW" to stream all subsequent updates.
The <stream_name> can be "ALL" to subscribe to all known streams, in which
case the <token> must be set to "NOW", i.e.::
REPLICATE ALL NOW
"""
NAME = "REPLICATE"
def __init__(self):
pass
def __init__(self, stream_name, token):
self.stream_name = stream_name
self.token = token
@classmethod
def from_line(cls, line):
return cls()
stream_name, token = line.split(" ", 1)
if token in ("NOW", "now"):
token = "NOW"
else:
token = int(token)
return cls(stream_name, token)
def to_line(self):
return ""
return " ".join((self.stream_name, str(self.token)))
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
class UserSyncCommand(Command):
@@ -207,32 +225,30 @@ class UserSyncCommand(Command):
Format::
USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
USER_SYNC <user_id> <state> <last_sync_ms>
Where <state> is either "start" or "stop"
"""
NAME = "USER_SYNC"
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
self.instance_id = instance_id
def __init__(self, user_id, is_syncing, last_sync_ms):
self.user_id = user_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls, line):
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
user_id, state, last_sync_ms = line.split(" ", 2)
if state not in ("start", "end"):
raise Exception("Invalid USER_SYNC state %r" % (state,))
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
return cls(user_id, state == "start", int(last_sync_ms))
def to_line(self):
return " ".join(
(
self.instance_id,
self.user_id,
"start" if self.is_syncing else "end",
str(self.last_sync_ms),
@@ -240,30 +256,6 @@ class UserSyncCommand(Command):
)
class ClearUserSyncsCommand(Command):
"""Sent by the client to inform the server that it should drop all
information about syncing users sent by the client.
Mainly used when client is about to shut down.
Format::
CLEAR_USER_SYNC <instance_id>
"""
NAME = "CLEAR_USER_SYNC"
def __init__(self, instance_id):
self.instance_id = instance_id
@classmethod
def from_line(cls, line):
return cls(line)
def to_line(self):
return self.instance_id
class FederationAckCommand(Command):
"""Sent by the client when it has processed up to a given point in the
federation stream. This allows the master to drop in-memory caches of the
@@ -424,7 +416,6 @@ _COMMANDS = (
InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
@@ -447,7 +438,6 @@ VALID_CLIENT_COMMANDS = (
ReplicateCommand.NAME,
PingCommand.NAME,
UserSyncCommand.NAME,
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME,

View File

@@ -1,410 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A replication client for use by synapse workers.
"""
import logging
from typing import Any, Callable, Dict, List
from prometheus_client import Counter
from synapse.metrics import LaterGauge
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
RemovePusherCommand,
ReplicateCommand,
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
class ReplicationClientHandler:
"""Handles incoming commands from replication.
Proxies data to `HomeServer.get_replication_data_handler()`.
"""
def __init__(self, hs):
self.replication_data_handler = hs.get_replication_data_handler()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
self.presence_handler = hs.get_presence_handler()
self.instance_id = hs.get_instance_id()
self._position_linearizer = Linearizer("replication_position")
self.connections = [] # type: List[Any]
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self.connections),
)
LaterGauge(
"synapse_replication_tcp_resource_connections_per_stream",
"",
["stream_name"],
lambda: {
(stream_name,): len(
[
conn
for conn in self.connections
if stream_name in conn.replication_streams
]
)
for stream_name in self.streams
},
)
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, List[Any]]
self.is_master = hs.config.worker_app is None
self.federation_sender = None
if self.is_master and not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self._server_notices_sender = None
if self.is_master:
self._server_notices_sender = hs.get_server_notices_sender()
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
def new_connection(self, connection):
self.connections.append(connection)
def lost_connection(self, connection):
try:
self.connections.remove(connection)
except ValueError:
pass
def connected(self) -> bool:
"""Do we have any replication connections open?
Used to no-op if nothing is connected.
"""
return bool(self.connections)
async def on_REPLICATE(self, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self.is_master:
return
if not self.connections:
raise Exception("Not connected")
for stream_name, stream in self.streams.items():
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))
async def on_USER_SYNC(self, cmd: UserSyncCommand):
user_sync_counter.inc()
if self.is_master:
await self.presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
if self.is_master:
await self.presence_handler.update_external_syncs_clear(cmd.instance_id)
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self.federation_sender:
self.federation_sender.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
remove_pusher_counter.inc()
if self.is_master:
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
self.notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
invalidate_cache_counter.inc()
if self.is_master:
# We invalidate the cache locally, but then also stream that to other
# workers.
await self.store.invalidate_cache_and_stream(
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self.is_master:
await self.store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
async def on_RDATA(self, cmd: RdataCommand):
stream_name = cmd.stream_name
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception("[%s] Failed to parse RDATA: %r", stream_name, cmd.row)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows)
async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name: name of the replication stream for this batch of rows
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
await self.replication_data_handler.on_rdata(stream_name, token, rows)
async def on_POSITION(self, cmd: PositionCommand):
stream = self.streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return
# We protect catching up with a linearizer in case the replicaiton
# connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name):
# Find where we previously streamed up to.
current_token = self.replication_data_handler.get_streams_to_replicate().get(
cmd.stream_name
)
if current_token is None:
logger.debug(
"Got POSITION for stream we're not subscribed to: %s",
cmd.stream_name,
)
return
# Fetch all updates between then and now.
limited = cmd.token != current_token
while limited:
updates, current_token, limited = await stream.get_updates_since(
current_token, cmd.token
)
if updates:
await self.on_rdata(
cmd.stream_name,
current_token,
[stream.parse_row(update[1]) for update in updates],
)
# We've now caught up to position sent to us, notify handler.
await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])
if rows:
await self.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
"""Called when get a new REMOTE_SERVER_UP command."""
if self.is_master:
self.notifier.notify_remote_server_up(cmd.data)
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users.
"""
return self.presence_handler.get_currently_syncing_users()
def send_command(self, cmd: Command):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
for conn in self.connections:
conn.send_command(cmd)
def send_federation_ack(self, token: int):
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
self.send_command(FederationAckCommand(token))
def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
):
"""Poke the master that a user has started/stopped syncing.
"""
self.send_command(
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
)
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
"""Poke the master to remove a pusher for a user
"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(
self,
user_id: str,
access_token: str,
ip: str,
user_agent: str,
device_id: str,
last_seen: int,
):
"""Tell the master that the user made a request.
"""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def stream_update(self, stream_name: str, token: str, data: Any):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))
class DummyReplicationDataHandler:
"""A replication data handler that simply discards all data.
"""
async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
pass
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
return {}
async def on_position(self, stream_name: str, token: int):
pass
class WorkerReplicationDataHandler:
"""A replication data handler that calls slave data stores.
"""
def __init__(self, store):
self.store = store
async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, [])

View File

@@ -35,7 +35,9 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE
< REPLICATE events 1
< REPLICATE backfill 1
< REPLICATE caches 1
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -46,40 +48,45 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
import abc
import fcntl
import logging
import struct
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set
from typing import Any, DefaultDict, Dict, List, Set, Tuple
from six import iteritems
from six import iteritems, iterkeys
from prometheus_client import Counter
from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
Command,
ErrorCommand,
NameCommand,
PingCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
SyncCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
from synapse.server import HomeServer
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@@ -120,11 +127,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
delimiter = b"\n"
# Valid commands we expect to receive
VALID_INBOUND_COMMANDS = [] # type: Collection[str]
# Valid commands we can send
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000
def __init__(self, clock, handler):
def __init__(self, clock):
self.clock = clock
self.handler = handler
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
@@ -164,8 +176,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# can time us out.
self.send_command(PingCommand(self.clock.time_msec()))
self.handler.new_connection(self)
def send_ping(self):
"""Periodically sends a ping and checks if we should close the connection
due to the other side timing out.
@@ -203,6 +213,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
line = line.decode("utf-8")
cmd_name, rest_of_line = line.split(" ", 1)
if cmd_name not in self.VALID_INBOUND_COMMANDS:
logger.error("[%s] invalid command %s", self.id(), cmd_name)
self.send_error("invalid command: %s", cmd_name)
return
self.last_received_command = self.clock.time_msec()
self.inbound_commands_counter[cmd_name] = (
@@ -234,23 +249,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
Args:
cmd: received command
"""
handled = False
# First call any command handlers on this instance. These are for TCP
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd)
handler = getattr(self, "on_%s" % (cmd.NAME,))
await handler(cmd)
def close(self):
logger.warning("[%s] Closing connection", self.id())
@@ -258,9 +258,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.transport.loseConnection()
self.on_connection_closed()
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def send_error(self, error_string, *args):
"""Send an error to remote and close the connection.
"""
@@ -382,8 +379,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.state = ConnectionStates.CLOSED
self.pending_commands = []
self.handler.lost_connection(self)
if self.transport:
self.transport.unregisterProducer()
@@ -407,66 +402,346 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(self, hs, server_name, clock, handler):
BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
def __init__(self, server_name, clock, streamer):
BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
self.server_name = server_name
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)
async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
async def on_USER_SYNC(self, cmd):
await self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
deferreds = [
run_in_background(self.subscribe_to_stream, stream, token)
for stream in iterkeys(self.streamer.streams_by_name)
]
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
await self.subscribe_to_stream(stream_name, token)
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd):
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
async def on_INVALIDATE_CACHE(self, cmd):
await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.streamer.on_remote_server_up(cmd.data)
async def on_USER_IP(self, cmd):
self.streamer.on_user_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
updates down if they have. During that time new updates for the stream
are queued and sent once we've sent down any missed updates.
"""
self.replication_streams.discard(stream_name)
self.connecting_streams.add(stream_name)
try:
# Get missing updates
updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)
# Send all the missing updates
for update in updates:
token, row = update[0], update[1]
self.send_command(RdataCommand(stream_name, token, row))
# We send a POSITION command to ensure that they have an up to
# date token (especially useful if we didn't send any updates
# above)
self.send_command(PositionCommand(stream_name, current_token))
# Now we can send any updates that came in while we were subscribing
pending_rdata = self.pending_rdata.pop(stream_name, [])
updates = []
for token, update in pending_rdata:
# If the token is null, it is part of a batch update. Batches
# are multiple updates that share a single token. To denote
# this, the token is set to None for all tokens in the batch
# except for the last. If we find a None token, we keep looking
# through tokens until we find one that is not None and then
# process all previous updates in the batch as if they had the
# final token.
if token is None:
# Store this update as part of a batch
updates.append(update)
continue
if token <= current_token:
# This update or batch of updates is older than
# current_token, dismiss it
updates = []
continue
updates.append(update)
# Send all updates that are part of this batch with the
# found token
for update in updates:
self.send_command(RdataCommand(stream_name, token, update))
# Clear stored updates
updates = []
# They're now fully subscribed
self.replication_streams.add(stream_name)
except Exception as e:
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
self.send_error("failed to handle replicate: %r", e)
finally:
self.connecting_streams.discard(stream_name)
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
if stream_name in self.replication_streams:
# The client is subscribed to the stream
self.send_command(RdataCommand(stream_name, token, data))
elif stream_name in self.connecting_streams:
# The client is being subscribed to the stream
logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
self.pending_rdata.setdefault(stream_name, []).append((token, data))
else:
# The client isn't subscribed
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
def send_sync(self, data):
self.send_command(SyncCommand(data))
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
"""
The interface for the handler that should be passed to
ClientReplicationStreamProtocol
"""
@abc.abstractmethod
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
raise NotImplementedError()
@abc.abstractmethod
async def on_position(self, stream_name, token):
"""Called when we get new position data."""
raise NotImplementedError()
@abc.abstractmethod
def on_sync(self, data):
"""Called when get a new SYNC command."""
raise NotImplementedError()
@abc.abstractmethod
async def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
raise NotImplementedError()
@abc.abstractmethod
def get_streams_to_replicate(self):
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
raise NotImplementedError()
@abc.abstractmethod
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users."""
raise NotImplementedError()
@abc.abstractmethod
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
raise NotImplementedError()
@abc.abstractmethod
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
raise NotImplementedError()
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
def __init__(
self,
hs: "HomeServer",
client_name: str,
server_name: str,
clock: Clock,
handler,
handler: AbstractReplicationClientHandler,
):
BaseReplicationStreamProtocol.__init__(self, clock, handler)
self.instance_id = hs.get_instance_id()
BaseReplicationStreamProtocol.__init__(self, clock)
self.client_name = client_name
self.server_name = server_name
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
self.handler = handler
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, List[Any]]
self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self)
self.send_command(NameCommand(self.client_name))
self.replicate()
# Once we've connected subscribe to the necessary streams
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token)
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
currently_syncing = self.handler.get_currently_syncing_users()
now = self.clock.time_msec()
for user_id in currently_syncing:
self.send_command(UserSyncCommand(user_id, True, now))
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
# This will happen if we don't actually subscribe to any streams
if not self.streams_connecting:
self.handler.finished_connecting()
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
def replicate(self):
async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
await self.handler.on_position(cmd.stream_name, cmd.token)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
"""
logger.info("[%s] Subscribing to replication streams", self.id())
if stream_name not in STREAMS_MAP:
raise Exception("Invalid stream name %r" % (stream_name,))
self.send_command(ReplicateCommand())
logger.info(
"[%s] Subscribing to replication stream: %r from %r",
self.id(),
stream_name,
token,
)
self.streams_connecting.add(stream_name)
self.send_command(ReplicateCommand(stream_name, token))
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.handler.update_connection(None)
# The following simply registers metrics for the replication connections

View File

@@ -17,7 +17,7 @@
import logging
import random
from typing import Dict
from typing import Any, List
from six import itervalues
@@ -25,16 +25,24 @@ from prometheus_client import Counter
from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure
from synapse.util.metrics import Measure, measure_func
from .protocol import ServerReplicationStreamProtocol
from .streams import STREAMS_MAP, Stream
from .streams import STREAMS_MAP
from .streams.federation import FederationStream
stream_updates_counter = Counter(
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -44,18 +52,13 @@ class ReplicationStreamProtocolFactory(Factory):
"""
def __init__(self, hs):
self.handler = hs.get_tcp_replication()
self.streamer = ReplicationStreamer(hs)
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
self.hs = hs
# Ensure the replication streamer is started if we register a
# replication server endpoint.
hs.get_replication_streamer()
def buildProtocol(self, addr):
return ServerReplicationStreamProtocol(
self.hs, self.server_name, self.clock, self.handler
self.server_name, self.clock, self.streamer
)
@@ -75,6 +78,16 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
self.connections = [] # type: List[ServerReplicationStreamProtocol]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self.connections),
)
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
@@ -86,22 +99,39 @@ class ReplicationStreamer(object):
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
LaterGauge(
"synapse_replication_tcp_resource_connections_per_stream",
"",
["stream_name"],
lambda: {
(stream_name,): len(
[
conn
for conn in self.connections
if stream_name in conn.replication_streams
]
)
for stream_name in self.streams_by_name
},
)
self.federation_sender = None
if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
self.pending_updates = False
self.client = hs.get_tcp_replication()
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
def get_streams(self) -> Dict[str, Stream]:
"""Get a mapp from stream name to stream instance.
"""
return self.streams_by_name
def on_shutdown(self):
# close all connections on shutdown
for conn in self.connections:
conn.send_error("server shutting down")
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
@@ -110,7 +140,7 @@ class ReplicationStreamer(object):
This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed
"""
if not self.client.connected():
if not self.connections:
# Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall beihind forever
for stream in self.streams:
@@ -136,6 +166,11 @@ class ReplicationStreamer(object):
self.pending_updates = False
with Measure(self.clock, "repl.stream.get_updates"):
# First we tell the streams that they should update their
# current tokens.
for stream in self.streams:
stream.advance_current_token()
all_streams = self.streams
if self._replication_torture_level is not None:
@@ -145,7 +180,7 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
if stream.last_token == stream.current_token():
if stream.last_token == stream.upto_token:
continue
if self._replication_torture_level:
@@ -157,17 +192,18 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
stream.current_token(),
stream.upto_token,
)
try:
updates, current_token, limited = await stream.get_updates()
self.pending_updates |= limited
updates, current_token = await stream.get_updates()
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
logger.debug(
"Sending %d updates", len(updates),
"Sending %d updates to %d connections",
len(updates),
len(self.connections),
)
if updates:
@@ -183,17 +219,116 @@ class ReplicationStreamer(object):
# token. See RdataCommand for more details.
batched_updates = _batch_updates(updates)
for token, row in batched_updates:
try:
self.client.stream_update(stream.NAME, token, row)
except Exception:
logger.exception("Failed to replicate")
for conn in self.connections:
for token, row in batched_updates:
try:
conn.stream_update(stream.NAME, token, row)
except Exception:
logger.exception("Failed to replicate")
logger.debug("No more pending updates, breaking poke loop")
finally:
self.pending_updates = False
self.is_looping = False
@measure_func("repl.get_stream_updates")
async def get_stream_updates(self, stream_name, token):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
stream = self.streams_by_name.get(stream_name, None)
if not stream:
raise Exception("unknown stream %s", stream_name)
return await stream.get_updates_since(token)
@measure_func("repl.federation_ack")
def federation_ack(self, token):
"""We've received an ack for federation stream from a client.
"""
federation_ack_counter.inc()
if self.federation_sender:
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
await self.presence_handler.update_external_syncs_row(
conn_id, user_id, is_syncing, last_sync_ms
)
@measure_func("repl.on_remove_pusher")
async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
"""
remove_pusher_counter.inc()
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id
)
self.notifier.on_new_replication_data()
@measure_func("repl.on_invalidate_cache")
async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
"""The client has asked us to invalidate a cache
"""
invalidate_cache_counter.inc()
# We invalidate the cache locally, but then also stream that to other
# workers.
await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
@measure_func("repl.on_user_ip")
async def on_user_ip(
self, user_id, access_token, ip, user_agent, device_id, last_seen
):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
await self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen
)
await self._server_notices_sender.on_user_ip(user_id)
@measure_func("repl.on_remote_server_up")
def on_remote_server_up(self, server: str):
self.notifier.notify_remote_server_up(server)
def send_remote_server_up(self, server: str):
for conn in self.connections:
conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
Used in tests.
"""
for conn in self.connections:
conn.send_sync(data)
def new_connection(self, connection):
"""A new client connection has been established
"""
self.connections.append(connection)
def lost_connection(self, connection):
"""A client connection has been lost
"""
try:
self.connections.remove(connection)
except ValueError:
pass
# We need to tell the presence handler that the connection has been
# lost so that it can handle any ongoing syncs on that connection.
run_as_background_process(
"update_external_syncs_clear",
self.presence_handler.update_external_syncs_clear,
connection.conn_id,
)
def _batch_updates(updates):
"""Takes a list of updates of form [(token, row)] and sets the token to

View File

@@ -25,66 +25,26 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
from typing import Dict, Type
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
CachesStream,
DeviceListsStream,
GroupServerStream,
PresenceStream,
PublicRoomsStream,
PushersStream,
PushRulesStream,
ReceiptsStream,
Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
UserSignatureStream,
)
from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.federation import FederationStream
from . import _base, events, federation
STREAMS_MAP = {
stream.NAME: stream
for stream in (
EventsStream,
BackfillStream,
PresenceStream,
TypingStream,
ReceiptsStream,
PushRulesStream,
PushersStream,
CachesStream,
PublicRoomsStream,
DeviceListsStream,
ToDeviceStream,
FederationStream,
TagAccountDataStream,
AccountDataStream,
GroupServerStream,
UserSignatureStream,
events.EventsStream,
_base.BackfillStream,
_base.PresenceStream,
_base.TypingStream,
_base.ReceiptsStream,
_base.PushRulesStream,
_base.PushersStream,
_base.CachesStream,
_base.PublicRoomsStream,
_base.DeviceListsStream,
_base.ToDeviceStream,
federation.FederationStream,
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.GroupServerStream,
_base.UserSignatureStream,
)
} # type: Dict[str, Type[Stream]]
__all__ = [
"STREAMS_MAP",
"Stream",
"BackfillStream",
"PresenceStream",
"TypingStream",
"ReceiptsStream",
"PushRulesStream",
"PushersStream",
"CachesStream",
"PublicRoomsStream",
"DeviceListsStream",
"ToDeviceStream",
"TagAccountDataStream",
"AccountDataStream",
"GroupServerStream",
"UserSignatureStream",
]
}

View File

@@ -14,40 +14,114 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import namedtuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple
from typing import Any, List, Optional
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000
BackfillStreamRow = namedtuple(
"BackfillStreamRow",
(
"event_id", # str
"room_id", # str
"type", # str
"state_key", # str, optional
"redacts", # str, optional
"relates_to", # str, optional
),
)
PresenceStreamRow = namedtuple(
"PresenceStreamRow",
(
"user_id", # str
"state", # str
"last_active_ts", # int
"last_federation_update_ts", # int
"last_user_sync_ts", # int
"status_msg", # str
"currently_active", # bool
),
)
TypingStreamRow = namedtuple(
"TypingStreamRow", ("room_id", "user_ids") # str # list(str)
)
ReceiptsStreamRow = namedtuple(
"ReceiptsStreamRow",
(
"room_id", # str
"receipt_type", # str
"user_id", # str
"event_id", # str
"data", # dict
),
)
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
PushersStreamRow = namedtuple(
"PushersStreamRow",
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
)
# Some type aliases to make things a bit easier.
# A stream position token
Token = int
@attr.s
class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
StreamRow = Tuple[Token, tuple]
Attributes:
cache_func: Name of the cached function.
keys: The entry in the cache to invalidate. If None then will
invalidate all.
invalidation_ts: Timestamp of when the invalidation took place.
"""
cache_func = attr.ib(type=str)
keys = attr.ib(type=Optional[List[Any]])
invalidation_ts = attr.ib(type=int)
PublicRoomsStreamRow = namedtuple(
"PublicRoomsStreamRow",
(
"room_id", # str
"visibility", # str
"appservice_id", # str, optional
"network_id", # str, optional
),
)
DeviceListsStreamRow = namedtuple(
"DeviceListsStreamRow", ("user_id", "destination") # str # str
)
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
AccountDataStreamRow = namedtuple(
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
)
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
class Stream(object):
"""Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last
time it was called.
time it was called up until the point `advance_current_token` was called.
"""
NAME = None # type: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod
def parse_row(cls, row):
@@ -65,56 +139,80 @@ class Stream(object):
return cls.ROW_TYPE(*row)
def __init__(self, hs):
# The token from which we last asked for updates
self.last_token = self.current_token()
# The token that we will get updates up to
self.upto_token = self.current_token()
def advance_current_token(self):
"""Updates `upto_token` to "now", which updates up until which point
get_updates[_since] will fetch rows till.
"""
self.upto_token = self.current_token()
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
self.last_token = self.current_token()
self.upto_token = self.current_token()
self.last_token = self.upto_token
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
async def get_updates(self):
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before).
since the stream was constructed if it hadn't been called before),
until the `upto_token`
Returns:
A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries, `new_last_token` is the new
position in stream, and `limited` is whether there are more updates
to fetch.
Deferred[Tuple[List[Tuple[int, Any]], int]:
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
"""
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token
)
updates, current_token = await self.get_updates_since(self.last_token)
self.last_token = current_token
return updates, current_token, limited
return updates, current_token
async def get_updates_since(
self, from_token: Token, upto_token: Token, limit: int = 100
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
async def get_updates_since(self, from_token):
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries, `new_last_token` is the new
position in stream, and `limited` is whether there are more updates
to fetch.
Deferred[Tuple[List[Tuple[int, Any]], int]:
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
"""
if from_token in ("NOW", "now"):
return [], self.upto_token
current_token = self.upto_token
from_token = int(from_token)
if from_token == upto_token:
return [], upto_token, False
if from_token == current_token:
return [], current_token
updates, upto_token, limited = await self.update_function(
from_token, upto_token, limit=limit,
)
return updates, upto_token, limited
logger.info("get_updates_since: %s", self.__class__)
if self._LIMITED:
rows = await self.update_function(
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
rows = await self.update_function(from_token, current_token)
updates = [(row[0], row[1:]) for row in rows]
# check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator.
if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
return updates, current_token
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -125,8 +223,9 @@ class Stream(object):
"""
raise NotImplementedError()
def update_function(self, from_token, current_token, limit):
"""Get updates between from_token and to_token.
def update_function(self, from_token, current_token, limit=None):
"""Get updates between from_token and to_token. If Stream._LIMITED is
True then limit is provided, otherwise it's not.
Returns:
Deferred(list(tuple)): the first entry in the tuple is the token for
@@ -136,144 +235,52 @@ class Stream(object):
raise NotImplementedError()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) == limit:
upto_token = rows[-1][0]
limited = True
return updates, upto_token, limited
return update_function
def make_http_update_function(
hs, stream_name: str
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
"""
client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function(
from_token: int, upto_token: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
return await client(
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
limit=limit,
)
return update_function
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
"""
BackfillStreamRow = namedtuple(
"BackfillStreamRow",
(
"event_id", # str
"room_id", # str
"type", # str
"state_key", # str, optional
"redacts", # str, optional
"relates_to", # str, optional
),
)
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
self.update_function = store.get_all_new_backfill_event_rows # type: ignore
super(BackfillStream, self).__init__(hs)
class PresenceStream(Stream):
PresenceStreamRow = namedtuple(
"PresenceStreamRow",
(
"user_id", # str
"state", # str
"last_active_ts", # int
"last_federation_update_ts", # int
"last_user_sync_ts", # int
"status_msg", # str
"currently_active", # bool
),
)
NAME = "presence"
_LIMITED = False
ROW_TYPE = PresenceStreamRow
def __init__(self, hs):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
self._is_worker = hs.config.worker_app is not None
self.current_token = store.get_current_presence_token # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
self.update_function = presence_handler.get_all_presence_updates # type: ignore
super(PresenceStream, self).__init__(hs)
class TypingStream(Stream):
TypingStreamRow = namedtuple(
"TypingStreamRow", ("room_id", "user_ids") # str # list(str)
)
NAME = "typing"
_LIMITED = False
ROW_TYPE = TypingStreamRow
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
self.update_function = typing_handler.get_all_typing_updates # type: ignore
super(TypingStream, self).__init__(hs)
class ReceiptsStream(Stream):
ReceiptsStreamRow = namedtuple(
"ReceiptsStreamRow",
(
"room_id", # str
"receipt_type", # str
"user_id", # str
"event_id", # str
"data", # dict
),
)
NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow
@@ -281,7 +288,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
self.update_function = store.get_all_updated_receipts # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -290,8 +297,6 @@ class PushRulesStream(Stream):
"""A user has changed their push rules
"""
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
@@ -305,24 +310,13 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
limited = False
if len(rows) == limit:
to_token = rows[-1][0]
limited = True
return [(row[0], (row[2],)) for row in rows], to_token, limited
return [(row[0], row[2]) for row in rows]
class PushersStream(Stream):
"""A user has added/changed/removed a pusher
"""
PushersStreamRow = namedtuple(
"PushersStreamRow",
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
)
NAME = "pushers"
ROW_TYPE = PushersStreamRow
@@ -330,7 +324,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
self.update_function = store.get_all_updated_pushers_rows # type: ignore
super(PushersStream, self).__init__(hs)
@@ -340,21 +334,6 @@ class CachesStream(Stream):
the cache on the workers
"""
@attr.s
class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
Attributes:
cache_func: Name of the cached function.
keys: The entry in the cache to invalidate. If None then will
invalidate all.
invalidation_ts: Timestamp of when the invalidation took place.
"""
cache_func = attr.ib(type=str)
keys = attr.ib(type=Optional[List[Any]])
invalidation_ts = attr.ib(type=int)
NAME = "caches"
ROW_TYPE = CachesStreamRow
@@ -362,7 +341,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
self.update_function = store.get_all_updated_caches # type: ignore
super(CachesStream, self).__init__(hs)
@@ -371,16 +350,6 @@ class PublicRoomsStream(Stream):
"""The public rooms list changed
"""
PublicRoomsStreamRow = namedtuple(
"PublicRoomsStreamRow",
(
"room_id", # str
"visibility", # str
"appservice_id", # str, optional
"network_id", # str, optional
),
)
NAME = "public_rooms"
ROW_TYPE = PublicRoomsStreamRow
@@ -388,28 +357,24 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
self.update_function = store.get_all_new_public_rooms # type: ignore
super(PublicRoomsStream, self).__init__(hs)
class DeviceListsStream(Stream):
"""Either a user has updated their devices or a remote server needs to be
told about a device update.
"""Someone added/changed/removed a device
"""
@attr.s
class DeviceListsStreamRow:
entity = attr.ib(type=str)
NAME = "device_lists"
_LIMITED = False
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -418,8 +383,6 @@ class ToDeviceStream(Stream):
"""New to_device messages for a client
"""
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
@@ -427,7 +390,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
self.update_function = store.get_all_new_device_messages # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -436,10 +399,6 @@ class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room
"""
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
@@ -447,7 +406,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
self.update_function = store.get_all_updated_tags # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -456,10 +415,6 @@ class AccountDataStream(Stream):
"""Global or per room account data was changed
"""
AccountDataStreamRow = namedtuple(
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
@@ -467,11 +422,10 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
async def _update_function(self, from_token, to_token, limit):
async def update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -486,11 +440,6 @@ class AccountDataStream(Stream):
class GroupServerStream(Stream):
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
NAME = "groups"
ROW_TYPE = GroupsStreamRow
@@ -498,7 +447,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
self.update_function = store.get_all_groups_changes # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -507,15 +456,14 @@ class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key
"""
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
NAME = "user_signature"
_LIMITED = False
ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
super(UserSignatureStream, self).__init__(hs)

View File

@@ -19,7 +19,7 @@ from typing import Tuple, Type
import attr
from ._base import Stream, db_query_to_update_function
from ._base import Stream
"""Handling of the 'events' replication stream
@@ -117,11 +117,10 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs)
async def _update_function(self, from_token, current_token, limit=None):
async def update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)

View File

@@ -15,9 +15,15 @@
# limitations under the License.
from collections import namedtuple
from twisted.internet import defer
from ._base import Stream
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
FederationStreamRow = namedtuple(
"FederationStreamRow",
(
"type", # str, the type of data as defined in the BaseFederationRows
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
),
)
class FederationStream(Stream):
@@ -25,28 +31,13 @@ class FederationStream(Stream):
sending disabled.
"""
FederationStreamRow = namedtuple(
"FederationStreamRow",
(
"type", # str, the type of data as defined in the BaseFederationRows
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
),
)
NAME = "federation"
ROW_TYPE = FederationStreamRow
_QUERY_MASTER = True
def __init__(self, hs):
# Not all synapse instances will have a federation sender instance,
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
# so we stub the stream out when that is the case.
if hs.config.worker_app is None or hs.should_send_federation():
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
else:
self.current_token = lambda: 0 # type: ignore
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = federation_sender.get_replication_rows # type: ignore
super(FederationStream, self).__init__(hs)

View File

@@ -41,7 +41,6 @@ from synapse.rest.client.v2_alpha import (
keys,
notifications,
openid,
password_policy,
read_marker,
receipts,
register,
@@ -119,7 +118,6 @@ class ClientRestResource(JsonResource):
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
# moving to /_synapse/admin
synapse.rest.admin.register_servlets_for_client_rest_resource(

View File

@@ -29,11 +29,7 @@ from synapse.rest.admin._base import (
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
JoinRoomAliasServlet,
ListRoomRestServlet,
ShutdownRoomRestServlet,
)
from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import (
AccountValidityRenewServlet,
@@ -193,7 +189,6 @@ def register_servlets(hs, http_server):
"""
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)

View File

@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Optional
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.api.constants import Membership
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -30,7 +29,7 @@ from synapse.rest.admin._base import (
historical_admin_path_patterns,
)
from synapse.storage.data_stores.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.types import create_requester
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -238,75 +237,3 @@ class ListRoomRestServlet(RestServlet):
response["prev_batch"] = 0
return 200, response
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_handlers().admin_handler
self.state_handler = hs.get_state_handler()
async def on_POST(self, request, room_identifier):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["user_id"])
target_user = UserID.from_string(content["user_id"])
if not self.hs.is_mine(target_user):
raise SynapseError(400, "This endpoint can only be used with local users")
if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found")
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
fake_requester = create_requester(target_user)
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
await self.room_member_handler.update_membership(
requester=requester,
target=fake_requester.user,
room_id=room_id,
action="invite",
remote_room_hosts=remote_room_hosts,
ratelimit=False,
)
await self.room_member_handler.update_membership(
requester=fake_requester,
target=fake_requester.user,
room_id=room_id,
action="join",
remote_room_hosts=remote_room_hosts,
ratelimit=False,
)
return 200, {"room_id": room_id}

View File

@@ -14,6 +14,11 @@
# limitations under the License.
import logging
import xml.etree.ElementTree as ET
from six.moves import urllib
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -23,10 +28,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.push.mailer import load_jinja2_templates
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn
logger = logging.getLogger(__name__)
@@ -397,7 +402,7 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def on_GET(self, request: SynapseRequest):
def on_GET(self, request):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
@@ -406,15 +411,15 @@ class BaseSSORedirectServlet(RestServlet):
request.redirect(sso_url)
finish_request(request)
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
def get_sso_url(self, client_redirect_url):
"""Get the URL to redirect to, to perform SSO auth
Args:
client_redirect_url: the URL that we should redirect the
client_redirect_url (bytes): the URL that we should redirect the
client to when everything is done
Returns:
URL to redirect to
bytes: URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
@@ -422,10 +427,19 @@ class BaseSSORedirectServlet(RestServlet):
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._cas_handler = hs.get_cas_handler()
super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url.encode("ascii")
self.cas_service_url = hs.config.cas_service_url.encode("ascii")
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
return self._cas_handler.handle_redirect_request(client_redirect_url)
def get_sso_url(self, client_redirect_url):
client_redirect_url_param = urllib.parse.urlencode(
{b"redirectUrl": client_redirect_url}
).encode("ascii")
hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
service_param = urllib.parse.urlencode(
{b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
).encode("ascii")
return b"%s/login?%s" % (self.cas_server_url, service_param)
class CasTicketServlet(RestServlet):
@@ -433,15 +447,81 @@ class CasTicketServlet(RestServlet):
def __init__(self, hs):
super(CasTicketServlet, self).__init__()
self._cas_handler = hs.get_cas_handler()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_proxied_http_client()
async def on_GET(self, request: SynapseRequest) -> None:
async def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True)
ticket = parse_string(request, "ticket", required=True)
await self._cas_handler.handle_ticket_request(
request, client_redirect_url, ticket
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url,
}
try:
body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
result = await self.handle_cas_response(request, body, client_redirect_url)
return result
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
displayname = attributes.pop(self.cas_displayname_attribute, None)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth(
user, request, client_redirect_url, displayname
)
def parse_cas_response(self, cas_response_body):
user = None
attributes = {}
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
class SAMLRedirectServlet(BaseSSORedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -449,10 +529,72 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
def get_sso_url(self, client_redirect_url):
return self._saml_handler.handle_redirect_request(client_redirect_url)
class SSOAuthHandler(object):
"""
Utility class for Resources and Servlets which handle the response from a SSO
service
Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator()
# Load the redirect page HTML template
self._template = load_jinja2_templates(
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
)[0]
self._server_name = hs.config.server_name
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
async def on_successful_auth(
self, username, request, client_redirect_url, user_display_name=None
):
"""Called once the user has successfully authenticated with the SSO.
Registers the user if necessary, and then returns a redirect (with
a login token) to the client.
Args:
username (unicode|bytes): the remote user id. We'll map this onto
something sane for a MXID localpath.
request (SynapseRequest): the incoming request from the browser. We'll
respond to it with a redirect.
client_redirect_url (unicode): the redirect_url the client gave us when
it first started the process.
user_display_name (unicode|None): if set, and we have to register a new user,
we will set their displayname to this.
Returns:
Deferred[none]: Completes once we have handled the request.
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name
)
self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url
)
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled:

View File

@@ -234,16 +234,13 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
params = await self.auth_handler.validate_user_via_ui_auth(
requester, request, body, self.hs.get_ip_from_request(request),
requester, body, self.hs.get_ip_from_request(request)
)
user_id = requester.user.to_string()
else:
requester = None
result, params, _ = await self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY]],
request,
body,
self.hs.get_ip_from_request(request),
[[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
)
if LoginType.EMAIL_IDENTITY in result:
@@ -311,7 +308,7 @@ class DeactivateAccountRestServlet(RestServlet):
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
requester, request, body, self.hs.get_ip_from_request(request),
requester, body, self.hs.get_ip_from_request(request)
)
result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server")
@@ -605,11 +602,6 @@ class ThreepidRestServlet(RestServlet):
return 200, {"threepids": threepids}
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -654,11 +646,6 @@ class ThreepidAddRestServlet(RestServlet):
@interactive_auth_handler
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -669,7 +656,7 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
requester, request, body, self.hs.get_ip_from_request(request),
requester, body, self.hs.get_ip_from_request(request)
)
validation_session = await self.identity_handler.validate_threepid_session(
@@ -754,16 +741,10 @@ class ThreepidDeleteRestServlet(RestServlet):
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])

View File

@@ -142,6 +142,14 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
elif stagetype == LoginType.TERMS:
html = TERMS_TEMPLATE % {
"session": session,
@@ -150,19 +158,17 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
else:
raise SynapseError(404, "Unknown auth stage type")
# Render the HTML and return.
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
async def on_POST(self, request, stagetype):
session = parse_string(request, "session")
@@ -190,6 +196,15 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
@@ -210,19 +225,17 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
else:
raise SynapseError(404, "Unknown auth stage type")
# Render the HTML and return.
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
finish_request(request)
return None
def on_OPTIONS(self, _):
return 200, {}

View File

@@ -81,7 +81,7 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
requester, request, body, self.hs.get_ip_from_request(request),
requester, body, self.hs.get_ip_from_request(request)
)
await self.device_handler.delete_devices(
@@ -127,7 +127,7 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
requester, request, body, self.hs.get_ip_from_request(request),
requester, body, self.hs.get_ip_from_request(request)
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)

View File

@@ -263,7 +263,7 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
requester, request, body, self.hs.get_ip_from_request(request),
requester, body, self.hs.get_ip_from_request(request)
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)

View File

@@ -1,58 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.http.servlet import RestServlet
from ._base import client_patterns
logger = logging.getLogger(__name__)
class PasswordPolicyServlet(RestServlet):
PATTERNS = client_patterns("/password_policy$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(PasswordPolicyServlet, self).__init__()
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
def on_GET(self, request):
if not self.enabled or not self.policy:
return (200, {})
policy = {}
for param in [
"minimum_length",
"require_digit",
"require_symbol",
"require_lowercase",
"require_uppercase",
]:
if param in self.policy:
policy["m.%s" % param] = self.policy[param]
return (200, policy)
def register_servlets(hs, http_server):
PasswordPolicyServlet(hs).register(http_server)

View File

@@ -373,7 +373,6 @@ class RegisterRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
self._registration_flows = _calculate_registration_flows(
@@ -421,7 +420,6 @@ class RegisterRestServlet(RestServlet):
or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(body["password"])
desired_username = None
if "username" in body:
@@ -501,10 +499,7 @@ class RegisterRestServlet(RestServlet):
)
auth_result, params, session_id = await self.auth_handler.check_auth(
self._registration_flows,
request,
body,
self.hs.get_ip_from_request(request),
self._registration_flows, body, self.hs.get_ip_from_request(request)
)
# Check that we're not trying to register a denied 3pid.

View File

@@ -188,7 +188,7 @@ class RoomKeysServlet(RestServlet):
"""
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version", required=True)
version = parse_string(request, "version")
room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id

View File

@@ -50,9 +50,6 @@ class DownloadResource(DirectServeResource):
b" media-src 'self';"
b" object-src 'self';",
)
request.setHeader(
b"Referrer-Policy", b"no-referrer",
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)

View File

@@ -24,6 +24,7 @@ from six import iteritems
import twisted.internet.error
import twisted.web.http
from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -113,14 +114,15 @@ class MediaRepository(object):
"update_recently_accessed_media", self._update_recently_accessed
)
async def _update_recently_accessed(self):
@defer.inlineCallbacks
def _update_recently_accessed(self):
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()
await self.store.update_cached_last_access_time(
yield self.store.update_cached_last_access_time(
local_media, remote_media, self.clock.time_msec()
)
@@ -136,7 +138,8 @@ class MediaRepository(object):
else:
self.recently_accessed_locals.add(media_id)
async def create_content(
@defer.inlineCallbacks
def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
"""Store uploaded content for a local user and return the mxc URL
@@ -155,11 +158,11 @@ class MediaRepository(object):
file_info = FileInfo(server_name=None, file_id=media_id)
fname = await self.media_storage.store_file(content, file_info)
fname = yield self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
await self.store.store_local_media(
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -168,11 +171,12 @@ class MediaRepository(object):
user_id=auth_user,
)
await self._generate_thumbnails(None, media_id, media_id, media_type)
yield self._generate_thumbnails(None, media_id, media_id, media_type)
return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media(self, request, media_id, name):
@defer.inlineCallbacks
def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
@@ -186,7 +190,7 @@ class MediaRepository(object):
Deferred: Resolves once a response has successfully been written
to request
"""
media_info = await self.store.get_local_media(media_id)
media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
@@ -200,12 +204,13 @@ class MediaRepository(object):
file_info = FileInfo(None, media_id, url_cache=url_cache)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
request, responder, media_type, media_length, upload_name
)
async def get_remote_media(self, request, server_name, media_id, name):
@defer.inlineCallbacks
def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
@@ -231,8 +236,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
server_name, media_id
)
@@ -241,13 +246,14 @@ class MediaRepository(object):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
await respond_with_responder(
yield respond_with_responder(
request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)
async def get_remote_media_info(self, server_name, media_id):
@defer.inlineCallbacks
def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -268,8 +274,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
server_name, media_id
)
@@ -280,7 +286,8 @@ class MediaRepository(object):
return media_info
async def _get_remote_media_impl(self, server_name, media_id):
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -292,7 +299,7 @@ class MediaRepository(object):
Returns:
Deferred[(Responder, media_info)]
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)
media_info = yield self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
@@ -310,18 +317,19 @@ class MediaRepository(object):
logger.info("Media is quarantined")
raise NotFoundError()
responder = await self.media_storage.fetch_media(file_info)
responder = yield self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info
# Failed to find the file anywhere, lets download it.
media_info = await self._download_remote_file(server_name, media_id, file_id)
media_info = yield self._download_remote_file(server_name, media_id, file_id)
responder = await self.media_storage.fetch_media(file_info)
responder = yield self.media_storage.fetch_media(file_info)
return responder, media_info
async def _download_remote_file(self, server_name, media_id, file_id):
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id, file_id):
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -343,7 +351,7 @@ class MediaRepository(object):
("/_matrix/media/v1/download", server_name, media_id)
)
try:
length, headers = await self.client.get_file(
length, headers = yield self.client.get_file(
server_name,
request_path,
output_stream=f,
@@ -389,7 +397,7 @@ class MediaRepository(object):
)
raise SynapseError(502, "Failed to fetch remote media")
await finish()
yield finish()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
@@ -397,7 +405,7 @@ class MediaRepository(object):
logger.info("Stored remote media in file %r", fname)
await self.store.store_cached_remote_media(
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
@@ -415,7 +423,7 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
await self._generate_thumbnails(server_name, media_id, file_id, media_type)
yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
return media_info
@@ -450,15 +458,16 @@ class MediaRepository(object):
return t_byte_source
async def generate_local_exact_thumbnail(
@defer.inlineCallbacks
def generate_local_exact_thumbnail(
self, media_id, t_width, t_height, t_method, t_type, url_cache
):
input_path = await self.media_storage.ensure_media_is_in_local_cache(
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
thumbnailer = Thumbnailer(input_path)
t_byte_source = await defer_to_thread(
t_byte_source = yield defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -481,7 +490,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
output_path = await self.media_storage.store_file(
output_path = yield self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -491,21 +500,22 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
await self.store.store_local_thumbnail(
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return output_path
async def generate_remote_exact_thumbnail(
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
):
input_path = await self.media_storage.ensure_media_is_in_local_cache(
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
thumbnailer = Thumbnailer(input_path)
t_byte_source = await defer_to_thread(
t_byte_source = yield defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -527,7 +537,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
output_path = await self.media_storage.store_file(
output_path = yield self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -537,7 +547,7 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
await self.store.store_remote_media_thumbnail(
yield self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -550,7 +560,8 @@ class MediaRepository(object):
return output_path
async def _generate_thumbnails(
@defer.inlineCallbacks
def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
"""Generate and store thumbnails for an image.
@@ -571,7 +582,7 @@ class MediaRepository(object):
if not requirements:
return
input_path = await self.media_storage.ensure_media_is_in_local_cache(
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
@@ -589,7 +600,7 @@ class MediaRepository(object):
return
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(
m_width, m_height = yield defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)
@@ -609,11 +620,11 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
t_byte_source = await defer_to_thread(
t_byte_source = yield defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
t_byte_source = await defer_to_thread(
t_byte_source = yield defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
@@ -635,7 +646,7 @@ class MediaRepository(object):
url_cache=url_cache,
)
output_path = await self.media_storage.store_file(
output_path = yield self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -645,7 +656,7 @@ class MediaRepository(object):
# Write to database
if server_name:
await self.store.store_remote_media_thumbnail(
yield self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -656,14 +667,15 @@ class MediaRepository(object):
t_len,
)
else:
await self.store.store_local_thumbnail(
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return {"width": m_width, "height": m_height}
async def delete_old_remote_media(self, before_ts):
old_media = await self.store.get_remote_media_before(before_ts)
@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
old_media = yield self.store.get_remote_media_before(before_ts)
deleted = 0
@@ -677,7 +689,7 @@ class MediaRepository(object):
# TODO: Should we delete from the backup store
with (await self.remote_media_linearizer.queue(key)):
with (yield self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
@@ -693,7 +705,7 @@ class MediaRepository(object):
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
await self.store.delete_remote_media(origin, media_id)
yield self.store.delete_remote_media(origin, media_id)
deleted += 1
return {"deleted": deleted}

View File

@@ -165,7 +165,8 @@ class PreviewUrlResource(DirectServeResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
async def _do_preview(self, url, user, ts):
@defer.inlineCallbacks
def _do_preview(self, url, user, ts):
"""Check the db, and download the URL and build a preview
Args:
@@ -178,7 +179,7 @@ class PreviewUrlResource(DirectServeResource):
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
cache_result = await self.store.get_url_cache(url, ts)
cache_result = yield self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
@@ -191,13 +192,13 @@ class PreviewUrlResource(DirectServeResource):
og = og.encode("utf8")
return og
media_info = await self._download_url(url, user)
media_info = yield self._download_url(url, user)
logger.debug("got media_info of '%s'", media_info)
if _is_media(media_info["media_type"]):
file_id = media_info["filesystem_id"]
dims = await self.media_repo._generate_thumbnails(
dims = yield self.media_repo._generate_thumbnails(
None, file_id, file_id, media_info["media_type"], url_cache=True
)
@@ -247,14 +248,14 @@ class PreviewUrlResource(DirectServeResource):
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if "og:image" in og and og["og:image"]:
image_info = await self._download_url(
image_info = yield self._download_url(
_rebase_url(og["og:image"], media_info["uri"]), user
)
if _is_media(image_info["media_type"]):
# TODO: make sure we don't choke on white-on-transparent images
file_id = image_info["filesystem_id"]
dims = await self.media_repo._generate_thumbnails(
dims = yield self.media_repo._generate_thumbnails(
None, file_id, file_id, image_info["media_type"], url_cache=True
)
if dims:
@@ -292,7 +293,7 @@ class PreviewUrlResource(DirectServeResource):
jsonog = json.dumps(og)
# store OG in history-aware DB cache
await self.store.store_url_cache(
yield self.store.store_url_cache(
url,
media_info["response_code"],
media_info["etag"],
@@ -304,7 +305,8 @@ class PreviewUrlResource(DirectServeResource):
return jsonog.encode("utf8")
async def _download_url(self, url, user):
@defer.inlineCallbacks
def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -316,7 +318,7 @@ class PreviewUrlResource(DirectServeResource):
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
logger.debug("Trying to get url '%s'", url)
length, headers, uri, code = await self.client.get_file(
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size
)
except SynapseError:
@@ -343,7 +345,7 @@ class PreviewUrlResource(DirectServeResource):
% (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
await finish()
yield finish()
try:
if b"Content-Type" in headers:
@@ -354,7 +356,7 @@ class PreviewUrlResource(DirectServeResource):
download_name = get_filename_from_headers(headers)
await self.store.store_local_media(
yield self.store.store_local_media(
media_id=file_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -391,7 +393,8 @@ class PreviewUrlResource(DirectServeResource):
"expire_url_cache_data", self._expire_url_cache_data
)
async def _expire_url_cache_data(self):
@defer.inlineCallbacks
def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
@@ -400,12 +403,12 @@ class PreviewUrlResource(DirectServeResource):
logger.info("Running url preview cache expiry")
if not (await self.store.db.updates.has_completed_background_updates()):
if not (yield self.store.db.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
# First we delete expired url cache entries
media_ids = await self.store.get_expired_url_cache(now)
media_ids = yield self.store.get_expired_url_cache(now)
removed_media = []
for media_id in media_ids:
@@ -427,7 +430,7 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
await self.store.delete_url_cache(removed_media)
yield self.store.delete_url_cache(removed_media)
if removed_media:
logger.info("Deleted %d entries from url cache", len(removed_media))
@@ -437,7 +440,7 @@ class PreviewUrlResource(DirectServeResource):
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
expire_before = now - 2 * 24 * 60 * 60 * 1000
media_ids = await self.store.get_url_cache_media_before(expire_before)
media_ids = yield self.store.get_url_cache_media_before(expire_before)
removed_media = []
for media_id in media_ids:
@@ -475,7 +478,7 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
await self.store.delete_url_cache_media(removed_media)
yield self.store.delete_url_cache_media(removed_media)
logger.info("Deleted %d media from url cache", len(removed_media))

Some files were not shown because too many files have changed in this diff Show More