Compare commits
59 Commits
anoa/debug
...
erikj/repl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11f5579129 | ||
|
|
22fc68762e | ||
|
|
4f21c33be3 | ||
|
|
07569f25d1 | ||
|
|
104844c1e1 | ||
|
|
6486c96b65 | ||
|
|
c5f89fba55 | ||
|
|
7406477525 | ||
|
|
9fc588e6dc | ||
|
|
b7da598a61 | ||
|
|
84f7eaed16 | ||
|
|
fb69690761 | ||
|
|
8327eb9280 | ||
|
|
ae219fb411 | ||
|
|
12aa5a7fa7 | ||
|
|
fbf0782c63 | ||
|
|
16ee97988a | ||
|
|
a07e03ce90 | ||
|
|
d9965fb8d6 | ||
|
|
09cc058a4c | ||
|
|
665630fcaa | ||
|
|
7496d3d2f6 | ||
|
|
fa4f12102d | ||
|
|
55ca6cf88c | ||
|
|
825fb5d0a5 | ||
|
|
060e7dce09 | ||
|
|
e8e2ddb60a | ||
|
|
1c1242acba | ||
|
|
6ca5e56fd1 | ||
|
|
4cff617df1 | ||
|
|
7bab642707 | ||
|
|
b1cfaf08af | ||
|
|
28d9d6e8a9 | ||
|
|
39230d2171 | ||
|
|
1fcf9c6f95 | ||
|
|
d6828c129f | ||
|
|
c816072d47 | ||
|
|
190ab593b7 | ||
|
|
e341518f92 | ||
|
|
a564b92d37 | ||
|
|
5126cb1253 | ||
|
|
229eb81498 | ||
|
|
b3cee0ce67 | ||
|
|
96071eea8f | ||
|
|
477c4f5b1c | ||
|
|
c165c1233b | ||
|
|
fdb1344716 | ||
|
|
caec7d4fa0 | ||
|
|
c2db6599c8 | ||
|
|
a319cb1dd1 | ||
|
|
6e6476ef07 | ||
|
|
4ce50519cd | ||
|
|
65a941d1f8 | ||
|
|
e53744c737 | ||
|
|
f70f44abc7 | ||
|
|
59ad93d2a4 | ||
|
|
9ce4e344a8 | ||
|
|
f5caa1864e | ||
|
|
c3c6c0e622 |
102
INSTALL.md
102
INSTALL.md
@@ -2,7 +2,6 @@
|
||||
- [Installing Synapse](#installing-synapse)
|
||||
- [Installing from source](#installing-from-source)
|
||||
- [Platform-Specific Instructions](#platform-specific-instructions)
|
||||
- [Troubleshooting Installation](#troubleshooting-installation)
|
||||
- [Prebuilt packages](#prebuilt-packages)
|
||||
- [Setting up Synapse](#setting-up-synapse)
|
||||
- [TLS certificates](#tls-certificates)
|
||||
@@ -10,6 +9,7 @@
|
||||
- [Registering a user](#registering-a-user)
|
||||
- [Setting up a TURN server](#setting-up-a-turn-server)
|
||||
- [URL previews](#url-previews)
|
||||
- [Troubleshooting Installation](#troubleshooting-installation)
|
||||
|
||||
# Choosing your server name
|
||||
|
||||
@@ -70,7 +70,7 @@ pip install -U matrix-synapse
|
||||
```
|
||||
|
||||
Before you can start Synapse, you will need to generate a configuration
|
||||
file. To do this, run (in your virtualenv, as before)::
|
||||
file. To do this, run (in your virtualenv, as before):
|
||||
|
||||
```
|
||||
cd ~/synapse
|
||||
@@ -84,22 +84,24 @@ python -m synapse.app.homeserver \
|
||||
... substituting an appropriate value for `--server-name`.
|
||||
|
||||
This command will generate you a config file that you can then customise, but it will
|
||||
also generate a set of keys for you. These keys will allow your Home Server to
|
||||
identify itself to other Home Servers, so don't lose or delete them. It would be
|
||||
also generate a set of keys for you. These keys will allow your homeserver to
|
||||
identify itself to other homeserver, so don't lose or delete them. It would be
|
||||
wise to back them up somewhere safe. (If, for whatever reason, you do need to
|
||||
change your Home Server's keys, you may find that other Home Servers have the
|
||||
change your homeserver's keys, you may find that other homeserver have the
|
||||
old key cached. If you update the signing key, you should change the name of the
|
||||
key in the `<server name>.signing.key` file (the second word) to something
|
||||
different. See the
|
||||
[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys)
|
||||
for more information on key management.)
|
||||
for more information on key management).
|
||||
|
||||
To actually run your new homeserver, pick a working directory for Synapse to
|
||||
run (e.g. `~/synapse`), and::
|
||||
run (e.g. `~/synapse`), and:
|
||||
|
||||
cd ~/synapse
|
||||
source env/bin/activate
|
||||
synctl start
|
||||
```
|
||||
cd ~/synapse
|
||||
source env/bin/activate
|
||||
synctl start
|
||||
```
|
||||
|
||||
### Platform-Specific Instructions
|
||||
|
||||
@@ -110,7 +112,7 @@ Installing prerequisites on Ubuntu or Debian:
|
||||
```
|
||||
sudo apt-get install build-essential python3-dev libffi-dev \
|
||||
python3-pip python3-setuptools sqlite3 \
|
||||
libssl-dev python3-virtualenv libjpeg-dev libxslt1-dev
|
||||
libssl-dev virtualenv libjpeg-dev libxslt1-dev
|
||||
```
|
||||
|
||||
#### ArchLinux
|
||||
@@ -188,7 +190,7 @@ doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
|
||||
There is currently no port for OpenBSD. Additionally, OpenBSD's security
|
||||
settings require a slightly more difficult installation process.
|
||||
|
||||
XXX: I suspect this is out of date.
|
||||
(XXX: I suspect this is out of date)
|
||||
|
||||
1. Create a new directory in `/usr/local` called `_synapse`. Also, create a
|
||||
new user called `_synapse` and set that directory as the new user's home.
|
||||
@@ -196,7 +198,7 @@ XXX: I suspect this is out of date.
|
||||
write and execute permissions on the same memory space to be run from
|
||||
`/usr/local`.
|
||||
2. `su` to the new `_synapse` user and change to their home directory.
|
||||
3. Create a new virtualenv: `virtualenv -p python2.7 ~/.synapse`
|
||||
3. Create a new virtualenv: `virtualenv -p python3 ~/.synapse`
|
||||
4. Source the virtualenv configuration located at
|
||||
`/usr/local/_synapse/.synapse/bin/activate`. This is done in `ksh` by
|
||||
using the `.` command, rather than `bash`'s `source`.
|
||||
@@ -217,45 +219,6 @@ be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for
|
||||
Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server
|
||||
for Windows Server.
|
||||
|
||||
### Troubleshooting Installation
|
||||
|
||||
XXX a bunch of this is no longer relevant.
|
||||
|
||||
Synapse requires pip 8 or later, so if your OS provides too old a version you
|
||||
may need to manually upgrade it::
|
||||
|
||||
sudo pip install --upgrade pip
|
||||
|
||||
Installing may fail with `Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)`.
|
||||
You can fix this by manually upgrading pip and virtualenv::
|
||||
|
||||
sudo pip install --upgrade virtualenv
|
||||
|
||||
You can next rerun `virtualenv -p python3 synapse` to update the virtual env.
|
||||
|
||||
Installing may fail during installing virtualenv with `InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.`
|
||||
You can fix this by manually installing ndg-httpsclient::
|
||||
|
||||
pip install --upgrade ndg-httpsclient
|
||||
|
||||
Installing may fail with `mock requires setuptools>=17.1. Aborting installation`.
|
||||
You can fix this by upgrading setuptools::
|
||||
|
||||
pip install --upgrade setuptools
|
||||
|
||||
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
|
||||
refuse to run until you remove the temporary installation directory it
|
||||
created. To reset the installation::
|
||||
|
||||
rm -rf /tmp/pip_install_matrix
|
||||
|
||||
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||
happens, you will have to individually install the dependencies which are
|
||||
failing, e.g.::
|
||||
|
||||
pip install twisted
|
||||
|
||||
## Prebuilt packages
|
||||
|
||||
As an alternative to installing from source, prebuilt packages are available
|
||||
@@ -314,7 +277,7 @@ For `buster` and `sid`, Synapse is available in the Debian repositories and
|
||||
it should be possible to install it with simply:
|
||||
|
||||
```
|
||||
sudo apt install matrix-synapse
|
||||
sudo apt install matrix-synapse
|
||||
```
|
||||
|
||||
There is also a version of `matrix-synapse` in `stretch-backports`. Please see
|
||||
@@ -375,15 +338,17 @@ sudo pip install py-bcrypt
|
||||
|
||||
Synapse can be found in the void repositories as 'synapse':
|
||||
|
||||
xbps-install -Su
|
||||
xbps-install -S synapse
|
||||
```
|
||||
xbps-install -Su
|
||||
xbps-install -S synapse
|
||||
```
|
||||
|
||||
### FreeBSD
|
||||
|
||||
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
|
||||
|
||||
- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
|
||||
- Packages: `pkg install py27-matrix-synapse`
|
||||
- Packages: `pkg install py37-matrix-synapse`
|
||||
|
||||
|
||||
### NixOS
|
||||
@@ -420,6 +385,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
|
||||
resources:
|
||||
- names: [client, federation]
|
||||
```
|
||||
|
||||
* You will also need to uncomment the `tls_certificate_path` and
|
||||
`tls_private_key_path` lines under the `TLS` section. You can either
|
||||
point these settings at an existing certificate and key, or you can
|
||||
@@ -435,7 +401,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
|
||||
`cert.pem`).
|
||||
|
||||
For a more detailed guide to configuring your server for federation, see
|
||||
[federate.md](docs/federate.md)
|
||||
[federate.md](docs/federate.md).
|
||||
|
||||
|
||||
## Email
|
||||
@@ -482,7 +448,7 @@ on your server even if `enable_registration` is `false`.
|
||||
## Setting up a TURN server
|
||||
|
||||
For reliable VoIP calls to be routed via this homeserver, you MUST configure
|
||||
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
|
||||
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
|
||||
|
||||
## URL previews
|
||||
|
||||
@@ -491,10 +457,24 @@ turn it on you must enable the `url_preview_enabled: True` config parameter
|
||||
and explicitly specify the IP ranges that Synapse is not allowed to spider for
|
||||
previewing in the `url_preview_ip_range_blacklist` configuration parameter.
|
||||
This is critical from a security perspective to stop arbitrary Matrix users
|
||||
spidering 'internal' URLs on your network. At the very least we recommend that
|
||||
spidering 'internal' URLs on your network. At the very least we recommend that
|
||||
your loopback and RFC1918 IP addresses are blacklisted.
|
||||
|
||||
This also requires the optional lxml and netaddr python dependencies to be
|
||||
installed. This in turn requires the libxml2 library to be available - on
|
||||
This also requires the optional `lxml` and `netaddr` python dependencies to be
|
||||
installed. This in turn requires the `libxml2` library to be available - on
|
||||
Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for
|
||||
your OS.
|
||||
|
||||
# Troubleshooting Installation
|
||||
|
||||
`pip` seems to leak *lots* of memory during installation. For instance, a Linux
|
||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||
happens, you will have to individually install the dependencies which are
|
||||
failing, e.g.:
|
||||
|
||||
```
|
||||
pip install twisted
|
||||
```
|
||||
|
||||
If you have any other problems, feel free to ask in
|
||||
[#synapse:matrix.org](https://matrix.to/#/#synapse:matrix.org).
|
||||
|
||||
1
changelog.d/6573.bugfix
Normal file
1
changelog.d/6573.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Don't attempt to use an invalid sqlite config if no database configuration is provided. Contributed by @nekatak.
|
||||
1
changelog.d/6634.bugfix
Normal file
1
changelog.d/6634.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix single-sign on with CAS systems: pass the same service URL when requesting the CAS ticket and when calling the `proxyValidate` URL. Contributed by @Naugrimm.
|
||||
1
changelog.d/6892.doc
Normal file
1
changelog.d/6892.doc
Normal file
@@ -0,0 +1 @@
|
||||
Update Debian installation instructions to recommend installing the `virtualenv` package instead of `python3-virtualenv`.
|
||||
1
changelog.d/6988.doc
Normal file
1
changelog.d/6988.doc
Normal file
@@ -0,0 +1 @@
|
||||
Improve the documentation for database configuration.
|
||||
1
changelog.d/7009.feature
Normal file
1
changelog.d/7009.feature
Normal file
@@ -0,0 +1 @@
|
||||
Set `Referrer-Policy` header to `no-referrer` on media downloads.
|
||||
1
changelog.d/7010.misc
Normal file
1
changelog.d/7010.misc
Normal file
@@ -0,0 +1 @@
|
||||
Change device list streams to have one row per ID.
|
||||
1
changelog.d/7011.misc
Normal file
1
changelog.d/7011.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove concept of a non-limited stream.
|
||||
1
changelog.d/7024.misc
Normal file
1
changelog.d/7024.misc
Normal file
@@ -0,0 +1 @@
|
||||
Move catchup of replication streams logic to worker.
|
||||
1
changelog.d/7051.feature
Normal file
1
changelog.d/7051.feature
Normal file
@@ -0,0 +1 @@
|
||||
Admin API `POST /_synapse/admin/v1/join/<roomIdOrAlias>` to join users to a room like `auto_join_rooms` for creation of users.
|
||||
1
changelog.d/7068.bugfix
Normal file
1
changelog.d/7068.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Ensure that a user inteactive authentication session is tied to a single request.
|
||||
1
changelog.d/7089.bugfix
Normal file
1
changelog.d/7089.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors.
|
||||
1
changelog.d/7096.feature
Normal file
1
changelog.d/7096.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add options to prevent users from changing their profile or associated 3PIDs.
|
||||
1
changelog.d/7107.doc
Normal file
1
changelog.d/7107.doc
Normal file
@@ -0,0 +1 @@
|
||||
Update pre-built package name for FreeBSD.
|
||||
1
changelog.d/7109.bugfix
Normal file
1
changelog.d/7109.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Return the proper error (M_BAD_ALIAS) when a non-existant canonical alias is provided.
|
||||
1
changelog.d/7110.misc
Normal file
1
changelog.d/7110.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert some of synapse.rest.media to async/await.
|
||||
1
changelog.d/7115.misc
Normal file
1
changelog.d/7115.misc
Normal file
@@ -0,0 +1 @@
|
||||
De-duplicate / remove unused REST code for login and auth.
|
||||
1
changelog.d/7116.misc
Normal file
1
changelog.d/7116.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert `*StreamRow` classes to inner classes.
|
||||
1
changelog.d/7117.bugfix
Normal file
1
changelog.d/7117.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug which meant that groups updates were not correctly replicated between workers.
|
||||
1
changelog.d/7118.feature
Normal file
1
changelog.d/7118.feature
Normal file
@@ -0,0 +1 @@
|
||||
Allow server admins to define and enforce a password policy (MSC2000).
|
||||
1
changelog.d/7120.misc
Normal file
1
changelog.d/7120.misc
Normal file
@@ -0,0 +1 @@
|
||||
Clean up some LoggingContext code.
|
||||
1
changelog.d/7128.misc
Normal file
1
changelog.d/7128.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage.
|
||||
1
changelog.d/7133.bugfix
Normal file
1
changelog.d/7133.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix starting workers when federation sending not split out.
|
||||
1
changelog.d/7134.misc
Normal file
1
changelog.d/7134.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor replication command handling to reduce difference between server vs client.
|
||||
1
changelog.d/7136.misc
Normal file
1
changelog.d/7136.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactored the CAS authentication logic to a separate class.
|
||||
1
changelog.d/7137.removal
Normal file
1
changelog.d/7137.removal
Normal file
@@ -0,0 +1 @@
|
||||
Remove nonfunctional `captcha_bypass_secret` option from `homeserver.yaml`.
|
||||
1
changelog.d/7141.doc
Normal file
1
changelog.d/7141.doc
Normal file
@@ -0,0 +1 @@
|
||||
Clean up INSTALL.md a bit.
|
||||
1
changelog.d/7147.doc
Normal file
1
changelog.d/7147.doc
Normal file
@@ -0,0 +1 @@
|
||||
Add documentation for running a local CAS server for testing.
|
||||
1
changelog.d/7150.bugfix
Normal file
1
changelog.d/7150.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Ensure `is_verified` is a boolean in responses to `GET /_matrix/client/r0/room_keys/keys`. Also warn the user if they forgot the `version` query param.
|
||||
1
changelog.d/7151.bugfix
Normal file
1
changelog.d/7151.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix error page being shown when a custom SAML handler attempted to redirect when processing an auth response.
|
||||
1
changelog.d/7152.feature
Normal file
1
changelog.d/7152.feature
Normal file
@@ -0,0 +1 @@
|
||||
Improve the support for SSO authentication on the login fallback page.
|
||||
1
changelog.d/7153.feature
Normal file
1
changelog.d/7153.feature
Normal file
@@ -0,0 +1 @@
|
||||
Always whitelist the login fallback in the SSO configuration if `public_baseurl` is set.
|
||||
1
changelog.d/7155.bugfix
Normal file
1
changelog.d/7155.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo.
|
||||
1
changelog.d/7157.misc
Normal file
1
changelog.d/7157.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add tests for outbound device pokes.
|
||||
1
changelog.d/7160.feature
Normal file
1
changelog.d/7160.feature
Normal file
@@ -0,0 +1 @@
|
||||
Always send users their own device updates.
|
||||
34
docs/admin_api/room_membership.md
Normal file
34
docs/admin_api/room_membership.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# Edit Room Membership API
|
||||
|
||||
This API allows an administrator to join an user account with a given `user_id`
|
||||
to a room with a given `room_id_or_alias`. You can only modify the membership of
|
||||
local users. The server administrator must be in the room and have permission to
|
||||
invite users.
|
||||
|
||||
## Parameters
|
||||
|
||||
The following parameters are available:
|
||||
|
||||
* `user_id` - Fully qualified user: for example, `@user:server.com`.
|
||||
* `room_id_or_alias` - The room identifier or alias to join: for example,
|
||||
`!636q39766251:server.com`.
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
POST /_synapse/admin/v1/join/<room_id_or_alias>
|
||||
|
||||
{
|
||||
"user_id": "@user:server.com"
|
||||
}
|
||||
```
|
||||
|
||||
Including an `access_token` of a server admin.
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
{
|
||||
"room_id": "!636q39766251:server.com"
|
||||
}
|
||||
```
|
||||
64
docs/dev/cas.md
Normal file
64
docs/dev/cas.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# How to test CAS as a developer without a server
|
||||
|
||||
The [django-mama-cas](https://github.com/jbittel/django-mama-cas) project is an
|
||||
easy to run CAS implementation built on top of Django.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. Create a new virtualenv: `python3 -m venv <your virtualenv>`
|
||||
2. Activate your virtualenv: `source /path/to/your/virtualenv/bin/activate`
|
||||
3. Install Django and django-mama-cas:
|
||||
```
|
||||
python -m pip install "django<3" "django-mama-cas==2.4.0"
|
||||
```
|
||||
4. Create a Django project in the current directory:
|
||||
```
|
||||
django-admin startproject cas_test .
|
||||
```
|
||||
5. Follow the [install directions](https://django-mama-cas.readthedocs.io/en/latest/installation.html#configuring) for django-mama-cas
|
||||
6. Setup the SQLite database: `python manage.py migrate`
|
||||
7. Create a user:
|
||||
```
|
||||
python manage.py createsuperuser
|
||||
```
|
||||
1. Use whatever you want as the username and password.
|
||||
2. Leave the other fields blank.
|
||||
8. Use the built-in Django test server to serve the CAS endpoints on port 8000:
|
||||
```
|
||||
python manage.py runserver
|
||||
```
|
||||
|
||||
You should now have a Django project configured to serve CAS authentication with
|
||||
a single user created.
|
||||
|
||||
## Configure Synapse (and Riot) to use CAS
|
||||
|
||||
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
|
||||
running Django test server:
|
||||
```yaml
|
||||
cas_config:
|
||||
enabled: true
|
||||
server_url: "http://localhost:8000"
|
||||
service_url: "http://localhost:8081"
|
||||
#displayname_attribute: name
|
||||
#required_attributes:
|
||||
# name: value
|
||||
```
|
||||
2. Restart Synapse.
|
||||
|
||||
Note that the above configuration assumes the homeserver is running on port 8081
|
||||
and that the CAS server is on port 8000, both on localhost.
|
||||
|
||||
## Testing the configuration
|
||||
|
||||
Then in Riot:
|
||||
|
||||
1. Visit the login page with a Riot pointing at your homeserver.
|
||||
2. Click the Single Sign-On button.
|
||||
3. Login using the credentials created with `createsuperuser`.
|
||||
4. You should be logged in.
|
||||
|
||||
If you want to repeat this process you'll need to manually logout first:
|
||||
|
||||
1. http://localhost:8000/admin/
|
||||
2. Click "logout" in the top right.
|
||||
@@ -18,9 +18,13 @@ To make Synapse (and therefore Riot) use it:
|
||||
metadata:
|
||||
local: ["samling.xml"]
|
||||
```
|
||||
5. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure
|
||||
5. Ensure that your `homeserver.yaml` has a setting for `public_baseurl`:
|
||||
```yaml
|
||||
public_baseurl: http://localhost:8080/
|
||||
```
|
||||
6. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure
|
||||
the dependencies are installed and ready to go.
|
||||
6. Restart Synapse.
|
||||
7. Restart Synapse.
|
||||
|
||||
Then in Riot:
|
||||
|
||||
|
||||
@@ -29,14 +29,13 @@ from synapse.logging import context # omitted from future snippets
|
||||
def handle_request(request_id):
|
||||
request_context = context.LoggingContext()
|
||||
|
||||
calling_context = context.LoggingContext.current_context()
|
||||
context.LoggingContext.set_current_context(request_context)
|
||||
calling_context = context.set_current_context(request_context)
|
||||
try:
|
||||
request_context.request = request_id
|
||||
do_request_handling()
|
||||
logger.debug("finished")
|
||||
finally:
|
||||
context.LoggingContext.set_current_context(calling_context)
|
||||
context.set_current_context(calling_context)
|
||||
|
||||
def do_request_handling():
|
||||
logger.debug("phew") # this will be logged against request_id
|
||||
|
||||
@@ -72,8 +72,7 @@ underneath the database, or if a different version of the locale is used on any
|
||||
replicas.
|
||||
|
||||
The safest way to fix the issue is to take a dump and recreate the database with
|
||||
the correct `COLLATE` and `CTYPE` parameters (as per
|
||||
[docs/postgres.md](docs/postgres.md)). It is also possible to change the
|
||||
the correct `COLLATE` and `CTYPE` parameters (as shown above). It is also possible to change the
|
||||
parameters on a live database and run a `REINDEX` on the entire database,
|
||||
however extreme care must be taken to avoid database corruption.
|
||||
|
||||
@@ -105,19 +104,41 @@ of free memory the database host has available.
|
||||
When you are ready to start using PostgreSQL, edit the `database`
|
||||
section in your config file to match the following lines:
|
||||
|
||||
database:
|
||||
name: psycopg2
|
||||
args:
|
||||
user: <user>
|
||||
password: <pass>
|
||||
database: <db>
|
||||
host: <host>
|
||||
cp_min: 5
|
||||
cp_max: 10
|
||||
```yaml
|
||||
database:
|
||||
name: psycopg2
|
||||
args:
|
||||
user: <user>
|
||||
password: <pass>
|
||||
database: <db>
|
||||
host: <host>
|
||||
cp_min: 5
|
||||
cp_max: 10
|
||||
```
|
||||
|
||||
All key, values in `args` are passed to the `psycopg2.connect(..)`
|
||||
function, except keys beginning with `cp_`, which are consumed by the
|
||||
twisted adbapi connection pool.
|
||||
twisted adbapi connection pool. See the [libpq
|
||||
documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS)
|
||||
for a list of options which can be passed.
|
||||
|
||||
You should consider tuning the `args.keepalives_*` options if there is any danger of
|
||||
the connection between your homeserver and database dropping, otherwise Synapse
|
||||
may block for an extended period while it waits for a response from the
|
||||
database server. Example values might be:
|
||||
|
||||
```yaml
|
||||
# seconds of inactivity after which TCP should send a keepalive message to the server
|
||||
keepalives_idle: 10
|
||||
|
||||
# the number of seconds after which a TCP keepalive message that is not
|
||||
# acknowledged by the server should be retransmitted
|
||||
keepalives_interval: 10
|
||||
|
||||
# the number of TCP keepalives that can be lost before the client's connection
|
||||
# to the server is considered dead
|
||||
keepalives_count: 3
|
||||
```
|
||||
|
||||
## Porting from SQLite
|
||||
|
||||
|
||||
@@ -578,13 +578,46 @@ acme:
|
||||
|
||||
## Database ##
|
||||
|
||||
# The 'database' setting defines the database that synapse uses to store all of
|
||||
# its data.
|
||||
#
|
||||
# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
|
||||
# 'psycopg2' (for PostgreSQL).
|
||||
#
|
||||
# 'args' gives options which are passed through to the database engine,
|
||||
# except for options starting 'cp_', which are used to configure the Twisted
|
||||
# connection pool. For a reference to valid arguments, see:
|
||||
# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
|
||||
# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
|
||||
# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
|
||||
#
|
||||
#
|
||||
# Example SQLite configuration:
|
||||
#
|
||||
#database:
|
||||
# name: sqlite3
|
||||
# args:
|
||||
# database: /path/to/homeserver.db
|
||||
#
|
||||
#
|
||||
# Example Postgres configuration:
|
||||
#
|
||||
#database:
|
||||
# name: psycopg2
|
||||
# args:
|
||||
# user: synapse
|
||||
# password: secretpassword
|
||||
# database: synapse
|
||||
# host: localhost
|
||||
# cp_min: 5
|
||||
# cp_max: 10
|
||||
#
|
||||
# For more information on using Synapse with Postgres, see `docs/postgres.md`.
|
||||
#
|
||||
database:
|
||||
# The database engine name
|
||||
name: "sqlite3"
|
||||
# Arguments to pass to the engine
|
||||
name: sqlite3
|
||||
args:
|
||||
# Path to the database
|
||||
database: "DATADIR/homeserver.db"
|
||||
database: DATADIR/homeserver.db
|
||||
|
||||
# Number of events to cache in memory.
|
||||
#
|
||||
@@ -839,10 +872,6 @@ media_store_path: "DATADIR/media_store"
|
||||
#
|
||||
#enable_registration_captcha: false
|
||||
|
||||
# A secret key used to bypass the captcha test entirely.
|
||||
#
|
||||
#captcha_bypass_secret: "YOUR_SECRET_HERE"
|
||||
|
||||
# The API endpoint to use for verifying m.login.recaptcha responses.
|
||||
#
|
||||
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
|
||||
@@ -1057,6 +1086,29 @@ account_threepid_delegates:
|
||||
#email: https://example.com # Delegate email sending to example.com
|
||||
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
|
||||
|
||||
# Whether users are allowed to change their displayname after it has
|
||||
# been initially set. Useful when provisioning users based on the
|
||||
# contents of a third-party directory.
|
||||
#
|
||||
# Does not apply to server administrators. Defaults to 'true'
|
||||
#
|
||||
#enable_set_displayname: false
|
||||
|
||||
# Whether users are allowed to change their avatar after it has been
|
||||
# initially set. Useful when provisioning users based on the contents
|
||||
# of a third-party directory.
|
||||
#
|
||||
# Does not apply to server administrators. Defaults to 'true'
|
||||
#
|
||||
#enable_set_avatar_url: false
|
||||
|
||||
# Whether users can change the 3PIDs associated with their accounts
|
||||
# (email address and msisdn).
|
||||
#
|
||||
# Defaults to 'true'
|
||||
#
|
||||
#enable_3pid_changes: false
|
||||
|
||||
# Users who register on this homeserver will automatically be joined
|
||||
# to these rooms
|
||||
#
|
||||
@@ -1392,6 +1444,10 @@ sso:
|
||||
# phishing attacks from evil.site. To avoid this, include a slash after the
|
||||
# hostname: "https://my.client/".
|
||||
#
|
||||
# If public_baseurl is set, then the login fallback page (used by clients
|
||||
# that don't natively support the required login flows) is whitelisted in
|
||||
# addition to any URLs in this list.
|
||||
#
|
||||
# By default, this list is empty.
|
||||
#
|
||||
#client_whitelist:
|
||||
@@ -1453,6 +1509,41 @@ password_config:
|
||||
#
|
||||
#pepper: "EVEN_MORE_SECRET"
|
||||
|
||||
# Define and enforce a password policy. Each parameter is optional.
|
||||
# This is an implementation of MSC2000.
|
||||
#
|
||||
policy:
|
||||
# Whether to enforce the password policy.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#enabled: true
|
||||
|
||||
# Minimum accepted length for a password.
|
||||
# Defaults to 0.
|
||||
#
|
||||
#minimum_length: 15
|
||||
|
||||
# Whether a password must contain at least one digit.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_digit: true
|
||||
|
||||
# Whether a password must contain at least one symbol.
|
||||
# A symbol is any character that's not a number or a letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_symbol: true
|
||||
|
||||
# Whether a password must contain at least one lowercase letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_lowercase: true
|
||||
|
||||
# Whether a password must contain at least one lowercase letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_uppercase: true
|
||||
|
||||
|
||||
# Configuration for sending emails from Synapse.
|
||||
#
|
||||
|
||||
@@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and
|
||||
'<' worker to master flows):
|
||||
|
||||
> SERVER example.com
|
||||
< REPLICATE events 53
|
||||
< REPLICATE
|
||||
> POSITION events 53
|
||||
> RDATA events 54 ["$foo1:bar.com", ...]
|
||||
> RDATA events 55 ["$foo4:bar.com", ...]
|
||||
|
||||
The example shows the server accepting a new connection and sending its
|
||||
identity with the `SERVER` command, followed by the client asking to
|
||||
subscribe to the `events` stream from the token `53`. The server then
|
||||
periodically sends `RDATA` commands which have the format
|
||||
`RDATA <stream_name> <token> <row>`, where the format of `<row>` is
|
||||
defined by the individual streams.
|
||||
The example shows the server accepting a new connection and sending its identity
|
||||
with the `SERVER` command, followed by the client server to respond with the
|
||||
position of all streams. The server then periodically sends `RDATA` commands
|
||||
which have the format `RDATA <stream_name> <token> <row>`, where the format of
|
||||
`<row>` is defined by the individual streams.
|
||||
|
||||
Error reporting happens by either the client or server sending an ERROR
|
||||
command, and usually the connection will be closed.
|
||||
@@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually
|
||||
connect to the server using a tool like netcat. A few things should be
|
||||
noted when manually using the protocol:
|
||||
|
||||
- When subscribing to a stream using `REPLICATE`, the special token
|
||||
`NOW` can be used to get all future updates. The special stream name
|
||||
`ALL` can be used with `NOW` to subscribe to all available streams.
|
||||
- The federation stream is only available if federation sending has
|
||||
been disabled on the main process.
|
||||
- The server will only time connections out that have sent a `PING`
|
||||
@@ -91,9 +88,7 @@ The client:
|
||||
- Sends a `NAME` command, allowing the server to associate a human
|
||||
friendly name with the connection. This is optional.
|
||||
- Sends a `PING` as above
|
||||
- For each stream the client wishes to subscribe to it sends a
|
||||
`REPLICATE` with the `stream_name` and token it wants to subscribe
|
||||
from.
|
||||
- Sends a `REPLICATE` to get the current position of all streams.
|
||||
- On receipt of a `SERVER` command, checks that the server name
|
||||
matches the expected server name.
|
||||
|
||||
@@ -140,9 +135,7 @@ the wire:
|
||||
> PING 1490197665618
|
||||
< NAME synapse.app.appservice
|
||||
< PING 1490197665618
|
||||
< REPLICATE events 1
|
||||
< REPLICATE backfill 1
|
||||
< REPLICATE caches 1
|
||||
< REPLICATE
|
||||
> POSITION events 1
|
||||
> POSITION backfill 1
|
||||
> POSITION caches 1
|
||||
@@ -181,9 +174,9 @@ client (C):
|
||||
|
||||
#### POSITION (S)
|
||||
|
||||
The position of the stream has been updated. Sent to the client
|
||||
after all missing updates for a stream have been sent to the client
|
||||
and they're now up to date.
|
||||
On receipt of a POSITION command clients should check if they have missed any
|
||||
updates, and if so then fetch them out of band. Sent in response to a
|
||||
REPLICATE command (but can happen at any time).
|
||||
|
||||
#### ERROR (S, C)
|
||||
|
||||
@@ -199,25 +192,18 @@ client (C):
|
||||
|
||||
#### REPLICATE (C)
|
||||
|
||||
Asks the server to replicate a given stream. The syntax is:
|
||||
|
||||
```
|
||||
REPLICATE <stream_name> <token>
|
||||
```
|
||||
|
||||
Where `<token>` may be either:
|
||||
* a numeric stream_id to stream updates since (exclusive)
|
||||
* `NOW` to stream all subsequent updates.
|
||||
|
||||
The `<stream_name>` is the name of a replication stream to subscribe
|
||||
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
|
||||
of streams). It can also be `ALL` to subscribe to all known streams,
|
||||
in which case the `<token>` must be set to `NOW`.
|
||||
Asks the server for the current position of all streams.
|
||||
|
||||
#### USER_SYNC (C)
|
||||
|
||||
A user has started or stopped syncing
|
||||
|
||||
#### CLEAR_USER_SYNC (C)
|
||||
|
||||
The server should clear all associated user sync data from the worker.
|
||||
|
||||
This is used when a worker is shutting down.
|
||||
|
||||
#### FEDERATION_ACK (C)
|
||||
|
||||
Acknowledge receipt of some federation data
|
||||
|
||||
@@ -64,6 +64,13 @@ class Codes(object):
|
||||
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
|
||||
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
|
||||
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
|
||||
PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
|
||||
PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
|
||||
PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
|
||||
PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
|
||||
PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
|
||||
PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
|
||||
WEAK_PASSWORD = "M_WEAK_PASSWORD"
|
||||
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
||||
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
||||
BAD_ALIAS = "M_BAD_ALIAS"
|
||||
@@ -439,6 +446,20 @@ class IncompatibleRoomVersionError(SynapseError):
|
||||
return cs_error(self.msg, self.errcode, room_version=self._room_version)
|
||||
|
||||
|
||||
class PasswordRefusedError(SynapseError):
|
||||
"""A password has been refused, either during password reset/change or registration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
msg="This password doesn't comply with the server's policy",
|
||||
errcode=Codes.WEAK_PASSWORD,
|
||||
):
|
||||
super(PasswordRefusedError, self).__init__(
|
||||
code=400, msg=msg, errcode=errcode,
|
||||
)
|
||||
|
||||
|
||||
class RequestSendFailed(RuntimeError):
|
||||
"""Sending a HTTP request over federation failed due to not being able to
|
||||
talk to the remote server for some reason.
|
||||
|
||||
@@ -64,13 +64,26 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
from synapse.replication.tcp.client import ReplicationClientFactory
|
||||
from synapse.replication.tcp.commands import ClearUserSyncsCommand
|
||||
from synapse.replication.tcp.handler import WorkerReplicationDataHandler
|
||||
from synapse.replication.tcp.streams import (
|
||||
AccountDataStream,
|
||||
DeviceListsStream,
|
||||
GroupServerStream,
|
||||
PresenceStream,
|
||||
PushersStream,
|
||||
PushRulesStream,
|
||||
ReceiptsStream,
|
||||
TagAccountDataStream,
|
||||
ToDeviceStream,
|
||||
TypingStream,
|
||||
)
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStream,
|
||||
EventsStreamEventRow,
|
||||
EventsStreamRow,
|
||||
)
|
||||
from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow
|
||||
from synapse.rest.admin import register_servlets_for_media_repo
|
||||
from synapse.rest.client.v1 import events
|
||||
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
||||
@@ -113,7 +126,6 @@ from synapse.types import ReadReceipt
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
logger = logging.getLogger("synapse.app.generic_worker")
|
||||
@@ -222,6 +234,7 @@ class GenericWorkerPresence(object):
|
||||
self.user_to_num_current_syncs = {}
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.instance_id = hs.get_instance_id()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
self.user_to_current_state = {state.user_id: state for state in active_presence}
|
||||
@@ -234,13 +247,24 @@ class GenericWorkerPresence(object):
|
||||
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
|
||||
)
|
||||
|
||||
self.process_id = random_string(16)
|
||||
logger.info("Presence process_id is %r", self.process_id)
|
||||
hs.get_reactor().addSystemEventTrigger(
|
||||
"before",
|
||||
"shutdown",
|
||||
run_as_background_process,
|
||||
"generic_presence.on_shutdown",
|
||||
self._on_shutdown,
|
||||
)
|
||||
|
||||
def _on_shutdown(self):
|
||||
if self.hs.config.use_presence:
|
||||
self.hs.get_tcp_replication().send_command(
|
||||
ClearUserSyncsCommand(self.instance_id)
|
||||
)
|
||||
|
||||
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||
if self.hs.config.use_presence:
|
||||
self.hs.get_tcp_replication().send_user_sync(
|
||||
user_id, is_syncing, last_sync_ms
|
||||
self.instance_id, user_id, is_syncing, last_sync_ms
|
||||
)
|
||||
|
||||
def mark_as_coming_online(self, user_id):
|
||||
@@ -390,6 +414,9 @@ class GenericWorkerTyping(object):
|
||||
self._room_serials[row.room_id] = token
|
||||
self._room_typing[row.room_id] = row.user_ids
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
return self._latest_room_serial
|
||||
|
||||
|
||||
class GenericWorkerSlavedStore(
|
||||
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
|
||||
@@ -572,25 +599,26 @@ class GenericWorkerServer(HomeServer):
|
||||
else:
|
||||
logger.warning("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
self.get_tcp_replication().start_replication(self)
|
||||
factory = ReplicationClientFactory(self, self.config.worker_name)
|
||||
host = self.config.worker_replication_host
|
||||
port = self.config.worker_replication_port
|
||||
self.get_reactor().connectTCP(host, port, factory)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
def build_tcp_replication(self):
|
||||
return GenericWorkerReplicationHandler(self)
|
||||
|
||||
def build_presence_handler(self):
|
||||
return GenericWorkerPresence(self)
|
||||
|
||||
def build_typing_handler(self):
|
||||
return GenericWorkerTyping(self)
|
||||
|
||||
def build_replication_data_handler(self):
|
||||
return GenericWorkerReplicationHandler(self)
|
||||
|
||||
class GenericWorkerReplicationHandler(ReplicationClientHandler):
|
||||
|
||||
class GenericWorkerReplicationHandler(WorkerReplicationDataHandler):
|
||||
def __init__(self, hs):
|
||||
super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.typing_handler = hs.get_typing_handler()
|
||||
# NB this is a SynchrotronPresence, not a normal PresenceHandler
|
||||
@@ -618,15 +646,12 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
|
||||
args.update(self.send_handler.stream_positions())
|
||||
return args
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
return self.presence_handler.get_currently_syncing_users()
|
||||
|
||||
async def process_and_notify(self, stream_name, token, rows):
|
||||
try:
|
||||
if self.send_handler:
|
||||
self.send_handler.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
if stream_name == "events":
|
||||
if stream_name == EventsStream.NAME:
|
||||
# We shouldn't get multiple rows per token for events stream, so
|
||||
# we don't need to optimise this for multiple rows.
|
||||
for row in rows:
|
||||
@@ -649,43 +674,44 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
|
||||
)
|
||||
|
||||
await self.pusher_pool.on_new_notifications(token, token)
|
||||
elif stream_name == "push_rules":
|
||||
elif stream_name == PushRulesStream.NAME:
|
||||
self.notifier.on_new_event(
|
||||
"push_rules_key", token, users=[row.user_id for row in rows]
|
||||
)
|
||||
elif stream_name in ("account_data", "tag_account_data"):
|
||||
elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
|
||||
self.notifier.on_new_event(
|
||||
"account_data_key", token, users=[row.user_id for row in rows]
|
||||
)
|
||||
elif stream_name == "receipts":
|
||||
elif stream_name == ReceiptsStream.NAME:
|
||||
self.notifier.on_new_event(
|
||||
"receipt_key", token, rooms=[row.room_id for row in rows]
|
||||
)
|
||||
await self.pusher_pool.on_new_receipts(
|
||||
token, token, {row.room_id for row in rows}
|
||||
)
|
||||
elif stream_name == "typing":
|
||||
elif stream_name == TypingStream.NAME:
|
||||
self.typing_handler.process_replication_rows(token, rows)
|
||||
self.notifier.on_new_event(
|
||||
"typing_key", token, rooms=[row.room_id for row in rows]
|
||||
)
|
||||
elif stream_name == "to_device":
|
||||
elif stream_name == ToDeviceStream.NAME:
|
||||
entities = [row.entity for row in rows if row.entity.startswith("@")]
|
||||
if entities:
|
||||
self.notifier.on_new_event("to_device_key", token, users=entities)
|
||||
elif stream_name == "device_lists":
|
||||
elif stream_name == DeviceListsStream.NAME:
|
||||
all_room_ids = set()
|
||||
for row in rows:
|
||||
room_ids = await self.store.get_rooms_for_user(row.user_id)
|
||||
all_room_ids.update(room_ids)
|
||||
if row.entity.startswith("@"):
|
||||
room_ids = await self.store.get_rooms_for_user(row.entity)
|
||||
all_room_ids.update(room_ids)
|
||||
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
|
||||
elif stream_name == "presence":
|
||||
elif stream_name == PresenceStream.NAME:
|
||||
await self.presence_handler.process_replication_rows(token, rows)
|
||||
elif stream_name == "receipts":
|
||||
elif stream_name == GroupServerStream.NAME:
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[row.user_id for row in rows]
|
||||
)
|
||||
elif stream_name == "pushers":
|
||||
elif stream_name == PushersStream.NAME:
|
||||
for row in rows:
|
||||
if row.deleted:
|
||||
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
|
||||
@@ -774,7 +800,10 @@ class FederationSenderHandler(object):
|
||||
|
||||
# ... as well as device updates and messages
|
||||
elif stream_name == DeviceListsStream.NAME:
|
||||
hosts = {row.destination for row in rows}
|
||||
# The entities are either user IDs (starting with '@') whose devices
|
||||
# have changed, or remote servers that we need to tell about
|
||||
# changes.
|
||||
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
|
||||
for host in hosts:
|
||||
self.federation_sender.send_device_messages(host)
|
||||
|
||||
@@ -789,7 +818,7 @@ class FederationSenderHandler(object):
|
||||
async def _on_new_receipts(self, rows):
|
||||
"""
|
||||
Args:
|
||||
rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
|
||||
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
|
||||
new receipts to be processed
|
||||
"""
|
||||
for receipt in rows:
|
||||
@@ -860,6 +889,9 @@ def start(config_options):
|
||||
|
||||
# Force the appservice to start since they will be disabled in the main config
|
||||
config.notify_appservices = True
|
||||
else:
|
||||
# For other worker types we force this to off.
|
||||
config.notify_appservices = False
|
||||
|
||||
if config.worker_app == "synapse.app.pusher":
|
||||
if config.start_pushers:
|
||||
@@ -873,6 +905,9 @@ def start(config_options):
|
||||
|
||||
# Force the pushers to start since they will be disabled in the main config
|
||||
config.start_pushers = True
|
||||
else:
|
||||
# For other worker types we force this to off.
|
||||
config.start_pushers = False
|
||||
|
||||
if config.worker_app == "synapse.app.user_dir":
|
||||
if config.update_user_directory:
|
||||
@@ -886,6 +921,9 @@ def start(config_options):
|
||||
|
||||
# Force the pushers to start since they will be disabled in the main config
|
||||
config.update_user_directory = True
|
||||
else:
|
||||
# For other worker types we force this to off.
|
||||
config.update_user_directory = False
|
||||
|
||||
if config.worker_app == "synapse.app.federation_sender":
|
||||
if config.send_federation:
|
||||
@@ -899,6 +937,9 @@ def start(config_options):
|
||||
|
||||
# Force the pushers to start since they will be disabled in the main config
|
||||
config.send_federation = True
|
||||
else:
|
||||
# For other worker types we force this to off.
|
||||
config.send_federation = False
|
||||
|
||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
|
||||
@@ -294,7 +294,6 @@ class RootConfig(object):
|
||||
report_stats=None,
|
||||
open_private_ports=False,
|
||||
listeners=None,
|
||||
database_conf=None,
|
||||
tls_certificate_path=None,
|
||||
tls_private_key_path=None,
|
||||
acme_domain=None,
|
||||
@@ -367,7 +366,6 @@ class RootConfig(object):
|
||||
report_stats=report_stats,
|
||||
open_private_ports=open_private_ports,
|
||||
listeners=listeners,
|
||||
database_conf=database_conf,
|
||||
tls_certificate_path=tls_certificate_path,
|
||||
tls_private_key_path=tls_private_key_path,
|
||||
acme_domain=acme_domain,
|
||||
|
||||
@@ -24,7 +24,6 @@ class CaptchaConfig(Config):
|
||||
self.enable_registration_captcha = config.get(
|
||||
"enable_registration_captcha", False
|
||||
)
|
||||
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
|
||||
self.recaptcha_siteverify_api = config.get(
|
||||
"recaptcha_siteverify_api",
|
||||
"https://www.recaptcha.net/recaptcha/api/siteverify",
|
||||
@@ -49,10 +48,6 @@ class CaptchaConfig(Config):
|
||||
#
|
||||
#enable_registration_captcha: false
|
||||
|
||||
# A secret key used to bypass the captcha test entirely.
|
||||
#
|
||||
#captcha_bypass_secret: "YOUR_SECRET_HERE"
|
||||
|
||||
# The API endpoint to use for verifying m.login.recaptcha responses.
|
||||
#
|
||||
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,14 +15,65 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from textwrap import indent
|
||||
|
||||
import yaml
|
||||
|
||||
from synapse.config._base import Config, ConfigError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NON_SQLITE_DATABASE_PATH_WARNING = """\
|
||||
Ignoring 'database_path' setting: not using a sqlite3 database.
|
||||
--------------------------------------------------------------------------------
|
||||
"""
|
||||
|
||||
DEFAULT_CONFIG = """\
|
||||
## Database ##
|
||||
|
||||
# The 'database' setting defines the database that synapse uses to store all of
|
||||
# its data.
|
||||
#
|
||||
# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
|
||||
# 'psycopg2' (for PostgreSQL).
|
||||
#
|
||||
# 'args' gives options which are passed through to the database engine,
|
||||
# except for options starting 'cp_', which are used to configure the Twisted
|
||||
# connection pool. For a reference to valid arguments, see:
|
||||
# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
|
||||
# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
|
||||
# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
|
||||
#
|
||||
#
|
||||
# Example SQLite configuration:
|
||||
#
|
||||
#database:
|
||||
# name: sqlite3
|
||||
# args:
|
||||
# database: /path/to/homeserver.db
|
||||
#
|
||||
#
|
||||
# Example Postgres configuration:
|
||||
#
|
||||
#database:
|
||||
# name: psycopg2
|
||||
# args:
|
||||
# user: synapse
|
||||
# password: secretpassword
|
||||
# database: synapse
|
||||
# host: localhost
|
||||
# cp_min: 5
|
||||
# cp_max: 10
|
||||
#
|
||||
# For more information on using Synapse with Postgres, see `docs/postgres.md`.
|
||||
#
|
||||
database:
|
||||
name: sqlite3
|
||||
args:
|
||||
database: %(database_path)s
|
||||
|
||||
# Number of events to cache in memory.
|
||||
#
|
||||
#event_cache_size: 10K
|
||||
"""
|
||||
|
||||
|
||||
class DatabaseConnectionConfig:
|
||||
"""Contains the connection config for a particular database.
|
||||
@@ -36,10 +88,12 @@ class DatabaseConnectionConfig:
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, db_config: dict):
|
||||
if db_config["name"] not in ("sqlite3", "psycopg2"):
|
||||
raise ConfigError("Unsupported database type %r" % (db_config["name"],))
|
||||
db_engine = db_config.get("name", "sqlite3")
|
||||
|
||||
if db_config["name"] == "sqlite3":
|
||||
if db_engine not in ("sqlite3", "psycopg2"):
|
||||
raise ConfigError("Unsupported database type %r" % (db_engine,))
|
||||
|
||||
if db_engine == "sqlite3":
|
||||
db_config.setdefault("args", {}).update(
|
||||
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
|
||||
)
|
||||
@@ -56,6 +110,11 @@ class DatabaseConnectionConfig:
|
||||
class DatabaseConfig(Config):
|
||||
section = "database"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.databases = []
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
|
||||
|
||||
@@ -76,12 +135,13 @@ class DatabaseConfig(Config):
|
||||
|
||||
multi_database_config = config.get("databases")
|
||||
database_config = config.get("database")
|
||||
database_path = config.get("database_path")
|
||||
|
||||
if multi_database_config and database_config:
|
||||
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
|
||||
|
||||
if multi_database_config:
|
||||
if config.get("database_path"):
|
||||
if database_path:
|
||||
raise ConfigError("Can't specify 'database_path' with 'databases'")
|
||||
|
||||
self.databases = [
|
||||
@@ -89,65 +149,55 @@ class DatabaseConfig(Config):
|
||||
for name, db_conf in multi_database_config.items()
|
||||
]
|
||||
|
||||
else:
|
||||
if database_config is None:
|
||||
database_config = {"name": "sqlite3", "args": {}}
|
||||
|
||||
if database_config:
|
||||
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
||||
|
||||
self.set_databasepath(config.get("database_path"))
|
||||
if database_path:
|
||||
if self.databases and self.databases[0].name != "sqlite3":
|
||||
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
|
||||
return
|
||||
|
||||
def generate_config_section(self, data_dir_path, database_conf, **kwargs):
|
||||
if not database_conf:
|
||||
database_path = os.path.join(data_dir_path, "homeserver.db")
|
||||
database_conf = (
|
||||
"""# The database engine name
|
||||
name: "sqlite3"
|
||||
# Arguments to pass to the engine
|
||||
args:
|
||||
# Path to the database
|
||||
database: "%(database_path)s"
|
||||
"""
|
||||
% locals()
|
||||
)
|
||||
else:
|
||||
database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip()
|
||||
database_config = {"name": "sqlite3", "args": {}}
|
||||
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
||||
self.set_databasepath(database_path)
|
||||
|
||||
return (
|
||||
"""\
|
||||
## Database ##
|
||||
|
||||
database:
|
||||
%(database_conf)s
|
||||
# Number of events to cache in memory.
|
||||
#
|
||||
#event_cache_size: 10K
|
||||
"""
|
||||
% locals()
|
||||
)
|
||||
def generate_config_section(self, data_dir_path, **kwargs):
|
||||
return DEFAULT_CONFIG % {
|
||||
"database_path": os.path.join(data_dir_path, "homeserver.db")
|
||||
}
|
||||
|
||||
def read_arguments(self, args):
|
||||
self.set_databasepath(args.database_path)
|
||||
"""
|
||||
Cases for the cli input:
|
||||
- If no databases are configured and no database_path is set, raise.
|
||||
- No databases and only database_path available ==> sqlite3 db.
|
||||
- If there are multiple databases and a database_path raise an error.
|
||||
- If the database set in the config file is sqlite then
|
||||
overwrite with the command line argument.
|
||||
"""
|
||||
|
||||
if args.database_path is None:
|
||||
if not self.databases:
|
||||
raise ConfigError("No database config provided")
|
||||
return
|
||||
|
||||
if len(self.databases) == 0:
|
||||
database_config = {"name": "sqlite3", "args": {}}
|
||||
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
||||
self.set_databasepath(args.database_path)
|
||||
return
|
||||
|
||||
if self.get_single_database().name == "sqlite3":
|
||||
self.set_databasepath(args.database_path)
|
||||
else:
|
||||
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
|
||||
|
||||
def set_databasepath(self, database_path):
|
||||
if database_path is None:
|
||||
return
|
||||
|
||||
if database_path != ":memory:":
|
||||
database_path = self.abspath(database_path)
|
||||
|
||||
# We only support setting a database path if we have a single sqlite3
|
||||
# database.
|
||||
if len(self.databases) != 1:
|
||||
raise ConfigError("Cannot specify 'database_path' with multiple databases")
|
||||
|
||||
database = self.get_single_database()
|
||||
if database.config["name"] != "sqlite3":
|
||||
# We don't raise here as we haven't done so before for this case.
|
||||
logger.warn("Ignoring 'database_path' for non-sqlite3 database")
|
||||
return
|
||||
|
||||
database.config["args"]["database"] = database_path
|
||||
self.databases[0].config["args"]["database"] = database_path
|
||||
|
||||
@staticmethod
|
||||
def add_arguments(parser):
|
||||
@@ -162,7 +212,7 @@ class DatabaseConfig(Config):
|
||||
def get_single_database(self) -> DatabaseConnectionConfig:
|
||||
"""Returns the database if there is only one, useful for e.g. tests
|
||||
"""
|
||||
if len(self.databases) != 1:
|
||||
if not self.databases:
|
||||
raise Exception("More than one database exists")
|
||||
|
||||
return self.databases[0]
|
||||
|
||||
@@ -31,6 +31,10 @@ class PasswordConfig(Config):
|
||||
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
|
||||
self.password_pepper = password_config.get("pepper", "")
|
||||
|
||||
# Password policy
|
||||
self.password_policy = password_config.get("policy") or {}
|
||||
self.password_policy_enabled = self.password_policy.get("enabled", False)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """\
|
||||
password_config:
|
||||
@@ -48,4 +52,39 @@ class PasswordConfig(Config):
|
||||
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
|
||||
#
|
||||
#pepper: "EVEN_MORE_SECRET"
|
||||
|
||||
# Define and enforce a password policy. Each parameter is optional.
|
||||
# This is an implementation of MSC2000.
|
||||
#
|
||||
policy:
|
||||
# Whether to enforce the password policy.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#enabled: true
|
||||
|
||||
# Minimum accepted length for a password.
|
||||
# Defaults to 0.
|
||||
#
|
||||
#minimum_length: 15
|
||||
|
||||
# Whether a password must contain at least one digit.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_digit: true
|
||||
|
||||
# Whether a password must contain at least one symbol.
|
||||
# A symbol is any character that's not a number or a letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_symbol: true
|
||||
|
||||
# Whether a password must contain at least one lowercase letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_lowercase: true
|
||||
|
||||
# Whether a password must contain at least one lowercase letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_uppercase: true
|
||||
"""
|
||||
|
||||
@@ -129,6 +129,10 @@ class RegistrationConfig(Config):
|
||||
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
|
||||
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
|
||||
|
||||
self.enable_set_displayname = config.get("enable_set_displayname", True)
|
||||
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
|
||||
self.enable_3pid_changes = config.get("enable_3pid_changes", True)
|
||||
|
||||
self.disable_msisdn_registration = config.get(
|
||||
"disable_msisdn_registration", False
|
||||
)
|
||||
@@ -330,6 +334,29 @@ class RegistrationConfig(Config):
|
||||
#email: https://example.com # Delegate email sending to example.com
|
||||
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
|
||||
|
||||
# Whether users are allowed to change their displayname after it has
|
||||
# been initially set. Useful when provisioning users based on the
|
||||
# contents of a third-party directory.
|
||||
#
|
||||
# Does not apply to server administrators. Defaults to 'true'
|
||||
#
|
||||
#enable_set_displayname: false
|
||||
|
||||
# Whether users are allowed to change their avatar after it has been
|
||||
# initially set. Useful when provisioning users based on the contents
|
||||
# of a third-party directory.
|
||||
#
|
||||
# Does not apply to server administrators. Defaults to 'true'
|
||||
#
|
||||
#enable_set_avatar_url: false
|
||||
|
||||
# Whether users can change the 3PIDs associated with their accounts
|
||||
# (email address and msisdn).
|
||||
#
|
||||
# Defaults to 'true'
|
||||
#
|
||||
#enable_3pid_changes: false
|
||||
|
||||
# Users who register on this homeserver will automatically be joined
|
||||
# to these rooms
|
||||
#
|
||||
|
||||
@@ -39,6 +39,17 @@ class SSOConfig(Config):
|
||||
|
||||
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
|
||||
|
||||
# Attempt to also whitelist the server's login fallback, since that fallback sets
|
||||
# the redirect URL to itself (so it can process the login token then return
|
||||
# gracefully to the client). This would make it pointless to ask the user for
|
||||
# confirmation, since the URL the confirmation page would be showing wouldn't be
|
||||
# the client's.
|
||||
# public_baseurl is an optional setting, so we only add the fallback's URL to the
|
||||
# list if it's provided (because we can't figure out what that URL is otherwise).
|
||||
if self.public_baseurl:
|
||||
login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
|
||||
self.sso_client_whitelist.append(login_fallback_url)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
|
||||
@@ -54,6 +65,10 @@ class SSOConfig(Config):
|
||||
# phishing attacks from evil.site. To avoid this, include a slash after the
|
||||
# hostname: "https://my.client/".
|
||||
#
|
||||
# If public_baseurl is set, then the login fallback page (used by clients
|
||||
# that don't natively support the required login flows) is whitelisted in
|
||||
# addition to any URLs in this list.
|
||||
#
|
||||
# By default, this list is empty.
|
||||
#
|
||||
#client_whitelist:
|
||||
|
||||
@@ -43,8 +43,8 @@ from synapse.api.errors import (
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.logging.context import (
|
||||
LoggingContext,
|
||||
PreserveLoggingContext,
|
||||
current_context,
|
||||
make_deferred_yieldable,
|
||||
preserve_fn,
|
||||
run_in_background,
|
||||
@@ -236,7 +236,7 @@ class Keyring(object):
|
||||
"""
|
||||
|
||||
try:
|
||||
ctx = LoggingContext.current_context()
|
||||
ctx = current_context()
|
||||
|
||||
# map from server name to a set of outstanding request ids
|
||||
server_to_request_ids = {}
|
||||
|
||||
@@ -25,19 +25,15 @@ from twisted.python.failure import Failure
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventFormatVersions,
|
||||
RoomVersion,
|
||||
)
|
||||
from synapse.api.room_versions import EventFormatVersions, RoomVersion
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
from synapse.crypto.keyring import Keyring
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.http.servlet import assert_params_in_dict
|
||||
from synapse.logging.context import (
|
||||
LoggingContext,
|
||||
PreserveLoggingContext,
|
||||
current_context,
|
||||
make_deferred_yieldable,
|
||||
)
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
@@ -55,13 +51,15 @@ class FederationBase(object):
|
||||
self.store = hs.get_datastore()
|
||||
self._clock = hs.get_clock()
|
||||
|
||||
def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
|
||||
def _check_sigs_and_hash(
|
||||
self, room_version: RoomVersion, pdu: EventBase
|
||||
) -> Deferred:
|
||||
return make_deferred_yieldable(
|
||||
self._check_sigs_and_hashes(room_version, [pdu])[0]
|
||||
)
|
||||
|
||||
def _check_sigs_and_hashes(
|
||||
self, room_version: str, pdus: List[EventBase]
|
||||
self, room_version: RoomVersion, pdus: List[EventBase]
|
||||
) -> List[Deferred]:
|
||||
"""Checks that each of the received events is correctly signed by the
|
||||
sending server.
|
||||
@@ -80,7 +78,7 @@ class FederationBase(object):
|
||||
"""
|
||||
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
|
||||
|
||||
ctx = LoggingContext.current_context()
|
||||
ctx = current_context()
|
||||
|
||||
def callback(_, pdu: EventBase):
|
||||
with PreserveLoggingContext(ctx):
|
||||
@@ -146,7 +144,7 @@ class PduToCheckSig(
|
||||
|
||||
|
||||
def _check_sigs_on_pdus(
|
||||
keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
|
||||
keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
|
||||
) -> List[Deferred]:
|
||||
"""Check that the given events are correctly signed
|
||||
|
||||
@@ -191,10 +189,6 @@ def _check_sigs_on_pdus(
|
||||
for p in pdus
|
||||
]
|
||||
|
||||
v = KNOWN_ROOM_VERSIONS.get(room_version)
|
||||
if not v:
|
||||
raise RuntimeError("Unrecognized room version %s" % (room_version,))
|
||||
|
||||
# First we check that the sender event is signed by the sender's domain
|
||||
# (except if its a 3pid invite, in which case it may be sent by any server)
|
||||
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
|
||||
@@ -204,7 +198,7 @@ def _check_sigs_on_pdus(
|
||||
(
|
||||
p.sender_domain,
|
||||
p.redacted_pdu_json,
|
||||
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
|
||||
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||
p.pdu.event_id,
|
||||
)
|
||||
for p in pdus_to_check_sender
|
||||
@@ -227,7 +221,7 @@ def _check_sigs_on_pdus(
|
||||
# event id's domain (normally only the case for joins/leaves), and add additional
|
||||
# checks. Only do this if the room version has a concept of event ID domain
|
||||
# (ie, the room version uses old-style non-hash event IDs).
|
||||
if v.event_format == EventFormatVersions.V1:
|
||||
if room_version.event_format == EventFormatVersions.V1:
|
||||
pdus_to_check_event_id = [
|
||||
p
|
||||
for p in pdus_to_check
|
||||
@@ -239,7 +233,7 @@ def _check_sigs_on_pdus(
|
||||
(
|
||||
get_domain_from_id(p.pdu.event_id),
|
||||
p.redacted_pdu_json,
|
||||
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
|
||||
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||
p.pdu.event_id,
|
||||
)
|
||||
for p in pdus_to_check_event_id
|
||||
|
||||
@@ -220,8 +220,7 @@ class FederationClient(FederationBase):
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
pdus[:] = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
self._check_sigs_and_hashes(room_version.identifier, pdus),
|
||||
consumeErrors=True,
|
||||
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
@@ -291,9 +290,7 @@ class FederationClient(FederationBase):
|
||||
pdu = pdu_list[0]
|
||||
|
||||
# Check signatures are correct.
|
||||
signed_pdu = await self._check_sigs_and_hash(
|
||||
room_version.identifier, pdu
|
||||
)
|
||||
signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
|
||||
|
||||
break
|
||||
|
||||
@@ -350,7 +347,7 @@ class FederationClient(FederationBase):
|
||||
self,
|
||||
origin: str,
|
||||
pdus: List[EventBase],
|
||||
room_version: str,
|
||||
room_version: RoomVersion,
|
||||
outlier: bool = False,
|
||||
include_none: bool = False,
|
||||
) -> List[EventBase]:
|
||||
@@ -396,7 +393,7 @@ class FederationClient(FederationBase):
|
||||
self.get_pdu(
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
room_version=room_version, # type: ignore
|
||||
room_version=room_version,
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
)
|
||||
@@ -434,7 +431,7 @@ class FederationClient(FederationBase):
|
||||
]
|
||||
|
||||
signed_auth = await self._check_sigs_and_hash_and_fetch(
|
||||
destination, auth_chain, outlier=True, room_version=room_version.identifier
|
||||
destination, auth_chain, outlier=True, room_version=room_version
|
||||
)
|
||||
|
||||
signed_auth.sort(key=lambda e: e.depth)
|
||||
@@ -661,7 +658,7 @@ class FederationClient(FederationBase):
|
||||
destination,
|
||||
list(pdus.values()),
|
||||
outlier=True,
|
||||
room_version=room_version.identifier,
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
valid_pdus_map = {p.event_id: p for p in valid_pdus}
|
||||
@@ -756,7 +753,7 @@ class FederationClient(FederationBase):
|
||||
pdu = event_from_pdu_json(pdu_dict, room_version)
|
||||
|
||||
# Check signatures are correct.
|
||||
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
|
||||
pdu = await self._check_sigs_and_hash(room_version, pdu)
|
||||
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
|
||||
@@ -948,7 +945,7 @@ class FederationClient(FederationBase):
|
||||
]
|
||||
|
||||
signed_events = await self._check_sigs_and_hash_and_fetch(
|
||||
destination, events, outlier=False, room_version=room_version.identifier
|
||||
destination, events, outlier=False, room_version=room_version
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
if not e.code == 400:
|
||||
|
||||
@@ -409,7 +409,7 @@ class FederationServer(FederationBase):
|
||||
pdu = event_from_pdu_json(content, room_version)
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
await self.check_server_matches_acl(origin_host, pdu.room_id)
|
||||
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
|
||||
pdu = await self._check_sigs_and_hash(room_version, pdu)
|
||||
ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
|
||||
time_now = self._clock.time_msec()
|
||||
return {"event": ret_pdu.get_pdu_json(time_now)}
|
||||
@@ -425,7 +425,7 @@ class FederationServer(FederationBase):
|
||||
|
||||
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
|
||||
|
||||
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
|
||||
pdu = await self._check_sigs_and_hash(room_version, pdu)
|
||||
|
||||
res_pdus = await self.handler.on_send_join_request(origin, pdu)
|
||||
time_now = self._clock.time_msec()
|
||||
@@ -455,7 +455,7 @@ class FederationServer(FederationBase):
|
||||
|
||||
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
|
||||
|
||||
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
|
||||
pdu = await self._check_sigs_and_hash(room_version, pdu)
|
||||
|
||||
await self.handler.on_send_leave_request(origin, pdu)
|
||||
return {}
|
||||
@@ -611,7 +611,7 @@ class FederationServer(FederationBase):
|
||||
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
|
||||
|
||||
# We've already checked that we know the room version by this point
|
||||
room_version = await self.store.get_room_version_id(pdu.room_id)
|
||||
room_version = await self.store.get_room_version(pdu.room_id)
|
||||
|
||||
# Check signature.
|
||||
try:
|
||||
|
||||
@@ -477,7 +477,7 @@ def process_rows_for_federation(transaction_queue, rows):
|
||||
|
||||
Args:
|
||||
transaction_queue (FederationSender)
|
||||
rows (list(synapse.replication.tcp.streams.FederationStreamRow))
|
||||
rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow))
|
||||
"""
|
||||
|
||||
# The federation stream contains a bunch of different types of
|
||||
|
||||
@@ -499,4 +499,13 @@ class FederationSender(object):
|
||||
self._get_per_destination_queue(destination).attempt_new_transaction()
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
# Dummy implementation for case where federation sender isn't offloaded
|
||||
# to a worker.
|
||||
return 0
|
||||
|
||||
async def get_replication_rows(
|
||||
self, from_token, to_token, limit, federation_ack=None
|
||||
):
|
||||
# Dummy implementation for case where federation sender isn't offloaded
|
||||
# to a worker.
|
||||
return []
|
||||
|
||||
@@ -125,7 +125,11 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def validate_user_via_ui_auth(
|
||||
self, requester: Requester, request_body: Dict[str, Any], clientip: str
|
||||
self,
|
||||
requester: Requester,
|
||||
request: SynapseRequest,
|
||||
request_body: Dict[str, Any],
|
||||
clientip: str,
|
||||
):
|
||||
"""
|
||||
Checks that the user is who they claim to be, via a UI auth.
|
||||
@@ -137,6 +141,8 @@ class AuthHandler(BaseHandler):
|
||||
Args:
|
||||
requester: The user, as given by the access token
|
||||
|
||||
request: The request sent by the client.
|
||||
|
||||
request_body: The body of the request sent by the client
|
||||
|
||||
clientip: The IP address of the client.
|
||||
@@ -172,7 +178,9 @@ class AuthHandler(BaseHandler):
|
||||
flows = [[login_type] for login_type in self._supported_login_types]
|
||||
|
||||
try:
|
||||
result, params, _ = yield self.check_auth(flows, request_body, clientip)
|
||||
result, params, _ = yield self.check_auth(
|
||||
flows, request, request_body, clientip
|
||||
)
|
||||
except LoginError:
|
||||
# Update the ratelimite to say we failed (`can_do_action` doesn't raise).
|
||||
self._failed_uia_attempts_ratelimiter.can_do_action(
|
||||
@@ -211,7 +219,11 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(
|
||||
self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
|
||||
self,
|
||||
flows: List[List[str]],
|
||||
request: SynapseRequest,
|
||||
clientdict: Dict[str, Any],
|
||||
clientip: str,
|
||||
):
|
||||
"""
|
||||
Takes a dictionary sent by the client in the login / registration
|
||||
@@ -231,6 +243,8 @@ class AuthHandler(BaseHandler):
|
||||
strings representing auth-types. At least one full
|
||||
flow must be completed in order for auth to be successful.
|
||||
|
||||
request: The request sent by the client.
|
||||
|
||||
clientdict: The dictionary from the client root level, not the
|
||||
'auth' key: this method prompts for auth if none is sent.
|
||||
|
||||
@@ -270,13 +284,27 @@ class AuthHandler(BaseHandler):
|
||||
# email auth link on there). It's probably too open to abuse
|
||||
# because it lets unauthenticated clients store arbitrary objects
|
||||
# on a homeserver.
|
||||
# Revisit: Assumimg the REST APIs do sensible validation, the data
|
||||
# Revisit: Assuming the REST APIs do sensible validation, the data
|
||||
# isn't arbintrary.
|
||||
session["clientdict"] = clientdict
|
||||
self._save_session(session)
|
||||
elif "clientdict" in session:
|
||||
clientdict = session["clientdict"]
|
||||
|
||||
# Ensure that the queried operation does not vary between stages of
|
||||
# the UI authentication session. This is done by generating a stable
|
||||
# comparator based on the URI, method, and body (minus the auth dict)
|
||||
# and storing it during the initial query. Subsequent queries ensure
|
||||
# that this comparator has not changed.
|
||||
comparator = (request.uri, request.method, clientdict)
|
||||
if "ui_auth" not in session:
|
||||
session["ui_auth"] = comparator
|
||||
elif session["ui_auth"] != comparator:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Requested operation has changed during the UI authentication session.",
|
||||
)
|
||||
|
||||
if not authdict:
|
||||
raise InteractiveAuthIncompleteError(
|
||||
self._auth_dict_for_flows(flows, session)
|
||||
@@ -322,6 +350,7 @@ class AuthHandler(BaseHandler):
|
||||
creds,
|
||||
list(clientdict),
|
||||
)
|
||||
|
||||
return creds, clientdict, session["id"]
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, session)
|
||||
|
||||
204
synapse/handlers/cas_handler.py
Normal file
204
synapse/handlers/cas_handler.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import AnyStr, Dict, Optional, Tuple
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
from synapse.api.errors import Codes, LoginError
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CasHandler:
|
||||
"""
|
||||
Utility class for to handle the response from a CAS SSO service.
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer)
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self._hostname = hs.hostname
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
|
||||
self._cas_server_url = hs.config.cas_server_url
|
||||
self._cas_service_url = hs.config.cas_service_url
|
||||
self._cas_displayname_attribute = hs.config.cas_displayname_attribute
|
||||
self._cas_required_attributes = hs.config.cas_required_attributes
|
||||
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
|
||||
def _build_service_param(self, client_redirect_url: AnyStr) -> str:
|
||||
return "%s%s?%s" % (
|
||||
self._cas_service_url,
|
||||
"/_matrix/client/r0/login/cas/ticket",
|
||||
urllib.parse.urlencode({"redirectUrl": client_redirect_url}),
|
||||
)
|
||||
|
||||
async def _handle_cas_response(
|
||||
self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str
|
||||
) -> None:
|
||||
"""
|
||||
Retrieves the user and display name from the CAS response and continues with the authentication.
|
||||
|
||||
Args:
|
||||
request: The original client request.
|
||||
cas_response_body: The response from the CAS server.
|
||||
client_redirect_url: The URl to redirect the client to when
|
||||
everything is done.
|
||||
"""
|
||||
user, attributes = self._parse_cas_response(cas_response_body)
|
||||
displayname = attributes.pop(self._cas_displayname_attribute, None)
|
||||
|
||||
for required_attribute, required_value in self._cas_required_attributes.items():
|
||||
# If required attribute was not in CAS Response - Forbidden
|
||||
if required_attribute not in attributes:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
# Also need to check value
|
||||
if required_value is not None:
|
||||
actual_value = attributes[required_attribute]
|
||||
# If required attribute value does not match expected - Forbidden
|
||||
if required_value != actual_value:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
await self._on_successful_auth(user, request, client_redirect_url, displayname)
|
||||
|
||||
def _parse_cas_response(
|
||||
self, cas_response_body: str
|
||||
) -> Tuple[str, Dict[str, Optional[str]]]:
|
||||
"""
|
||||
Retrieve the user and other parameters from the CAS response.
|
||||
|
||||
Args:
|
||||
cas_response_body: The response from the CAS query.
|
||||
|
||||
Returns:
|
||||
A tuple of the user and a mapping of other attributes.
|
||||
"""
|
||||
user = None
|
||||
attributes = {}
|
||||
try:
|
||||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise Exception("root of CAS response is not serviceResponse")
|
||||
success = root[0].tag.endswith("authenticationSuccess")
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
if child.tag.endswith("attributes"):
|
||||
for attribute in child:
|
||||
# ElementTree library expands the namespace in
|
||||
# attribute tags to the full URL of the namespace.
|
||||
# We don't care about namespace here and it will always
|
||||
# be encased in curly braces, so we remove them.
|
||||
tag = attribute.tag
|
||||
if "}" in tag:
|
||||
tag = tag.split("}")[1]
|
||||
attributes[tag] = attribute.text
|
||||
if user is None:
|
||||
raise Exception("CAS response does not contain user")
|
||||
except Exception:
|
||||
logger.exception("Error parsing CAS response")
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not success:
|
||||
raise LoginError(
|
||||
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
return user, attributes
|
||||
|
||||
async def _on_successful_auth(
|
||||
self,
|
||||
username: str,
|
||||
request: SynapseRequest,
|
||||
client_redirect_url: str,
|
||||
user_display_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Called once the user has successfully authenticated with the SSO.
|
||||
|
||||
Registers the user if necessary, and then returns a redirect (with
|
||||
a login token) to the client.
|
||||
|
||||
Args:
|
||||
username: the remote user id. We'll map this onto
|
||||
something sane for a MXID localpath.
|
||||
|
||||
request: the incoming request from the browser. We'll
|
||||
respond to it with a redirect.
|
||||
|
||||
client_redirect_url: the redirect_url the client gave us when
|
||||
it first started the process.
|
||||
|
||||
user_display_name: if set, and we have to register a new user,
|
||||
we will set their displayname to this.
|
||||
"""
|
||||
localpart = map_username_to_mxid_localpart(username)
|
||||
user_id = UserID(localpart, self._hostname).to_string()
|
||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||
if not registered_user_id:
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=user_display_name
|
||||
)
|
||||
|
||||
self._auth_handler.complete_sso_login(
|
||||
registered_user_id, request, client_redirect_url
|
||||
)
|
||||
|
||||
def handle_redirect_request(self, client_redirect_url: bytes) -> bytes:
|
||||
"""
|
||||
Generates a URL to the CAS server where the client should be redirected.
|
||||
|
||||
Args:
|
||||
client_redirect_url: The final URL the client should go to after the
|
||||
user has negotiated SSO.
|
||||
|
||||
Returns:
|
||||
The URL to redirect to.
|
||||
"""
|
||||
args = urllib.parse.urlencode(
|
||||
{"service": self._build_service_param(client_redirect_url)}
|
||||
)
|
||||
|
||||
return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii")
|
||||
|
||||
async def handle_ticket_request(
|
||||
self, request: SynapseRequest, client_redirect_url: str, ticket: str
|
||||
) -> None:
|
||||
"""
|
||||
Validates a CAS ticket sent by the client for login/registration.
|
||||
|
||||
On a successful request, writes a redirect to the request.
|
||||
"""
|
||||
uri = self._cas_server_url + "/proxyValidate"
|
||||
args = {
|
||||
"ticket": ticket,
|
||||
"service": self._build_service_param(client_redirect_url),
|
||||
}
|
||||
try:
|
||||
body = await self._http_client.get_raw(uri, args)
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted raises this error if the connection is closed,
|
||||
# even if that's being used old-http style to signal end-of-data
|
||||
body = pde.response
|
||||
|
||||
await self._handle_cas_response(request, body, client_redirect_url)
|
||||
@@ -125,8 +125,14 @@ class DeviceWorkerHandler(BaseHandler):
|
||||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
tracked_users = set(users_who_share_room)
|
||||
|
||||
# Always tell the user about their own devices
|
||||
tracked_users.add(user_id)
|
||||
|
||||
changed = yield self.store.get_users_whose_devices_changed(
|
||||
from_token.device_list_key, users_who_share_room
|
||||
from_token.device_list_key, tracked_users
|
||||
)
|
||||
|
||||
# Then work out if any users have since joined
|
||||
@@ -456,7 +462,11 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
|
||||
yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
|
||||
# specify the user ID too since the user should always get their own device list
|
||||
# updates, even if they aren't in any rooms.
|
||||
yield self.notifier.on_new_event(
|
||||
"device_list_key", position, users=[user_id], rooms=room_ids
|
||||
)
|
||||
|
||||
if hosts:
|
||||
logger.info(
|
||||
|
||||
@@ -851,6 +851,38 @@ class EventCreationHandler(object):
|
||||
self.store.remove_push_actions_from_staging, event.event_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _validate_canonical_alias(
|
||||
self, directory_handler, room_alias_str, expected_room_id
|
||||
):
|
||||
"""
|
||||
Ensure that the given room alias points to the expected room ID.
|
||||
|
||||
Args:
|
||||
directory_handler: The directory handler object.
|
||||
room_alias_str: The room alias to check.
|
||||
expected_room_id: The room ID that the alias should point to.
|
||||
"""
|
||||
room_alias = RoomAlias.from_string(room_alias_str)
|
||||
try:
|
||||
mapping = yield directory_handler.get_association(room_alias)
|
||||
except SynapseError as e:
|
||||
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
|
||||
if e.errcode == Codes.NOT_FOUND:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Room alias %s does not point to the room" % (room_alias_str,),
|
||||
Codes.BAD_ALIAS,
|
||||
)
|
||||
raise
|
||||
|
||||
if mapping["room_id"] != expected_room_id:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Room alias %s does not point to the room" % (room_alias_str,),
|
||||
Codes.BAD_ALIAS,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def persist_and_notify_client_event(
|
||||
self, requester, event, context, ratelimit=True, extra_users=[]
|
||||
@@ -905,15 +937,9 @@ class EventCreationHandler(object):
|
||||
room_alias_str = event.content.get("alias", None)
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
if room_alias_str and room_alias_str != original_alias:
|
||||
room_alias = RoomAlias.from_string(room_alias_str)
|
||||
mapping = yield directory_handler.get_association(room_alias)
|
||||
|
||||
if mapping["room_id"] != event.room_id:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Room alias %s does not point to the room" % (room_alias_str,),
|
||||
Codes.BAD_ALIAS,
|
||||
)
|
||||
yield self._validate_canonical_alias(
|
||||
directory_handler, room_alias_str, event.room_id
|
||||
)
|
||||
|
||||
# Check that alt_aliases is the proper form.
|
||||
alt_aliases = event.content.get("alt_aliases", [])
|
||||
@@ -931,16 +957,9 @@ class EventCreationHandler(object):
|
||||
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
|
||||
if new_alt_aliases:
|
||||
for alias_str in new_alt_aliases:
|
||||
room_alias = RoomAlias.from_string(alias_str)
|
||||
mapping = yield directory_handler.get_association(room_alias)
|
||||
|
||||
if mapping["room_id"] != event.room_id:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Room alias %s does not point to the room"
|
||||
% (room_alias_str,),
|
||||
Codes.BAD_ALIAS,
|
||||
)
|
||||
yield self._validate_canonical_alias(
|
||||
directory_handler, alias_str, event.room_id
|
||||
)
|
||||
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
|
||||
|
||||
93
synapse/handlers/password_policy.py
Normal file
93
synapse/handlers/password_policy.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from synapse.api.errors import Codes, PasswordRefusedError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordPolicyHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.policy = hs.config.password_policy
|
||||
self.enabled = hs.config.password_policy_enabled
|
||||
|
||||
# Regexps for the spec'd policy parameters.
|
||||
self.regexp_digit = re.compile("[0-9]")
|
||||
self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
|
||||
self.regexp_uppercase = re.compile("[A-Z]")
|
||||
self.regexp_lowercase = re.compile("[a-z]")
|
||||
|
||||
def validate_password(self, password):
|
||||
"""Checks whether a given password complies with the server's policy.
|
||||
|
||||
Args:
|
||||
password (str): The password to check against the server's policy.
|
||||
|
||||
Raises:
|
||||
PasswordRefusedError: The password doesn't comply with the server's policy.
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
minimum_accepted_length = self.policy.get("minimum_length", 0)
|
||||
if len(password) < minimum_accepted_length:
|
||||
raise PasswordRefusedError(
|
||||
msg=(
|
||||
"The password must be at least %d characters long"
|
||||
% minimum_accepted_length
|
||||
),
|
||||
errcode=Codes.PASSWORD_TOO_SHORT,
|
||||
)
|
||||
|
||||
if (
|
||||
self.policy.get("require_digit", False)
|
||||
and self.regexp_digit.search(password) is None
|
||||
):
|
||||
raise PasswordRefusedError(
|
||||
msg="The password must include at least one digit",
|
||||
errcode=Codes.PASSWORD_NO_DIGIT,
|
||||
)
|
||||
|
||||
if (
|
||||
self.policy.get("require_symbol", False)
|
||||
and self.regexp_symbol.search(password) is None
|
||||
):
|
||||
raise PasswordRefusedError(
|
||||
msg="The password must include at least one symbol",
|
||||
errcode=Codes.PASSWORD_NO_SYMBOL,
|
||||
)
|
||||
|
||||
if (
|
||||
self.policy.get("require_uppercase", False)
|
||||
and self.regexp_uppercase.search(password) is None
|
||||
):
|
||||
raise PasswordRefusedError(
|
||||
msg="The password must include at least one uppercase letter",
|
||||
errcode=Codes.PASSWORD_NO_UPPERCASE,
|
||||
)
|
||||
|
||||
if (
|
||||
self.policy.get("require_lowercase", False)
|
||||
and self.regexp_lowercase.search(password) is None
|
||||
):
|
||||
raise PasswordRefusedError(
|
||||
msg="The password must include at least one lowercase letter",
|
||||
errcode=Codes.PASSWORD_NO_LOWERCASE,
|
||||
)
|
||||
@@ -747,7 +747,7 @@ class PresenceHandler(object):
|
||||
|
||||
return False
|
||||
|
||||
async def get_all_presence_updates(self, last_id, current_id):
|
||||
async def get_all_presence_updates(self, last_id, current_id, limit):
|
||||
"""
|
||||
Gets a list of presence update rows from between the given stream ids.
|
||||
Each row has:
|
||||
@@ -762,7 +762,7 @@ class PresenceHandler(object):
|
||||
"""
|
||||
# TODO(markjh): replicate the unpersisted changes.
|
||||
# This could use the in-memory stores for recent changes.
|
||||
rows = await self.store.get_all_presence_updates(last_id, current_id)
|
||||
rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
|
||||
return rows
|
||||
|
||||
def notify_new_event(self):
|
||||
|
||||
@@ -157,6 +157,15 @@ class BaseProfileHandler(BaseHandler):
|
||||
if not by_admin and target_user != requester.user:
|
||||
raise AuthError(400, "Cannot set another user's displayname")
|
||||
|
||||
if not by_admin and not self.hs.config.enable_set_displayname:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
if profile.display_name:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Changing display name is disabled on this server",
|
||||
Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
|
||||
raise SynapseError(
|
||||
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
|
||||
@@ -218,6 +227,13 @@ class BaseProfileHandler(BaseHandler):
|
||||
if not by_admin and target_user != requester.user:
|
||||
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||
|
||||
if not by_admin and not self.hs.config.enable_set_avatar_url:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
if profile.avatar_url:
|
||||
raise SynapseError(
|
||||
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
|
||||
raise SynapseError(
|
||||
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
|
||||
|
||||
@@ -26,6 +26,7 @@ from synapse.config import ConfigError
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.module_api.errors import RedirectException
|
||||
from synapse.types import (
|
||||
UserID,
|
||||
map_username_to_mxid_localpart,
|
||||
@@ -119,6 +120,9 @@ class SamlHandler:
|
||||
|
||||
try:
|
||||
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
|
||||
except RedirectException:
|
||||
# Raise the exception as per the wishes of the SAML module response
|
||||
raise
|
||||
except Exception as e:
|
||||
# If decoding the response or mapping it to a user failed, then log the
|
||||
# error and tell the user that something went wrong.
|
||||
|
||||
@@ -32,6 +32,7 @@ class SetPasswordHandler(BaseHandler):
|
||||
super(SetPasswordHandler, self).__init__(hs)
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self._password_policy_handler = hs.get_password_policy_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_password(
|
||||
@@ -44,6 +45,7 @@ class SetPasswordHandler(BaseHandler):
|
||||
if not self.hs.config.password_localdb_enabled:
|
||||
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
||||
|
||||
self._password_policy_handler.validate_password(new_password)
|
||||
password_hash = yield self._auth_handler.hash(new_password)
|
||||
|
||||
try:
|
||||
|
||||
@@ -26,7 +26,7 @@ from prometheus_client import Counter
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.logging.context import current_context
|
||||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
from synapse.storage.roommember import MemberSummary
|
||||
from synapse.storage.state import StateFilter
|
||||
@@ -301,7 +301,7 @@ class SyncHandler(object):
|
||||
else:
|
||||
sync_type = "incremental_sync"
|
||||
|
||||
context = LoggingContext.current_context()
|
||||
context = current_context()
|
||||
if context:
|
||||
context.tag = sync_type
|
||||
|
||||
@@ -1143,9 +1143,14 @@ class SyncHandler(object):
|
||||
user_id
|
||||
)
|
||||
|
||||
tracked_users = set(users_who_share_room)
|
||||
|
||||
# Always tell the user about their own devices
|
||||
tracked_users.add(user_id)
|
||||
|
||||
# Step 1a, check for changes in devices of users we share a room with
|
||||
users_that_have_changed = await self.store.get_users_whose_devices_changed(
|
||||
since_token.device_list_key, users_who_share_room
|
||||
since_token.device_list_key, tracked_users
|
||||
)
|
||||
|
||||
# Step 1b, check for newly joined rooms
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -257,7 +258,13 @@ class TypingHandler(object):
|
||||
"typing_key", self._latest_room_serial, rooms=[member.room_id]
|
||||
)
|
||||
|
||||
async def get_all_typing_updates(self, last_id, current_id):
|
||||
async def get_all_typing_updates(
|
||||
self, last_id: int, current_id: int, limit: int
|
||||
) -> List[dict]:
|
||||
"""Get up to `limit` typing updates between the given tokens, earliest
|
||||
updates first.
|
||||
"""
|
||||
|
||||
if last_id == current_id:
|
||||
return []
|
||||
|
||||
@@ -275,7 +282,7 @@ class TypingHandler(object):
|
||||
typing = self._room_typing[room_id]
|
||||
rows.append((serial, room_id, list(typing)))
|
||||
rows.sort()
|
||||
return rows
|
||||
return rows[:limit]
|
||||
|
||||
def get_current_token(self):
|
||||
return self._latest_room_serial
|
||||
|
||||
@@ -19,7 +19,7 @@ import threading
|
||||
|
||||
from prometheus_client.core import Counter, Histogram
|
||||
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.logging.context import current_context
|
||||
from synapse.metrics import LaterGauge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -148,7 +148,7 @@ LaterGauge(
|
||||
class RequestMetrics(object):
|
||||
def start(self, time_sec, name, method):
|
||||
self.start = time_sec
|
||||
self.start_context = LoggingContext.current_context()
|
||||
self.start_context = current_context()
|
||||
self.name = name
|
||||
self.method = method
|
||||
|
||||
@@ -163,7 +163,7 @@ class RequestMetrics(object):
|
||||
with _in_flight_requests_lock:
|
||||
_in_flight_requests.discard(self)
|
||||
|
||||
context = LoggingContext.current_context()
|
||||
context = current_context()
|
||||
|
||||
tag = ""
|
||||
if context:
|
||||
|
||||
@@ -42,7 +42,7 @@ from synapse.logging._terse_json import (
|
||||
TerseJSONToConsoleLogObserver,
|
||||
TerseJSONToTCPLogObserver,
|
||||
)
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.logging.context import current_context
|
||||
|
||||
|
||||
def stdlib_log_level_to_twisted(level: str) -> LogLevel:
|
||||
@@ -86,7 +86,7 @@ class LogContextObserver(object):
|
||||
].startswith("Timing out client"):
|
||||
return
|
||||
|
||||
context = LoggingContext.current_context()
|
||||
context = current_context()
|
||||
|
||||
# Copy the context information to the log event.
|
||||
if context is not None:
|
||||
|
||||
@@ -175,7 +175,54 @@ class ContextResourceUsage(object):
|
||||
return res
|
||||
|
||||
|
||||
LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
|
||||
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
|
||||
|
||||
|
||||
class _Sentinel(object):
|
||||
"""Sentinel to represent the root context"""
|
||||
|
||||
__slots__ = ["previous_context", "finished", "request", "scope", "tag"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Minimal set for compatibility with LoggingContext
|
||||
self.previous_context = None
|
||||
self.finished = False
|
||||
self.request = None
|
||||
self.scope = None
|
||||
self.tag = None
|
||||
|
||||
def __str__(self):
|
||||
return "sentinel"
|
||||
|
||||
def copy_to(self, record):
|
||||
pass
|
||||
|
||||
def copy_to_twisted_log_entry(self, record):
|
||||
record["request"] = None
|
||||
record["scope"] = None
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def add_database_transaction(self, duration_sec):
|
||||
pass
|
||||
|
||||
def add_database_scheduled(self, sched_sec):
|
||||
pass
|
||||
|
||||
def record_event_fetch(self, event_count):
|
||||
pass
|
||||
|
||||
def __nonzero__(self):
|
||||
return False
|
||||
|
||||
__bool__ = __nonzero__ # python3
|
||||
|
||||
|
||||
SENTINEL_CONTEXT = _Sentinel()
|
||||
|
||||
|
||||
class LoggingContext(object):
|
||||
@@ -199,76 +246,33 @@ class LoggingContext(object):
|
||||
"_resource_usage",
|
||||
"usage_start",
|
||||
"main_thread",
|
||||
"alive",
|
||||
"finished",
|
||||
"request",
|
||||
"tag",
|
||||
"scope",
|
||||
]
|
||||
|
||||
thread_local = threading.local()
|
||||
|
||||
class Sentinel(object):
|
||||
"""Sentinel to represent the root context"""
|
||||
|
||||
__slots__ = ["previous_context", "alive", "request", "scope", "tag"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Minimal set for compatibility with LoggingContext
|
||||
self.previous_context = None
|
||||
self.alive = None
|
||||
self.request = None
|
||||
self.scope = None
|
||||
self.tag = None
|
||||
|
||||
def __str__(self):
|
||||
return "sentinel"
|
||||
|
||||
def copy_to(self, record):
|
||||
pass
|
||||
|
||||
def copy_to_twisted_log_entry(self, record):
|
||||
record["request"] = None
|
||||
record["scope"] = None
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def add_database_transaction(self, duration_sec):
|
||||
pass
|
||||
|
||||
def add_database_scheduled(self, sched_sec):
|
||||
pass
|
||||
|
||||
def record_event_fetch(self, event_count):
|
||||
pass
|
||||
|
||||
def __nonzero__(self):
|
||||
return False
|
||||
|
||||
__bool__ = __nonzero__ # python3
|
||||
|
||||
sentinel = Sentinel()
|
||||
|
||||
def __init__(self, name=None, parent_context=None, request=None) -> None:
|
||||
self.previous_context = LoggingContext.current_context()
|
||||
self.previous_context = current_context()
|
||||
self.name = name
|
||||
|
||||
# track the resources used by this context so far
|
||||
self._resource_usage = ContextResourceUsage()
|
||||
|
||||
# If alive has the thread resource usage when the logcontext last
|
||||
# became active.
|
||||
# The thread resource usage when the logcontext became active. None
|
||||
# if the context is not currently active.
|
||||
self.usage_start = None
|
||||
|
||||
self.main_thread = get_thread_id()
|
||||
self.request = None
|
||||
self.tag = ""
|
||||
self.alive = True
|
||||
self.scope = None # type: Optional[_LogContextScope]
|
||||
|
||||
# keep track of whether we have hit the __exit__ block for this context
|
||||
# (suggesting that the the thing that created the context thinks it should
|
||||
# be finished, and that re-activating it would suggest an error).
|
||||
self.finished = False
|
||||
|
||||
self.parent_context = parent_context
|
||||
|
||||
if self.parent_context is not None:
|
||||
@@ -283,44 +287,15 @@ class LoggingContext(object):
|
||||
return str(self.request)
|
||||
return "%s@%x" % (self.name, id(self))
|
||||
|
||||
@classmethod
|
||||
def current_context(cls) -> LoggingContextOrSentinel:
|
||||
"""Get the current logging context from thread local storage
|
||||
|
||||
Returns:
|
||||
LoggingContext: the current logging context
|
||||
"""
|
||||
return getattr(cls.thread_local, "current_context", cls.sentinel)
|
||||
|
||||
@classmethod
|
||||
def set_current_context(
|
||||
cls, context: LoggingContextOrSentinel
|
||||
) -> LoggingContextOrSentinel:
|
||||
"""Set the current logging context in thread local storage
|
||||
Args:
|
||||
context(LoggingContext): The context to activate.
|
||||
Returns:
|
||||
The context that was previously active
|
||||
"""
|
||||
current = cls.current_context()
|
||||
|
||||
if current is not context:
|
||||
current.stop()
|
||||
cls.thread_local.current_context = context
|
||||
context.start()
|
||||
return current
|
||||
|
||||
def __enter__(self) -> "LoggingContext":
|
||||
"""Enters this logging context into thread local storage"""
|
||||
old_context = self.set_current_context(self)
|
||||
old_context = set_current_context(self)
|
||||
if self.previous_context != old_context:
|
||||
logger.warning(
|
||||
"Expected previous context %r, found %r",
|
||||
self.previous_context,
|
||||
old_context,
|
||||
)
|
||||
self.alive = True
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback) -> None:
|
||||
@@ -329,24 +304,19 @@ class LoggingContext(object):
|
||||
Returns:
|
||||
None to avoid suppressing any exceptions that were thrown.
|
||||
"""
|
||||
current = self.set_current_context(self.previous_context)
|
||||
current = set_current_context(self.previous_context)
|
||||
if current is not self:
|
||||
if current is self.sentinel:
|
||||
if current is SENTINEL_CONTEXT:
|
||||
logger.warning("Expected logging context %s was lost", self)
|
||||
else:
|
||||
logger.warning(
|
||||
"Expected logging context %s but found %s", self, current
|
||||
)
|
||||
self.alive = False
|
||||
|
||||
# if we have a parent, pass our CPU usage stats on
|
||||
if self.parent_context is not None and hasattr(
|
||||
self.parent_context, "_resource_usage"
|
||||
):
|
||||
self.parent_context._resource_usage += self._resource_usage
|
||||
|
||||
# reset them in case we get entered again
|
||||
self._resource_usage.reset()
|
||||
# the fact that we are here suggests that the caller thinks that everything
|
||||
# is done and dusted for this logcontext, and further activity will not get
|
||||
# recorded against the correct metrics.
|
||||
self.finished = True
|
||||
|
||||
def copy_to(self, record) -> None:
|
||||
"""Copy logging fields from this context to a log record or
|
||||
@@ -371,9 +341,14 @@ class LoggingContext(object):
|
||||
logger.warning("Started logcontext %s on different thread", self)
|
||||
return
|
||||
|
||||
if self.finished:
|
||||
logger.warning("Re-starting finished log context %s", self)
|
||||
|
||||
# If we haven't already started record the thread resource usage so
|
||||
# far
|
||||
if not self.usage_start:
|
||||
if self.usage_start:
|
||||
logger.warning("Re-starting already-active log context %s", self)
|
||||
else:
|
||||
self.usage_start = get_thread_resource_usage()
|
||||
|
||||
def stop(self) -> None:
|
||||
@@ -396,6 +371,15 @@ class LoggingContext(object):
|
||||
|
||||
self.usage_start = None
|
||||
|
||||
# if we have a parent, pass our CPU usage stats on
|
||||
if self.parent_context is not None and hasattr(
|
||||
self.parent_context, "_resource_usage"
|
||||
):
|
||||
self.parent_context._resource_usage += self._resource_usage
|
||||
|
||||
# reset them in case we get entered again
|
||||
self._resource_usage.reset()
|
||||
|
||||
def get_resource_usage(self) -> ContextResourceUsage:
|
||||
"""Get resources used by this logcontext so far.
|
||||
|
||||
@@ -409,7 +393,7 @@ class LoggingContext(object):
|
||||
# If we are on the correct thread and we're currently running then we
|
||||
# can include resource usage so far.
|
||||
is_main_thread = get_thread_id() == self.main_thread
|
||||
if self.alive and self.usage_start and is_main_thread:
|
||||
if self.usage_start and is_main_thread:
|
||||
utime_delta, stime_delta = self._get_cputime()
|
||||
res.ru_utime += utime_delta
|
||||
res.ru_stime += stime_delta
|
||||
@@ -492,7 +476,7 @@ class LoggingContextFilter(logging.Filter):
|
||||
Returns:
|
||||
True to include the record in the log output.
|
||||
"""
|
||||
context = LoggingContext.current_context()
|
||||
context = current_context()
|
||||
for key, value in self.defaults.items():
|
||||
setattr(record, key, value)
|
||||
|
||||
@@ -512,27 +496,24 @@ class PreserveLoggingContext(object):
|
||||
|
||||
__slots__ = ["current_context", "new_context", "has_parent"]
|
||||
|
||||
def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
|
||||
if new_context is None:
|
||||
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
|
||||
else:
|
||||
self.new_context = new_context
|
||||
def __init__(
|
||||
self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
|
||||
) -> None:
|
||||
self.new_context = new_context
|
||||
|
||||
def __enter__(self) -> None:
|
||||
"""Captures the current logging context"""
|
||||
self.current_context = LoggingContext.set_current_context(self.new_context)
|
||||
self.current_context = set_current_context(self.new_context)
|
||||
|
||||
if self.current_context:
|
||||
self.has_parent = self.current_context.previous_context is not None
|
||||
if not self.current_context.alive:
|
||||
logger.debug("Entering dead context: %s", self.current_context)
|
||||
|
||||
def __exit__(self, type, value, traceback) -> None:
|
||||
"""Restores the current logging context"""
|
||||
context = LoggingContext.set_current_context(self.current_context)
|
||||
context = set_current_context(self.current_context)
|
||||
|
||||
if context != self.new_context:
|
||||
if context is LoggingContext.sentinel:
|
||||
if not context:
|
||||
logger.warning("Expected logging context %s was lost", self.new_context)
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -541,9 +522,30 @@ class PreserveLoggingContext(object):
|
||||
context,
|
||||
)
|
||||
|
||||
if self.current_context is not LoggingContext.sentinel:
|
||||
if not self.current_context.alive:
|
||||
logger.debug("Restoring dead context: %s", self.current_context)
|
||||
|
||||
_thread_local = threading.local()
|
||||
_thread_local.current_context = SENTINEL_CONTEXT
|
||||
|
||||
|
||||
def current_context() -> LoggingContextOrSentinel:
|
||||
"""Get the current logging context from thread local storage"""
|
||||
return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
|
||||
|
||||
|
||||
def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
|
||||
"""Set the current logging context in thread local storage
|
||||
Args:
|
||||
context(LoggingContext): The context to activate.
|
||||
Returns:
|
||||
The context that was previously active
|
||||
"""
|
||||
current = current_context()
|
||||
|
||||
if current is not context:
|
||||
current.stop()
|
||||
_thread_local.current_context = context
|
||||
context.start()
|
||||
return current
|
||||
|
||||
|
||||
def nested_logging_context(
|
||||
@@ -572,7 +574,7 @@ def nested_logging_context(
|
||||
if parent_context is not None:
|
||||
context = parent_context # type: LoggingContextOrSentinel
|
||||
else:
|
||||
context = LoggingContext.current_context()
|
||||
context = current_context()
|
||||
return LoggingContext(
|
||||
parent_context=context, request=str(context.request) + "-" + suffix
|
||||
)
|
||||
@@ -604,7 +606,7 @@ def run_in_background(f, *args, **kwargs):
|
||||
CRITICAL error about an unhandled error will be logged without much
|
||||
indication about where it came from.
|
||||
"""
|
||||
current = LoggingContext.current_context()
|
||||
current = current_context()
|
||||
try:
|
||||
res = f(*args, **kwargs)
|
||||
except: # noqa: E722
|
||||
@@ -625,7 +627,7 @@ def run_in_background(f, *args, **kwargs):
|
||||
|
||||
# The function may have reset the context before returning, so
|
||||
# we need to restore it now.
|
||||
ctx = LoggingContext.set_current_context(current)
|
||||
ctx = set_current_context(current)
|
||||
|
||||
# The original context will be restored when the deferred
|
||||
# completes, but there is nothing waiting for it, so it will
|
||||
@@ -674,7 +676,7 @@ def make_deferred_yieldable(deferred):
|
||||
|
||||
# ok, we can't be sure that a yield won't block, so let's reset the
|
||||
# logcontext, and add a callback to the deferred to restore it.
|
||||
prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
|
||||
prev_context = set_current_context(SENTINEL_CONTEXT)
|
||||
deferred.addBoth(_set_context_cb, prev_context)
|
||||
return deferred
|
||||
|
||||
@@ -684,7 +686,7 @@ ResultT = TypeVar("ResultT")
|
||||
|
||||
def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
|
||||
"""A callback function which just sets the logging context"""
|
||||
LoggingContext.set_current_context(context)
|
||||
set_current_context(context)
|
||||
return result
|
||||
|
||||
|
||||
@@ -752,7 +754,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
|
||||
Deferred: A Deferred which fires a callback with the result of `f`, or an
|
||||
errback if `f` throws an exception.
|
||||
"""
|
||||
logcontext = LoggingContext.current_context()
|
||||
logcontext = current_context()
|
||||
|
||||
def g():
|
||||
with LoggingContext(parent_context=logcontext):
|
||||
|
||||
@@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager
|
||||
|
||||
import twisted
|
||||
|
||||
from synapse.logging.context import LoggingContext, nested_logging_context
|
||||
from synapse.logging.context import current_context, nested_logging_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager):
|
||||
(Scope) : the Scope that is active, or None if not
|
||||
available.
|
||||
"""
|
||||
ctx = LoggingContext.current_context()
|
||||
if ctx is LoggingContext.sentinel:
|
||||
return None
|
||||
else:
|
||||
return ctx.scope
|
||||
ctx = current_context()
|
||||
return ctx.scope
|
||||
|
||||
def activate(self, span, finish_on_close):
|
||||
"""
|
||||
@@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager):
|
||||
"""
|
||||
|
||||
enter_logcontext = False
|
||||
ctx = LoggingContext.current_context()
|
||||
ctx = current_context()
|
||||
|
||||
if ctx is LoggingContext.sentinel:
|
||||
if not ctx:
|
||||
# We don't want this scope to affect.
|
||||
logger.error("Tried to activate scope outside of loggingcontext")
|
||||
return Scope(None, span)
|
||||
|
||||
@@ -21,6 +21,7 @@ from synapse.replication.http import (
|
||||
membership,
|
||||
register,
|
||||
send_event,
|
||||
streams,
|
||||
)
|
||||
|
||||
REPLICATION_PREFIX = "/_synapse/replication"
|
||||
@@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
|
||||
login.register_servlets(hs, self)
|
||||
register.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
streams.register_servlets(hs, self)
|
||||
|
||||
78
synapse/replication/http/streams.py
Normal file
78
synapse/replication/http/streams.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import parse_integer
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||
"""Fetches stream updates from a server. Used for streams not persisted to
|
||||
the database, e.g. typing notifications.
|
||||
|
||||
The API looks like:
|
||||
|
||||
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
updates: [ ... ],
|
||||
upto_token: 10,
|
||||
limited: False,
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
NAME = "get_repl_stream_updates"
|
||||
PATH_ARGS = ("stream_name",)
|
||||
METHOD = "GET"
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
|
||||
# We pull the streams from the replication steamer (if we try and make
|
||||
# them ourselves we end up in an import loop).
|
||||
self.streams = hs.get_replication_streamer().get_streams()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(stream_name, from_token, upto_token, limit):
|
||||
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
|
||||
|
||||
async def _handle_request(self, request, stream_name):
|
||||
stream = self.streams.get(stream_name)
|
||||
if stream is None:
|
||||
raise SynapseError(400, "Unknown stream")
|
||||
|
||||
from_token = parse_integer(request, "from_token", required=True)
|
||||
upto_token = parse_integer(request, "upto_token", required=True)
|
||||
limit = parse_integer(request, "limit", required=True)
|
||||
|
||||
updates, upto_token, limited = await stream.get_updates_since(
|
||||
from_token, upto_token, limit
|
||||
)
|
||||
|
||||
return (
|
||||
200,
|
||||
{"updates": updates, "upto_token": upto_token, "limited": limited},
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ReplicationGetStreamUpdates(hs).register(http_server)
|
||||
@@ -18,8 +18,10 @@ from typing import Dict, Optional
|
||||
|
||||
import six
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
|
||||
from synapse.storage.data_stores.main.cache import (
|
||||
CURRENT_STATE_CACHE_NAME,
|
||||
CacheInvalidationWorkerStore,
|
||||
)
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
||||
@@ -35,7 +37,7 @@ def __func__(inp):
|
||||
return inp.__func__
|
||||
|
||||
|
||||
class BaseSlavedStore(SQLBaseStore):
|
||||
class BaseSlavedStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
@@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
|
||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||
return pos
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "caches":
|
||||
if self._cache_id_gen:
|
||||
|
||||
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
self.hs = hs
|
||||
|
||||
self._device_list_id_gen = SlavedIdTracker(
|
||||
db_conn, "device_lists_stream", "stream_id"
|
||||
db_conn,
|
||||
"device_lists_stream",
|
||||
"stream_id",
|
||||
extra_tables=[
|
||||
("user_signature_stream", "stream_id"),
|
||||
("device_lists_outbound_pokes", "stream_id"),
|
||||
],
|
||||
)
|
||||
device_list_max = self._device_list_id_gen.get_current_token()
|
||||
self._device_list_stream_cache = StreamChangeCache(
|
||||
@@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
|
||||
self._invalidate_caches_for_devices(token, rows)
|
||||
elif stream_name == UserSignatureStream.NAME:
|
||||
self._device_list_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||
return super(SlavedDeviceStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
|
||||
def _invalidate_caches_for_devices(self, token, user_id, destination):
|
||||
self._device_list_stream_cache.entity_has_changed(user_id, token)
|
||||
def _invalidate_caches_for_devices(self, token, rows):
|
||||
for row in rows:
|
||||
# The entities are either user IDs (starting with '@') whose devices
|
||||
# have changed, or remote servers that we need to tell about
|
||||
# changes.
|
||||
if row.entity.startswith("@"):
|
||||
self._device_list_stream_cache.entity_has_changed(row.entity, token)
|
||||
self.get_cached_devices_for_user.invalidate((row.entity,))
|
||||
self._get_cached_user_device.invalidate_many((row.entity,))
|
||||
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
|
||||
|
||||
if destination:
|
||||
self._device_list_federation_stream_cache.entity_has_changed(
|
||||
destination, token
|
||||
)
|
||||
|
||||
self.get_cached_devices_for_user.invalidate((user_id,))
|
||||
self._get_cached_user_device.invalidate_many((user_id,))
|
||||
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
|
||||
else:
|
||||
self._device_list_federation_stream_cache.entity_has_changed(
|
||||
row.entity, token
|
||||
)
|
||||
|
||||
@@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def get_pushers_stream_token(self):
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "pushers":
|
||||
self._pushers_id_gen.advance(token)
|
||||
|
||||
@@ -16,26 +16,10 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.tcp.protocol import (
|
||||
AbstractReplicationClientHandler,
|
||||
ClientReplicationStreamProtocol,
|
||||
)
|
||||
|
||||
from .commands import (
|
||||
Command,
|
||||
FederationAckCommand,
|
||||
InvalidateCacheCommand,
|
||||
RemoteServerUpCommand,
|
||||
RemovePusherCommand,
|
||||
UserIpCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,10 +35,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
initialDelay = 0.1
|
||||
maxDelay = 1 # Try at least once every N seconds
|
||||
|
||||
def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
|
||||
def __init__(self, hs, client_name):
|
||||
self.client_name = client_name
|
||||
self.handler = handler
|
||||
self.handler = hs.get_tcp_replication()
|
||||
self.server_name = hs.config.server_name
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock() # As self.clock is defined in super class
|
||||
|
||||
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
|
||||
@@ -65,7 +50,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
def buildProtocol(self, addr):
|
||||
logger.info("Connected to replication: %r", addr)
|
||||
return ClientReplicationStreamProtocol(
|
||||
self.client_name, self.server_name, self._clock, self.handler
|
||||
self.hs, self.client_name, self.server_name, self._clock, self.handler,
|
||||
)
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
@@ -75,170 +60,3 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
logger.error("Failed to connect to replication: %r", reason)
|
||||
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
|
||||
|
||||
|
||||
class ReplicationClientHandler(AbstractReplicationClientHandler):
|
||||
"""A base handler that can be passed to the ReplicationClientFactory.
|
||||
|
||||
By default proxies incoming replication data to the SlaveStore.
|
||||
"""
|
||||
|
||||
def __init__(self, store: BaseSlavedStore):
|
||||
self.store = store
|
||||
|
||||
# The current connection. None if we are currently (re)connecting
|
||||
self.connection = None
|
||||
|
||||
# Any pending commands to be sent once a new connection has been
|
||||
# established
|
||||
self.pending_commands = [] # type: List[Command]
|
||||
|
||||
# Map from string -> deferred, to wake up when receiveing a SYNC with
|
||||
# the given string.
|
||||
# Used for tests.
|
||||
self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
|
||||
|
||||
# The factory used to create connections.
|
||||
self.factory = None # type: Optional[ReplicationClientFactory]
|
||||
|
||||
def start_replication(self, hs):
|
||||
"""Helper method to start a replication connection to the remote server
|
||||
using TCP.
|
||||
"""
|
||||
client_name = hs.config.worker_name
|
||||
self.factory = ReplicationClientFactory(hs, client_name, self)
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, self.factory)
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||
handle more.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
self.store.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
async def on_position(self, stream_name, token):
|
||||
"""Called when we get new position data. By default this just pokes
|
||||
the slave store.
|
||||
|
||||
Can be overriden in subclasses to handle more.
|
||||
"""
|
||||
self.store.process_replication_rows(stream_name, token, [])
|
||||
|
||||
def on_sync(self, data):
|
||||
"""When we received a SYNC we wake up any deferreds that were waiting
|
||||
for the sync with the given data.
|
||||
|
||||
Used by tests.
|
||||
"""
|
||||
d = self.awaiting_syncs.pop(data, None)
|
||||
if d:
|
||||
d.callback(data)
|
||||
|
||||
def on_remote_server_up(self, server: str):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
|
||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
args = self.store.stream_positions()
|
||||
user_account_data = args.pop("user_account_data", None)
|
||||
room_account_data = args.pop("room_account_data", None)
|
||||
if user_account_data:
|
||||
args["account_data"] = user_account_data
|
||||
elif room_account_data:
|
||||
args["account_data"] = room_account_data
|
||||
|
||||
return args
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users. (Overriden by the synchrotron's only)
|
||||
"""
|
||||
return []
|
||||
|
||||
def send_command(self, cmd):
|
||||
"""Send a command to master (when we get establish a connection if we
|
||||
don't have one already.)
|
||||
"""
|
||||
if self.connection:
|
||||
self.connection.send_command(cmd)
|
||||
else:
|
||||
logger.warning("Queuing command as not connected: %r", cmd.NAME)
|
||||
self.pending_commands.append(cmd)
|
||||
|
||||
def send_federation_ack(self, token):
|
||||
"""Ack data for the federation stream. This allows the master to drop
|
||||
data stored purely in memory.
|
||||
"""
|
||||
self.send_command(FederationAckCommand(token))
|
||||
|
||||
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||
"""Poke the master that a user has started/stopped syncing.
|
||||
"""
|
||||
self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
|
||||
|
||||
def send_remove_pusher(self, app_id, push_key, user_id):
|
||||
"""Poke the master to remove a pusher for a user
|
||||
"""
|
||||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_invalidate_cache(self, cache_func, keys):
|
||||
"""Poke the master to invalidate a cache.
|
||||
"""
|
||||
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
||||
"""Tell the master that the user made a request.
|
||||
"""
|
||||
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def await_sync(self, data):
|
||||
"""Returns a deferred that is resolved when we receive a SYNC command
|
||||
with given data.
|
||||
|
||||
[Not currently] used by tests.
|
||||
"""
|
||||
return self.awaiting_syncs.setdefault(data, defer.Deferred())
|
||||
|
||||
def update_connection(self, connection):
|
||||
"""Called when a connection has been established (or lost with None).
|
||||
"""
|
||||
self.connection = connection
|
||||
if connection:
|
||||
for cmd in self.pending_commands:
|
||||
connection.send_command(cmd)
|
||||
self.pending_commands = []
|
||||
|
||||
def finished_connecting(self):
|
||||
"""Called when we have successfully subscribed and caught up to all
|
||||
streams we're interested in.
|
||||
"""
|
||||
logger.info("Finished connecting to server")
|
||||
|
||||
# We don't reset the delay any earlier as otherwise if there is a
|
||||
# problem during start up we'll end up tight looping connecting to the
|
||||
# server.
|
||||
if self.factory:
|
||||
self.factory.resetDelay()
|
||||
|
||||
@@ -136,8 +136,8 @@ class PositionCommand(Command):
|
||||
"""Sent by the server to tell the client the stream postition without
|
||||
needing to send an RDATA.
|
||||
|
||||
Sent to the client after all missing updates for a stream have been sent
|
||||
to the client and they're now up to date.
|
||||
On receipt of a POSITION command clients should check if they have missed
|
||||
any updates, and if so then fetch them out of band.
|
||||
"""
|
||||
|
||||
NAME = "POSITION"
|
||||
@@ -179,42 +179,24 @@ class NameCommand(Command):
|
||||
|
||||
|
||||
class ReplicateCommand(Command):
|
||||
"""Sent by the client to subscribe to the stream.
|
||||
"""Sent by the client to subscribe to streams.
|
||||
|
||||
Format::
|
||||
|
||||
REPLICATE <stream_name> <token>
|
||||
|
||||
Where <token> may be either:
|
||||
* a numeric stream_id to stream updates from
|
||||
* "NOW" to stream all subsequent updates.
|
||||
|
||||
The <stream_name> can be "ALL" to subscribe to all known streams, in which
|
||||
case the <token> must be set to "NOW", i.e.::
|
||||
|
||||
REPLICATE ALL NOW
|
||||
REPLICATE
|
||||
"""
|
||||
|
||||
NAME = "REPLICATE"
|
||||
|
||||
def __init__(self, stream_name, token):
|
||||
self.stream_name = stream_name
|
||||
self.token = token
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
stream_name, token = line.split(" ", 1)
|
||||
if token in ("NOW", "now"):
|
||||
token = "NOW"
|
||||
else:
|
||||
token = int(token)
|
||||
return cls(stream_name, token)
|
||||
return cls()
|
||||
|
||||
def to_line(self):
|
||||
return " ".join((self.stream_name, str(self.token)))
|
||||
|
||||
def get_logcontext_id(self):
|
||||
return "REPLICATE-" + self.stream_name
|
||||
return ""
|
||||
|
||||
|
||||
class UserSyncCommand(Command):
|
||||
@@ -225,30 +207,32 @@ class UserSyncCommand(Command):
|
||||
|
||||
Format::
|
||||
|
||||
USER_SYNC <user_id> <state> <last_sync_ms>
|
||||
USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
|
||||
|
||||
Where <state> is either "start" or "stop"
|
||||
"""
|
||||
|
||||
NAME = "USER_SYNC"
|
||||
|
||||
def __init__(self, user_id, is_syncing, last_sync_ms):
|
||||
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
|
||||
self.instance_id = instance_id
|
||||
self.user_id = user_id
|
||||
self.is_syncing = is_syncing
|
||||
self.last_sync_ms = last_sync_ms
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
user_id, state, last_sync_ms = line.split(" ", 2)
|
||||
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
|
||||
|
||||
if state not in ("start", "end"):
|
||||
raise Exception("Invalid USER_SYNC state %r" % (state,))
|
||||
|
||||
return cls(user_id, state == "start", int(last_sync_ms))
|
||||
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
|
||||
|
||||
def to_line(self):
|
||||
return " ".join(
|
||||
(
|
||||
self.instance_id,
|
||||
self.user_id,
|
||||
"start" if self.is_syncing else "end",
|
||||
str(self.last_sync_ms),
|
||||
@@ -256,6 +240,30 @@ class UserSyncCommand(Command):
|
||||
)
|
||||
|
||||
|
||||
class ClearUserSyncsCommand(Command):
|
||||
"""Sent by the client to inform the server that it should drop all
|
||||
information about syncing users sent by the client.
|
||||
|
||||
Mainly used when client is about to shut down.
|
||||
|
||||
Format::
|
||||
|
||||
CLEAR_USER_SYNC <instance_id>
|
||||
"""
|
||||
|
||||
NAME = "CLEAR_USER_SYNC"
|
||||
|
||||
def __init__(self, instance_id):
|
||||
self.instance_id = instance_id
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
return cls(line)
|
||||
|
||||
def to_line(self):
|
||||
return self.instance_id
|
||||
|
||||
|
||||
class FederationAckCommand(Command):
|
||||
"""Sent by the client when it has processed up to a given point in the
|
||||
federation stream. This allows the master to drop in-memory caches of the
|
||||
@@ -416,6 +424,7 @@ _COMMANDS = (
|
||||
InvalidateCacheCommand,
|
||||
UserIpCommand,
|
||||
RemoteServerUpCommand,
|
||||
ClearUserSyncsCommand,
|
||||
) # type: Tuple[Type[Command], ...]
|
||||
|
||||
# Map of command name to command type.
|
||||
@@ -438,6 +447,7 @@ VALID_CLIENT_COMMANDS = (
|
||||
ReplicateCommand.NAME,
|
||||
PingCommand.NAME,
|
||||
UserSyncCommand.NAME,
|
||||
ClearUserSyncsCommand.NAME,
|
||||
FederationAckCommand.NAME,
|
||||
RemovePusherCommand.NAME,
|
||||
InvalidateCacheCommand.NAME,
|
||||
|
||||
410
synapse/replication/tcp/handler.py
Normal file
410
synapse/replication/tcp/handler.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A replication client for use by synapse workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.replication.tcp.commands import (
|
||||
ClearUserSyncsCommand,
|
||||
Command,
|
||||
FederationAckCommand,
|
||||
InvalidateCacheCommand,
|
||||
PositionCommand,
|
||||
RdataCommand,
|
||||
RemoteServerUpCommand,
|
||||
RemovePusherCommand,
|
||||
ReplicateCommand,
|
||||
UserIpCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
|
||||
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
|
||||
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
|
||||
invalidate_cache_counter = Counter(
|
||||
"synapse_replication_tcp_resource_invalidate_cache", ""
|
||||
)
|
||||
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
|
||||
|
||||
|
||||
class ReplicationClientHandler:
|
||||
"""Handles incoming commands from replication.
|
||||
|
||||
Proxies data to `HomeServer.get_replication_data_handler()`.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.replication_data_handler = hs.get_replication_data_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.clock = hs.get_clock()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.instance_id = hs.get_instance_id()
|
||||
|
||||
self._position_linearizer = Linearizer("replication_position")
|
||||
|
||||
self.connections = [] # type: List[Any]
|
||||
|
||||
self.streams = {
|
||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||
} # type: Dict[str, Stream]
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_total_connections",
|
||||
"",
|
||||
[],
|
||||
lambda: len(self.connections),
|
||||
)
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_connections_per_stream",
|
||||
"",
|
||||
["stream_name"],
|
||||
lambda: {
|
||||
(stream_name,): len(
|
||||
[
|
||||
conn
|
||||
for conn in self.connections
|
||||
if stream_name in conn.replication_streams
|
||||
]
|
||||
)
|
||||
for stream_name in self.streams
|
||||
},
|
||||
)
|
||||
|
||||
# Map of stream to batched updates. See RdataCommand for info on how
|
||||
# batching works.
|
||||
self.pending_batches = {} # type: Dict[str, List[Any]]
|
||||
|
||||
self.is_master = hs.config.worker_app is None
|
||||
|
||||
self.federation_sender = None
|
||||
if self.is_master and not hs.config.send_federation:
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
|
||||
self._server_notices_sender = None
|
||||
if self.is_master:
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
|
||||
|
||||
def new_connection(self, connection):
|
||||
self.connections.append(connection)
|
||||
|
||||
def lost_connection(self, connection):
|
||||
try:
|
||||
self.connections.remove(connection)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def connected(self) -> bool:
|
||||
"""Do we have any replication connections open?
|
||||
|
||||
Used to no-op if nothing is connected.
|
||||
"""
|
||||
return bool(self.connections)
|
||||
|
||||
async def on_REPLICATE(self, cmd: ReplicateCommand):
|
||||
# We only want to announce positions by the writer of the streams.
|
||||
# Currently this is just the master process.
|
||||
if not self.is_master:
|
||||
return
|
||||
|
||||
if not self.connections:
|
||||
raise Exception("Not connected")
|
||||
|
||||
for stream_name, stream in self.streams.items():
|
||||
current_token = stream.current_token()
|
||||
self.send_command(PositionCommand(stream_name, current_token))
|
||||
|
||||
async def on_USER_SYNC(self, cmd: UserSyncCommand):
|
||||
user_sync_counter.inc()
|
||||
|
||||
if self.is_master:
|
||||
await self.presence_handler.update_external_syncs_row(
|
||||
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
)
|
||||
|
||||
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
|
||||
if self.is_master:
|
||||
await self.presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||
|
||||
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
|
||||
federation_ack_counter.inc()
|
||||
|
||||
if self.federation_sender:
|
||||
self.federation_sender.federation_ack(cmd.token)
|
||||
|
||||
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
|
||||
remove_pusher_counter.inc()
|
||||
|
||||
if self.is_master:
|
||||
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
|
||||
)
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
|
||||
invalidate_cache_counter.inc()
|
||||
|
||||
if self.is_master:
|
||||
# We invalidate the cache locally, but then also stream that to other
|
||||
# workers.
|
||||
await self.store.invalidate_cache_and_stream(
|
||||
cmd.cache_func, tuple(cmd.keys)
|
||||
)
|
||||
|
||||
async def on_USER_IP(self, cmd: UserIpCommand):
|
||||
user_ip_cache_counter.inc()
|
||||
|
||||
if self.is_master:
|
||||
await self.store.insert_client_ip(
|
||||
cmd.user_id,
|
||||
cmd.access_token,
|
||||
cmd.ip,
|
||||
cmd.user_agent,
|
||||
cmd.device_id,
|
||||
cmd.last_seen,
|
||||
)
|
||||
|
||||
if self._server_notices_sender:
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
|
||||
async def on_RDATA(self, cmd: RdataCommand):
|
||||
stream_name = cmd.stream_name
|
||||
|
||||
try:
|
||||
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
|
||||
except Exception:
|
||||
logger.exception("[%s] Failed to parse RDATA: %r", stream_name, cmd.row)
|
||||
raise
|
||||
|
||||
if cmd.token is None:
|
||||
# I.e. this is part of a batch of updates for this stream. Batch
|
||||
# until we get an update for the stream with a non None token
|
||||
self.pending_batches.setdefault(stream_name, []).append(row)
|
||||
else:
|
||||
# Check if this is the last of a batch of updates
|
||||
rows = self.pending_batches.pop(stream_name, [])
|
||||
rows.append(row)
|
||||
await self.on_rdata(stream_name, cmd.token, rows)
|
||||
|
||||
async def on_rdata(self, stream_name: str, token: int, rows: list):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
stream_name: name of the replication stream for this batch of rows
|
||||
token: stream token for this batch of rows
|
||||
rows: a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
await self.replication_data_handler.on_rdata(stream_name, token, rows)
|
||||
|
||||
async def on_POSITION(self, cmd: PositionCommand):
|
||||
stream = self.streams.get(cmd.stream_name)
|
||||
if not stream:
|
||||
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
||||
return
|
||||
|
||||
# We protect catching up with a linearizer in case the replicaiton
|
||||
# connection reconnects under us.
|
||||
with await self._position_linearizer.queue(cmd.stream_name):
|
||||
# Find where we previously streamed up to.
|
||||
current_token = self.replication_data_handler.get_streams_to_replicate().get(
|
||||
cmd.stream_name
|
||||
)
|
||||
if current_token is None:
|
||||
logger.debug(
|
||||
"Got POSITION for stream we're not subscribed to: %s",
|
||||
cmd.stream_name,
|
||||
)
|
||||
return
|
||||
|
||||
# Fetch all updates between then and now.
|
||||
limited = cmd.token != current_token
|
||||
while limited:
|
||||
updates, current_token, limited = await stream.get_updates_since(
|
||||
current_token, cmd.token
|
||||
)
|
||||
if updates:
|
||||
await self.on_rdata(
|
||||
cmd.stream_name,
|
||||
current_token,
|
||||
[stream.parse_row(update[1]) for update in updates],
|
||||
)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
|
||||
|
||||
# Handle any RDATA that came in while we were catching up.
|
||||
rows = self.pending_batches.pop(cmd.stream_name, [])
|
||||
if rows:
|
||||
await self.on_rdata(cmd.stream_name, rows[-1].token, rows)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
if self.is_master:
|
||||
self.notifier.notify_remote_server_up(cmd.data)
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users.
|
||||
"""
|
||||
return self.presence_handler.get_currently_syncing_users()
|
||||
|
||||
def send_command(self, cmd: Command):
|
||||
"""Send a command to master (when we get establish a connection if we
|
||||
don't have one already.)
|
||||
"""
|
||||
for conn in self.connections:
|
||||
conn.send_command(cmd)
|
||||
|
||||
def send_federation_ack(self, token: int):
|
||||
"""Ack data for the federation stream. This allows the master to drop
|
||||
data stored purely in memory.
|
||||
"""
|
||||
self.send_command(FederationAckCommand(token))
|
||||
|
||||
def send_user_sync(
|
||||
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
|
||||
):
|
||||
"""Poke the master that a user has started/stopped syncing.
|
||||
"""
|
||||
self.send_command(
|
||||
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
||||
)
|
||||
|
||||
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
|
||||
"""Poke the master to remove a pusher for a user
|
||||
"""
|
||||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
|
||||
"""Poke the master to invalidate a cache.
|
||||
"""
|
||||
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_user_ip(
|
||||
self,
|
||||
user_id: str,
|
||||
access_token: str,
|
||||
ip: str,
|
||||
user_agent: str,
|
||||
device_id: str,
|
||||
last_seen: int,
|
||||
):
|
||||
"""Tell the master that the user made a request.
|
||||
"""
|
||||
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def stream_update(self, stream_name: str, token: str, data: Any):
|
||||
"""Called when a new update is available to stream to clients.
|
||||
|
||||
We need to check if the client is interested in the stream or not
|
||||
"""
|
||||
self.send_command(RdataCommand(stream_name, token, data))
|
||||
|
||||
|
||||
class DummyReplicationDataHandler:
|
||||
"""A replication data handler that simply discards all data.
|
||||
"""
|
||||
|
||||
async def on_rdata(self, stream_name: str, token: int, rows: list):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||
handle more.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def on_position(self, stream_name: str, token: int):
|
||||
pass
|
||||
|
||||
|
||||
class WorkerReplicationDataHandler:
|
||||
"""A replication data handler that calls slave data stores.
|
||||
"""
|
||||
|
||||
def __init__(self, store):
|
||||
self.store = store
|
||||
|
||||
async def on_rdata(self, stream_name: str, token: int, rows: list):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||
handle more.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
self.store.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
args = self.store.stream_positions()
|
||||
user_account_data = args.pop("user_account_data", None)
|
||||
room_account_data = args.pop("room_account_data", None)
|
||||
if user_account_data:
|
||||
args["account_data"] = user_account_data
|
||||
elif room_account_data:
|
||||
args["account_data"] = room_account_data
|
||||
return args
|
||||
|
||||
async def on_position(self, stream_name: str, token: int):
|
||||
self.store.process_replication_rows(stream_name, token, [])
|
||||
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
|
||||
> PING 1490197665618
|
||||
< NAME synapse.app.appservice
|
||||
< PING 1490197665618
|
||||
< REPLICATE events 1
|
||||
< REPLICATE backfill 1
|
||||
< REPLICATE caches 1
|
||||
< REPLICATE
|
||||
> POSITION events 1
|
||||
> POSITION backfill 1
|
||||
> POSITION caches 1
|
||||
@@ -48,45 +46,40 @@ indicate which side is sending, these are *not* included on the wire::
|
||||
> ERROR server stopping
|
||||
* connection closed by server *
|
||||
"""
|
||||
import abc
|
||||
import fcntl
|
||||
import logging
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List, Set, Tuple
|
||||
from typing import Any, DefaultDict, Dict, List, Set
|
||||
|
||||
from six import iteritems, iterkeys
|
||||
from six import iteritems
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.protocols.basic import LineOnlyReceiver
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.commands import (
|
||||
COMMAND_MAP,
|
||||
VALID_CLIENT_COMMANDS,
|
||||
VALID_SERVER_COMMANDS,
|
||||
Command,
|
||||
ErrorCommand,
|
||||
NameCommand,
|
||||
PingCommand,
|
||||
PositionCommand,
|
||||
RdataCommand,
|
||||
RemoteServerUpCommand,
|
||||
ReplicateCommand,
|
||||
ServerCommand,
|
||||
SyncCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP
|
||||
from synapse.types import Collection
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
connection_close_counter = Counter(
|
||||
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
|
||||
)
|
||||
@@ -127,16 +120,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
|
||||
delimiter = b"\n"
|
||||
|
||||
# Valid commands we expect to receive
|
||||
VALID_INBOUND_COMMANDS = [] # type: Collection[str]
|
||||
|
||||
# Valid commands we can send
|
||||
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
|
||||
|
||||
max_line_buffer = 10000
|
||||
|
||||
def __init__(self, clock):
|
||||
def __init__(self, clock, handler):
|
||||
self.clock = clock
|
||||
self.handler = handler
|
||||
|
||||
self.last_received_command = self.clock.time_msec()
|
||||
self.last_sent_command = 0
|
||||
@@ -176,6 +164,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
# can time us out.
|
||||
self.send_command(PingCommand(self.clock.time_msec()))
|
||||
|
||||
self.handler.new_connection(self)
|
||||
|
||||
def send_ping(self):
|
||||
"""Periodically sends a ping and checks if we should close the connection
|
||||
due to the other side timing out.
|
||||
@@ -213,11 +203,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
line = line.decode("utf-8")
|
||||
cmd_name, rest_of_line = line.split(" ", 1)
|
||||
|
||||
if cmd_name not in self.VALID_INBOUND_COMMANDS:
|
||||
logger.error("[%s] invalid command %s", self.id(), cmd_name)
|
||||
self.send_error("invalid command: %s", cmd_name)
|
||||
return
|
||||
|
||||
self.last_received_command = self.clock.time_msec()
|
||||
|
||||
self.inbound_commands_counter[cmd_name] = (
|
||||
@@ -249,8 +234,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
Args:
|
||||
cmd: received command
|
||||
"""
|
||||
handler = getattr(self, "on_%s" % (cmd.NAME,))
|
||||
await handler(cmd)
|
||||
handled = False
|
||||
|
||||
# First call any command handlers on this instance. These are for TCP
|
||||
# specific handling.
|
||||
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
logger.warning("Unhandled command: %r", cmd)
|
||||
|
||||
def close(self):
|
||||
logger.warning("[%s] Closing connection", self.id())
|
||||
@@ -258,6 +258,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
self.transport.loseConnection()
|
||||
self.on_connection_closed()
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def send_error(self, error_string, *args):
|
||||
"""Send an error to remote and close the connection.
|
||||
"""
|
||||
@@ -379,6 +382,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
self.state = ConnectionStates.CLOSED
|
||||
self.pending_commands = []
|
||||
|
||||
self.handler.lost_connection(self)
|
||||
|
||||
if self.transport:
|
||||
self.transport.unregisterProducer()
|
||||
|
||||
@@ -402,346 +407,66 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
|
||||
|
||||
class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
||||
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
||||
|
||||
def __init__(self, server_name, clock, streamer):
|
||||
BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
|
||||
def __init__(self, hs, server_name, clock, handler):
|
||||
BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class
|
||||
|
||||
self.server_name = server_name
|
||||
self.streamer = streamer
|
||||
|
||||
# The streams the client has subscribed to and is up to date with
|
||||
self.replication_streams = set() # type: Set[str]
|
||||
|
||||
# The streams the client is currently subscribing to.
|
||||
self.connecting_streams = set() # type: Set[str]
|
||||
|
||||
# Map from stream name to list of updates to send once we've finished
|
||||
# subscribing the client to the stream.
|
||||
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
|
||||
|
||||
def connectionMade(self):
|
||||
self.send_command(ServerCommand(self.server_name))
|
||||
BaseReplicationStreamProtocol.connectionMade(self)
|
||||
self.streamer.new_connection(self)
|
||||
|
||||
async def on_NAME(self, cmd):
|
||||
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
||||
self.name = cmd.data
|
||||
|
||||
async def on_USER_SYNC(self, cmd):
|
||||
await self.streamer.on_user_sync(
|
||||
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
)
|
||||
|
||||
async def on_REPLICATE(self, cmd):
|
||||
stream_name = cmd.stream_name
|
||||
token = cmd.token
|
||||
|
||||
if stream_name == "ALL":
|
||||
# Subscribe to all streams we're publishing to.
|
||||
deferreds = [
|
||||
run_in_background(self.subscribe_to_stream, stream, token)
|
||||
for stream in iterkeys(self.streamer.streams_by_name)
|
||||
]
|
||||
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
else:
|
||||
await self.subscribe_to_stream(stream_name, token)
|
||||
|
||||
async def on_FEDERATION_ACK(self, cmd):
|
||||
self.streamer.federation_ack(cmd.token)
|
||||
|
||||
async def on_REMOVE_PUSHER(self, cmd):
|
||||
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
|
||||
|
||||
async def on_INVALIDATE_CACHE(self, cmd):
|
||||
await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
self.streamer.on_remote_server_up(cmd.data)
|
||||
|
||||
async def on_USER_IP(self, cmd):
|
||||
self.streamer.on_user_ip(
|
||||
cmd.user_id,
|
||||
cmd.access_token,
|
||||
cmd.ip,
|
||||
cmd.user_agent,
|
||||
cmd.device_id,
|
||||
cmd.last_seen,
|
||||
)
|
||||
|
||||
async def subscribe_to_stream(self, stream_name, token):
|
||||
"""Subscribe the remote to a stream.
|
||||
|
||||
This invloves checking if they've missed anything and sending those
|
||||
updates down if they have. During that time new updates for the stream
|
||||
are queued and sent once we've sent down any missed updates.
|
||||
"""
|
||||
self.replication_streams.discard(stream_name)
|
||||
self.connecting_streams.add(stream_name)
|
||||
|
||||
try:
|
||||
# Get missing updates
|
||||
updates, current_token = await self.streamer.get_stream_updates(
|
||||
stream_name, token
|
||||
)
|
||||
|
||||
# Send all the missing updates
|
||||
for update in updates:
|
||||
token, row = update[0], update[1]
|
||||
self.send_command(RdataCommand(stream_name, token, row))
|
||||
|
||||
# We send a POSITION command to ensure that they have an up to
|
||||
# date token (especially useful if we didn't send any updates
|
||||
# above)
|
||||
self.send_command(PositionCommand(stream_name, current_token))
|
||||
|
||||
# Now we can send any updates that came in while we were subscribing
|
||||
pending_rdata = self.pending_rdata.pop(stream_name, [])
|
||||
updates = []
|
||||
for token, update in pending_rdata:
|
||||
# If the token is null, it is part of a batch update. Batches
|
||||
# are multiple updates that share a single token. To denote
|
||||
# this, the token is set to None for all tokens in the batch
|
||||
# except for the last. If we find a None token, we keep looking
|
||||
# through tokens until we find one that is not None and then
|
||||
# process all previous updates in the batch as if they had the
|
||||
# final token.
|
||||
if token is None:
|
||||
# Store this update as part of a batch
|
||||
updates.append(update)
|
||||
continue
|
||||
|
||||
if token <= current_token:
|
||||
# This update or batch of updates is older than
|
||||
# current_token, dismiss it
|
||||
updates = []
|
||||
continue
|
||||
|
||||
updates.append(update)
|
||||
|
||||
# Send all updates that are part of this batch with the
|
||||
# found token
|
||||
for update in updates:
|
||||
self.send_command(RdataCommand(stream_name, token, update))
|
||||
|
||||
# Clear stored updates
|
||||
updates = []
|
||||
|
||||
# They're now fully subscribed
|
||||
self.replication_streams.add(stream_name)
|
||||
except Exception as e:
|
||||
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
|
||||
self.send_error("failed to handle replicate: %r", e)
|
||||
finally:
|
||||
self.connecting_streams.discard(stream_name)
|
||||
|
||||
def stream_update(self, stream_name, token, data):
|
||||
"""Called when a new update is available to stream to clients.
|
||||
|
||||
We need to check if the client is interested in the stream or not
|
||||
"""
|
||||
if stream_name in self.replication_streams:
|
||||
# The client is subscribed to the stream
|
||||
self.send_command(RdataCommand(stream_name, token, data))
|
||||
elif stream_name in self.connecting_streams:
|
||||
# The client is being subscribed to the stream
|
||||
logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
|
||||
self.pending_rdata.setdefault(stream_name, []).append((token, data))
|
||||
else:
|
||||
# The client isn't subscribed
|
||||
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
|
||||
|
||||
def send_sync(self, data):
|
||||
self.send_command(SyncCommand(data))
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def on_connection_closed(self):
|
||||
BaseReplicationStreamProtocol.on_connection_closed(self)
|
||||
self.streamer.lost_connection(self)
|
||||
|
||||
|
||||
class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
The interface for the handler that should be passed to
|
||||
ClientReplicationStreamProtocol
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def on_position(self, stream_name, token):
|
||||
"""Called when we get new position data."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_sync(self, data):
|
||||
"""Called when get a new SYNC command."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def on_remote_server_up(self, server: str):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_streams_to_replicate(self):
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_connection(self, connection):
|
||||
"""Called when a connection has been established (or lost with None).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def finished_connecting(self):
|
||||
"""Called when we have successfully subscribed and caught up to all
|
||||
streams we're interested in.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
||||
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
client_name: str,
|
||||
server_name: str,
|
||||
clock: Clock,
|
||||
handler: AbstractReplicationClientHandler,
|
||||
handler,
|
||||
):
|
||||
BaseReplicationStreamProtocol.__init__(self, clock)
|
||||
BaseReplicationStreamProtocol.__init__(self, clock, handler)
|
||||
|
||||
self.instance_id = hs.get_instance_id()
|
||||
|
||||
self.client_name = client_name
|
||||
self.server_name = server_name
|
||||
self.handler = handler
|
||||
|
||||
self.streams = {
|
||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||
} # type: Dict[str, Stream]
|
||||
|
||||
# Set of stream names that have been subscribe to, but haven't yet
|
||||
# caught up with. This is used to track when the client has been fully
|
||||
# connected to the remote.
|
||||
self.streams_connecting = set() # type: Set[str]
|
||||
self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
|
||||
|
||||
# Map of stream to batched updates. See RdataCommand for info on how
|
||||
# batching works.
|
||||
self.pending_batches = {} # type: Dict[str, Any]
|
||||
self.pending_batches = {} # type: Dict[str, List[Any]]
|
||||
|
||||
def connectionMade(self):
|
||||
self.send_command(NameCommand(self.client_name))
|
||||
BaseReplicationStreamProtocol.connectionMade(self)
|
||||
|
||||
# Once we've connected subscribe to the necessary streams
|
||||
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
|
||||
self.replicate(stream_name, token)
|
||||
|
||||
# Tell the server if we have any users currently syncing (should only
|
||||
# happen on synchrotrons)
|
||||
currently_syncing = self.handler.get_currently_syncing_users()
|
||||
now = self.clock.time_msec()
|
||||
for user_id in currently_syncing:
|
||||
self.send_command(UserSyncCommand(user_id, True, now))
|
||||
|
||||
# We've now finished connecting to so inform the client handler
|
||||
self.handler.update_connection(self)
|
||||
|
||||
# This will happen if we don't actually subscribe to any streams
|
||||
if not self.streams_connecting:
|
||||
self.handler.finished_connecting()
|
||||
self.send_command(NameCommand(self.client_name))
|
||||
self.replicate()
|
||||
|
||||
async def on_SERVER(self, cmd):
|
||||
if cmd.data != self.server_name:
|
||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||
self.send_error("Wrong remote")
|
||||
|
||||
async def on_RDATA(self, cmd):
|
||||
stream_name = cmd.stream_name
|
||||
inbound_rdata_count.labels(stream_name).inc()
|
||||
|
||||
try:
|
||||
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
|
||||
)
|
||||
raise
|
||||
|
||||
if cmd.token is None:
|
||||
# I.e. this is part of a batch of updates for this stream. Batch
|
||||
# until we get an update for the stream with a non None token
|
||||
self.pending_batches.setdefault(stream_name, []).append(row)
|
||||
else:
|
||||
# Check if this is the last of a batch of updates
|
||||
rows = self.pending_batches.pop(stream_name, [])
|
||||
rows.append(row)
|
||||
await self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||
|
||||
async def on_POSITION(self, cmd):
|
||||
# When we get a `POSITION` command it means we've finished getting
|
||||
# missing updates for the given stream, and are now up to date.
|
||||
self.streams_connecting.discard(cmd.stream_name)
|
||||
if not self.streams_connecting:
|
||||
self.handler.finished_connecting()
|
||||
|
||||
await self.handler.on_position(cmd.stream_name, cmd.token)
|
||||
|
||||
async def on_SYNC(self, cmd):
|
||||
self.handler.on_sync(cmd.data)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
self.handler.on_remote_server_up(cmd.data)
|
||||
|
||||
def replicate(self, stream_name, token):
|
||||
def replicate(self):
|
||||
"""Send the subscription request to the server
|
||||
"""
|
||||
if stream_name not in STREAMS_MAP:
|
||||
raise Exception("Invalid stream name %r" % (stream_name,))
|
||||
logger.info("[%s] Subscribing to replication streams", self.id())
|
||||
|
||||
logger.info(
|
||||
"[%s] Subscribing to replication stream: %r from %r",
|
||||
self.id(),
|
||||
stream_name,
|
||||
token,
|
||||
)
|
||||
|
||||
self.streams_connecting.add(stream_name)
|
||||
|
||||
self.send_command(ReplicateCommand(stream_name, token))
|
||||
|
||||
def on_connection_closed(self):
|
||||
BaseReplicationStreamProtocol.on_connection_closed(self)
|
||||
self.handler.update_connection(None)
|
||||
self.send_command(ReplicateCommand())
|
||||
|
||||
|
||||
# The following simply registers metrics for the replication connections
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, List
|
||||
from typing import Dict
|
||||
|
||||
from six import itervalues
|
||||
|
||||
@@ -25,24 +25,16 @@ from prometheus_client import Counter
|
||||
|
||||
from twisted.internet.protocol import Factory
|
||||
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.metrics import Measure, measure_func
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
from .protocol import ServerReplicationStreamProtocol
|
||||
from .streams import STREAMS_MAP
|
||||
from .streams import STREAMS_MAP, Stream
|
||||
from .streams.federation import FederationStream
|
||||
|
||||
stream_updates_counter = Counter(
|
||||
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
|
||||
)
|
||||
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
|
||||
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
|
||||
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
|
||||
invalidate_cache_counter = Counter(
|
||||
"synapse_replication_tcp_resource_invalidate_cache", ""
|
||||
)
|
||||
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,13 +44,18 @@ class ReplicationStreamProtocolFactory(Factory):
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.streamer = ReplicationStreamer(hs)
|
||||
self.handler = hs.get_tcp_replication()
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.config.server_name
|
||||
self.hs = hs
|
||||
|
||||
# Ensure the replication streamer is started if we register a
|
||||
# replication server endpoint.
|
||||
hs.get_replication_streamer()
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return ServerReplicationStreamProtocol(
|
||||
self.server_name, self.clock, self.streamer
|
||||
self.hs, self.server_name, self.clock, self.handler
|
||||
)
|
||||
|
||||
|
||||
@@ -78,16 +75,6 @@ class ReplicationStreamer(object):
|
||||
|
||||
self._replication_torture_level = hs.config.replication_torture_level
|
||||
|
||||
# Current connections.
|
||||
self.connections = [] # type: List[ServerReplicationStreamProtocol]
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_total_connections",
|
||||
"",
|
||||
[],
|
||||
lambda: len(self.connections),
|
||||
)
|
||||
|
||||
# List of streams that clients can subscribe to.
|
||||
# We only support federation stream if federation sending hase been
|
||||
# disabled on the master.
|
||||
@@ -99,39 +86,22 @@ class ReplicationStreamer(object):
|
||||
|
||||
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_connections_per_stream",
|
||||
"",
|
||||
["stream_name"],
|
||||
lambda: {
|
||||
(stream_name,): len(
|
||||
[
|
||||
conn
|
||||
for conn in self.connections
|
||||
if stream_name in conn.replication_streams
|
||||
]
|
||||
)
|
||||
for stream_name in self.streams_by_name
|
||||
},
|
||||
)
|
||||
|
||||
self.federation_sender = None
|
||||
if not hs.config.send_federation:
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
|
||||
self.notifier.add_replication_callback(self.on_notifier_poke)
|
||||
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
|
||||
|
||||
# Keeps track of whether we are currently checking for updates
|
||||
self.is_looping = False
|
||||
self.pending_updates = False
|
||||
|
||||
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
|
||||
self.client = hs.get_tcp_replication()
|
||||
|
||||
def on_shutdown(self):
|
||||
# close all connections on shutdown
|
||||
for conn in self.connections:
|
||||
conn.send_error("server shutting down")
|
||||
def get_streams(self) -> Dict[str, Stream]:
|
||||
"""Get a mapp from stream name to stream instance.
|
||||
"""
|
||||
return self.streams_by_name
|
||||
|
||||
def on_notifier_poke(self):
|
||||
"""Checks if there is actually any new data and sends it to the
|
||||
@@ -140,7 +110,7 @@ class ReplicationStreamer(object):
|
||||
This should get called each time new data is available, even if it
|
||||
is currently being executed, so that nothing gets missed
|
||||
"""
|
||||
if not self.connections:
|
||||
if not self.client.connected():
|
||||
# Don't bother if nothing is listening. We still need to advance
|
||||
# the stream tokens otherwise they'll fall beihind forever
|
||||
for stream in self.streams:
|
||||
@@ -166,11 +136,6 @@ class ReplicationStreamer(object):
|
||||
self.pending_updates = False
|
||||
|
||||
with Measure(self.clock, "repl.stream.get_updates"):
|
||||
# First we tell the streams that they should update their
|
||||
# current tokens.
|
||||
for stream in self.streams:
|
||||
stream.advance_current_token()
|
||||
|
||||
all_streams = self.streams
|
||||
|
||||
if self._replication_torture_level is not None:
|
||||
@@ -180,7 +145,7 @@ class ReplicationStreamer(object):
|
||||
random.shuffle(all_streams)
|
||||
|
||||
for stream in all_streams:
|
||||
if stream.last_token == stream.upto_token:
|
||||
if stream.last_token == stream.current_token():
|
||||
continue
|
||||
|
||||
if self._replication_torture_level:
|
||||
@@ -192,18 +157,17 @@ class ReplicationStreamer(object):
|
||||
"Getting stream: %s: %s -> %s",
|
||||
stream.NAME,
|
||||
stream.last_token,
|
||||
stream.upto_token,
|
||||
stream.current_token(),
|
||||
)
|
||||
try:
|
||||
updates, current_token = await stream.get_updates()
|
||||
updates, current_token, limited = await stream.get_updates()
|
||||
self.pending_updates |= limited
|
||||
except Exception:
|
||||
logger.info("Failed to handle stream %s", stream.NAME)
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
"Sending %d updates to %d connections",
|
||||
len(updates),
|
||||
len(self.connections),
|
||||
"Sending %d updates", len(updates),
|
||||
)
|
||||
|
||||
if updates:
|
||||
@@ -219,116 +183,17 @@ class ReplicationStreamer(object):
|
||||
# token. See RdataCommand for more details.
|
||||
batched_updates = _batch_updates(updates)
|
||||
|
||||
for conn in self.connections:
|
||||
for token, row in batched_updates:
|
||||
try:
|
||||
conn.stream_update(stream.NAME, token, row)
|
||||
except Exception:
|
||||
logger.exception("Failed to replicate")
|
||||
for token, row in batched_updates:
|
||||
try:
|
||||
self.client.stream_update(stream.NAME, token, row)
|
||||
except Exception:
|
||||
logger.exception("Failed to replicate")
|
||||
|
||||
logger.debug("No more pending updates, breaking poke loop")
|
||||
finally:
|
||||
self.pending_updates = False
|
||||
self.is_looping = False
|
||||
|
||||
@measure_func("repl.get_stream_updates")
|
||||
async def get_stream_updates(self, stream_name, token):
|
||||
"""For a given stream get all updates since token. This is called when
|
||||
a client first subscribes to a stream.
|
||||
"""
|
||||
stream = self.streams_by_name.get(stream_name, None)
|
||||
if not stream:
|
||||
raise Exception("unknown stream %s", stream_name)
|
||||
|
||||
return await stream.get_updates_since(token)
|
||||
|
||||
@measure_func("repl.federation_ack")
|
||||
def federation_ack(self, token):
|
||||
"""We've received an ack for federation stream from a client.
|
||||
"""
|
||||
federation_ack_counter.inc()
|
||||
if self.federation_sender:
|
||||
self.federation_sender.federation_ack(token)
|
||||
|
||||
@measure_func("repl.on_user_sync")
|
||||
async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
|
||||
"""A client has started/stopped syncing on a worker.
|
||||
"""
|
||||
user_sync_counter.inc()
|
||||
await self.presence_handler.update_external_syncs_row(
|
||||
conn_id, user_id, is_syncing, last_sync_ms
|
||||
)
|
||||
|
||||
@measure_func("repl.on_remove_pusher")
|
||||
async def on_remove_pusher(self, app_id, push_key, user_id):
|
||||
"""A client has asked us to remove a pusher
|
||||
"""
|
||||
remove_pusher_counter.inc()
|
||||
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id=app_id, pushkey=push_key, user_id=user_id
|
||||
)
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
@measure_func("repl.on_invalidate_cache")
|
||||
async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
|
||||
"""The client has asked us to invalidate a cache
|
||||
"""
|
||||
invalidate_cache_counter.inc()
|
||||
|
||||
# We invalidate the cache locally, but then also stream that to other
|
||||
# workers.
|
||||
await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
|
||||
|
||||
@measure_func("repl.on_user_ip")
|
||||
async def on_user_ip(
|
||||
self, user_id, access_token, ip, user_agent, device_id, last_seen
|
||||
):
|
||||
"""The client saw a user request
|
||||
"""
|
||||
user_ip_cache_counter.inc()
|
||||
await self.store.insert_client_ip(
|
||||
user_id, access_token, ip, user_agent, device_id, last_seen
|
||||
)
|
||||
await self._server_notices_sender.on_user_ip(user_id)
|
||||
|
||||
@measure_func("repl.on_remote_server_up")
|
||||
def on_remote_server_up(self, server: str):
|
||||
self.notifier.notify_remote_server_up(server)
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
for conn in self.connections:
|
||||
conn.send_remote_server_up(server)
|
||||
|
||||
def send_sync_to_all_connections(self, data):
|
||||
"""Sends a SYNC command to all clients.
|
||||
|
||||
Used in tests.
|
||||
"""
|
||||
for conn in self.connections:
|
||||
conn.send_sync(data)
|
||||
|
||||
def new_connection(self, connection):
|
||||
"""A new client connection has been established
|
||||
"""
|
||||
self.connections.append(connection)
|
||||
|
||||
def lost_connection(self, connection):
|
||||
"""A client connection has been lost
|
||||
"""
|
||||
try:
|
||||
self.connections.remove(connection)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# We need to tell the presence handler that the connection has been
|
||||
# lost so that it can handle any ongoing syncs on that connection.
|
||||
run_as_background_process(
|
||||
"update_external_syncs_clear",
|
||||
self.presence_handler.update_external_syncs_clear,
|
||||
connection.conn_id,
|
||||
)
|
||||
|
||||
|
||||
def _batch_updates(updates):
|
||||
"""Takes a list of updates of form [(token, row)] and sets the token to
|
||||
|
||||
@@ -25,26 +25,66 @@ Each stream is defined by the following information:
|
||||
update_function: The function that returns a list of updates between two tokens
|
||||
"""
|
||||
|
||||
from . import _base, events, federation
|
||||
from typing import Dict, Type
|
||||
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
AccountDataStream,
|
||||
BackfillStream,
|
||||
CachesStream,
|
||||
DeviceListsStream,
|
||||
GroupServerStream,
|
||||
PresenceStream,
|
||||
PublicRoomsStream,
|
||||
PushersStream,
|
||||
PushRulesStream,
|
||||
ReceiptsStream,
|
||||
Stream,
|
||||
TagAccountDataStream,
|
||||
ToDeviceStream,
|
||||
TypingStream,
|
||||
UserSignatureStream,
|
||||
)
|
||||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.replication.tcp.streams.federation import FederationStream
|
||||
|
||||
STREAMS_MAP = {
|
||||
stream.NAME: stream
|
||||
for stream in (
|
||||
events.EventsStream,
|
||||
_base.BackfillStream,
|
||||
_base.PresenceStream,
|
||||
_base.TypingStream,
|
||||
_base.ReceiptsStream,
|
||||
_base.PushRulesStream,
|
||||
_base.PushersStream,
|
||||
_base.CachesStream,
|
||||
_base.PublicRoomsStream,
|
||||
_base.DeviceListsStream,
|
||||
_base.ToDeviceStream,
|
||||
federation.FederationStream,
|
||||
_base.TagAccountDataStream,
|
||||
_base.AccountDataStream,
|
||||
_base.GroupServerStream,
|
||||
_base.UserSignatureStream,
|
||||
EventsStream,
|
||||
BackfillStream,
|
||||
PresenceStream,
|
||||
TypingStream,
|
||||
ReceiptsStream,
|
||||
PushRulesStream,
|
||||
PushersStream,
|
||||
CachesStream,
|
||||
PublicRoomsStream,
|
||||
DeviceListsStream,
|
||||
ToDeviceStream,
|
||||
FederationStream,
|
||||
TagAccountDataStream,
|
||||
AccountDataStream,
|
||||
GroupServerStream,
|
||||
UserSignatureStream,
|
||||
)
|
||||
}
|
||||
} # type: Dict[str, Type[Stream]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"STREAMS_MAP",
|
||||
"Stream",
|
||||
"BackfillStream",
|
||||
"PresenceStream",
|
||||
"TypingStream",
|
||||
"ReceiptsStream",
|
||||
"PushRulesStream",
|
||||
"PushersStream",
|
||||
"CachesStream",
|
||||
"PublicRoomsStream",
|
||||
"DeviceListsStream",
|
||||
"ToDeviceStream",
|
||||
"TagAccountDataStream",
|
||||
"AccountDataStream",
|
||||
"GroupServerStream",
|
||||
"UserSignatureStream",
|
||||
]
|
||||
|
||||
@@ -14,114 +14,40 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_EVENTS_BEHIND = 500000
|
||||
|
||||
BackfillStreamRow = namedtuple(
|
||||
"BackfillStreamRow",
|
||||
(
|
||||
"event_id", # str
|
||||
"room_id", # str
|
||||
"type", # str
|
||||
"state_key", # str, optional
|
||||
"redacts", # str, optional
|
||||
"relates_to", # str, optional
|
||||
),
|
||||
)
|
||||
PresenceStreamRow = namedtuple(
|
||||
"PresenceStreamRow",
|
||||
(
|
||||
"user_id", # str
|
||||
"state", # str
|
||||
"last_active_ts", # int
|
||||
"last_federation_update_ts", # int
|
||||
"last_user_sync_ts", # int
|
||||
"status_msg", # str
|
||||
"currently_active", # bool
|
||||
),
|
||||
)
|
||||
TypingStreamRow = namedtuple(
|
||||
"TypingStreamRow", ("room_id", "user_ids") # str # list(str)
|
||||
)
|
||||
ReceiptsStreamRow = namedtuple(
|
||||
"ReceiptsStreamRow",
|
||||
(
|
||||
"room_id", # str
|
||||
"receipt_type", # str
|
||||
"user_id", # str
|
||||
"event_id", # str
|
||||
"data", # dict
|
||||
),
|
||||
)
|
||||
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
|
||||
PushersStreamRow = namedtuple(
|
||||
"PushersStreamRow",
|
||||
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
|
||||
)
|
||||
|
||||
# Some type aliases to make things a bit easier.
|
||||
|
||||
@attr.s
|
||||
class CachesStreamRow:
|
||||
"""Stream to inform workers they should invalidate their cache.
|
||||
# A stream position token
|
||||
Token = int
|
||||
|
||||
Attributes:
|
||||
cache_func: Name of the cached function.
|
||||
keys: The entry in the cache to invalidate. If None then will
|
||||
invalidate all.
|
||||
invalidation_ts: Timestamp of when the invalidation took place.
|
||||
"""
|
||||
|
||||
cache_func = attr.ib(type=str)
|
||||
keys = attr.ib(type=Optional[List[Any]])
|
||||
invalidation_ts = attr.ib(type=int)
|
||||
|
||||
|
||||
PublicRoomsStreamRow = namedtuple(
|
||||
"PublicRoomsStreamRow",
|
||||
(
|
||||
"room_id", # str
|
||||
"visibility", # str
|
||||
"appservice_id", # str, optional
|
||||
"network_id", # str, optional
|
||||
),
|
||||
)
|
||||
DeviceListsStreamRow = namedtuple(
|
||||
"DeviceListsStreamRow", ("user_id", "destination") # str # str
|
||||
)
|
||||
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
|
||||
TagAccountDataStreamRow = namedtuple(
|
||||
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
|
||||
)
|
||||
AccountDataStreamRow = namedtuple(
|
||||
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
|
||||
)
|
||||
GroupsStreamRow = namedtuple(
|
||||
"GroupsStreamRow",
|
||||
("group_id", "user_id", "type", "content"), # str # str # str # dict
|
||||
)
|
||||
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
|
||||
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
|
||||
StreamRow = Tuple[Token, tuple]
|
||||
|
||||
|
||||
class Stream(object):
|
||||
"""Base class for the streams.
|
||||
|
||||
Provides a `get_updates()` function that returns new updates since the last
|
||||
time it was called up until the point `advance_current_token` was called.
|
||||
time it was called.
|
||||
"""
|
||||
|
||||
NAME = None # type: str # The name of the stream
|
||||
# The type of the row. Used by the default impl of parse_row.
|
||||
ROW_TYPE = None # type: Any
|
||||
_LIMITED = True # Whether the update function takes a limit
|
||||
|
||||
@classmethod
|
||||
def parse_row(cls, row):
|
||||
@@ -139,80 +65,56 @@ class Stream(object):
|
||||
return cls.ROW_TYPE(*row)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
||||
# The token from which we last asked for updates
|
||||
self.last_token = self.current_token()
|
||||
|
||||
# The token that we will get updates up to
|
||||
self.upto_token = self.current_token()
|
||||
|
||||
def advance_current_token(self):
|
||||
"""Updates `upto_token` to "now", which updates up until which point
|
||||
get_updates[_since] will fetch rows till.
|
||||
"""
|
||||
self.upto_token = self.current_token()
|
||||
|
||||
def discard_updates_and_advance(self):
|
||||
"""Called when the stream should advance but the updates would be discarded,
|
||||
e.g. when there are no currently connected workers.
|
||||
"""
|
||||
self.upto_token = self.current_token()
|
||||
self.last_token = self.upto_token
|
||||
self.last_token = self.current_token()
|
||||
|
||||
async def get_updates(self):
|
||||
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||
"""Gets all updates since the last time this function was called (or
|
||||
since the stream was constructed if it hadn't been called before),
|
||||
until the `upto_token`
|
||||
since the stream was constructed if it hadn't been called before).
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[List[Tuple[int, Any]], int]:
|
||||
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
|
||||
list of ``(token, row)`` entries. ``row`` will be json-serialised and
|
||||
sent over the replication steam.
|
||||
A triplet `(updates, new_last_token, limited)`, where `updates` is
|
||||
a list of `(token, row)` entries, `new_last_token` is the new
|
||||
position in stream, and `limited` is whether there are more updates
|
||||
to fetch.
|
||||
"""
|
||||
updates, current_token = await self.get_updates_since(self.last_token)
|
||||
current_token = self.current_token()
|
||||
updates, current_token, limited = await self.get_updates_since(
|
||||
self.last_token, current_token
|
||||
)
|
||||
self.last_token = current_token
|
||||
|
||||
return updates, current_token
|
||||
return updates, current_token, limited
|
||||
|
||||
async def get_updates_since(self, from_token):
|
||||
async def get_updates_since(
|
||||
self, from_token: Token, upto_token: Token, limit: int = 100
|
||||
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||
"""Like get_updates except allows specifying from when we should
|
||||
stream updates
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[List[Tuple[int, Any]], int]:
|
||||
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
|
||||
list of ``(token, row)`` entries. ``row`` will be json-serialised and
|
||||
sent over the replication steam.
|
||||
A triplet `(updates, new_last_token, limited)`, where `updates` is
|
||||
a list of `(token, row)` entries, `new_last_token` is the new
|
||||
position in stream, and `limited` is whether there are more updates
|
||||
to fetch.
|
||||
"""
|
||||
if from_token in ("NOW", "now"):
|
||||
return [], self.upto_token
|
||||
|
||||
current_token = self.upto_token
|
||||
|
||||
from_token = int(from_token)
|
||||
|
||||
if from_token == current_token:
|
||||
return [], current_token
|
||||
if from_token == upto_token:
|
||||
return [], upto_token, False
|
||||
|
||||
logger.info("get_updates_since: %s", self.__class__)
|
||||
if self._LIMITED:
|
||||
rows = await self.update_function(
|
||||
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
|
||||
)
|
||||
|
||||
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
|
||||
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
|
||||
else:
|
||||
rows = await self.update_function(from_token, current_token)
|
||||
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
|
||||
# check we didn't get more rows than the limit.
|
||||
# doing it like this allows the update_function to be a generator.
|
||||
if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
|
||||
raise Exception("stream %s has fallen behind" % (self.NAME))
|
||||
|
||||
return updates, current_token
|
||||
updates, upto_token, limited = await self.update_function(
|
||||
from_token, upto_token, limit=limit,
|
||||
)
|
||||
return updates, upto_token, limited
|
||||
|
||||
def current_token(self):
|
||||
"""Gets the current token of the underlying streams. Should be provided
|
||||
@@ -223,9 +125,8 @@ class Stream(object):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_function(self, from_token, current_token, limit=None):
|
||||
"""Get updates between from_token and to_token. If Stream._LIMITED is
|
||||
True then limit is provided, otherwise it's not.
|
||||
def update_function(self, from_token, current_token, limit):
|
||||
"""Get updates between from_token and to_token.
|
||||
|
||||
Returns:
|
||||
Deferred(list(tuple)): the first entry in the tuple is the token for
|
||||
@@ -235,52 +136,144 @@ class Stream(object):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def db_query_to_update_function(
|
||||
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||
"""Wraps a db query function which returns a list of rows to make it
|
||||
suitable for use as an `update_function` for the Stream class
|
||||
"""
|
||||
|
||||
async def update_function(from_token, upto_token, limit):
|
||||
rows = await query_function(from_token, upto_token, limit)
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
if len(updates) == limit:
|
||||
upto_token = rows[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
return update_function
|
||||
|
||||
|
||||
def make_http_update_function(
|
||||
hs, stream_name: str
|
||||
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||
"""Makes a suitable function for use as an `update_function` that queries
|
||||
the master process for updates.
|
||||
"""
|
||||
|
||||
client = ReplicationGetStreamUpdates.make_client(hs)
|
||||
|
||||
async def update_function(
|
||||
from_token: int, upto_token: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
return await client(
|
||||
stream_name=stream_name,
|
||||
from_token=from_token,
|
||||
upto_token=upto_token,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return update_function
|
||||
|
||||
|
||||
class BackfillStream(Stream):
|
||||
"""We fetched some old events and either we had never seen that event before
|
||||
or it went from being an outlier to not.
|
||||
"""
|
||||
|
||||
BackfillStreamRow = namedtuple(
|
||||
"BackfillStreamRow",
|
||||
(
|
||||
"event_id", # str
|
||||
"room_id", # str
|
||||
"type", # str
|
||||
"state_key", # str, optional
|
||||
"redacts", # str, optional
|
||||
"relates_to", # str, optional
|
||||
),
|
||||
)
|
||||
|
||||
NAME = "backfill"
|
||||
ROW_TYPE = BackfillStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
self.current_token = store.get_current_backfill_token # type: ignore
|
||||
self.update_function = store.get_all_new_backfill_event_rows # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
|
||||
|
||||
super(BackfillStream, self).__init__(hs)
|
||||
|
||||
|
||||
class PresenceStream(Stream):
|
||||
PresenceStreamRow = namedtuple(
|
||||
"PresenceStreamRow",
|
||||
(
|
||||
"user_id", # str
|
||||
"state", # str
|
||||
"last_active_ts", # int
|
||||
"last_federation_update_ts", # int
|
||||
"last_user_sync_ts", # int
|
||||
"status_msg", # str
|
||||
"currently_active", # bool
|
||||
),
|
||||
)
|
||||
|
||||
NAME = "presence"
|
||||
_LIMITED = False
|
||||
ROW_TYPE = PresenceStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
presence_handler = hs.get_presence_handler()
|
||||
|
||||
self._is_worker = hs.config.worker_app is not None
|
||||
|
||||
self.current_token = store.get_current_presence_token # type: ignore
|
||||
self.update_function = presence_handler.get_all_presence_updates # type: ignore
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
|
||||
else:
|
||||
# Query master process
|
||||
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||
|
||||
super(PresenceStream, self).__init__(hs)
|
||||
|
||||
|
||||
class TypingStream(Stream):
|
||||
TypingStreamRow = namedtuple(
|
||||
"TypingStreamRow", ("room_id", "user_ids") # str # list(str)
|
||||
)
|
||||
|
||||
NAME = "typing"
|
||||
_LIMITED = False
|
||||
ROW_TYPE = TypingStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
typing_handler = hs.get_typing_handler()
|
||||
|
||||
self.current_token = typing_handler.get_current_token # type: ignore
|
||||
self.update_function = typing_handler.get_all_typing_updates # type: ignore
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
|
||||
else:
|
||||
# Query master process
|
||||
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||
|
||||
super(TypingStream, self).__init__(hs)
|
||||
|
||||
|
||||
class ReceiptsStream(Stream):
|
||||
ReceiptsStreamRow = namedtuple(
|
||||
"ReceiptsStreamRow",
|
||||
(
|
||||
"room_id", # str
|
||||
"receipt_type", # str
|
||||
"user_id", # str
|
||||
"event_id", # str
|
||||
"data", # dict
|
||||
),
|
||||
)
|
||||
|
||||
NAME = "receipts"
|
||||
ROW_TYPE = ReceiptsStreamRow
|
||||
|
||||
@@ -288,7 +281,7 @@ class ReceiptsStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_receipt_stream_id # type: ignore
|
||||
self.update_function = store.get_all_updated_receipts # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
|
||||
|
||||
super(ReceiptsStream, self).__init__(hs)
|
||||
|
||||
@@ -297,6 +290,8 @@ class PushRulesStream(Stream):
|
||||
"""A user has changed their push rules
|
||||
"""
|
||||
|
||||
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
|
||||
|
||||
NAME = "push_rules"
|
||||
ROW_TYPE = PushRulesStreamRow
|
||||
|
||||
@@ -310,13 +305,24 @@ class PushRulesStream(Stream):
|
||||
|
||||
async def update_function(self, from_token, to_token, limit):
|
||||
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
||||
return [(row[0], row[2]) for row in rows]
|
||||
|
||||
limited = False
|
||||
if len(rows) == limit:
|
||||
to_token = rows[-1][0]
|
||||
limited = True
|
||||
|
||||
return [(row[0], (row[2],)) for row in rows], to_token, limited
|
||||
|
||||
|
||||
class PushersStream(Stream):
|
||||
"""A user has added/changed/removed a pusher
|
||||
"""
|
||||
|
||||
PushersStreamRow = namedtuple(
|
||||
"PushersStreamRow",
|
||||
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
|
||||
)
|
||||
|
||||
NAME = "pushers"
|
||||
ROW_TYPE = PushersStreamRow
|
||||
|
||||
@@ -324,7 +330,7 @@ class PushersStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_pushers_stream_token # type: ignore
|
||||
self.update_function = store.get_all_updated_pushers_rows # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
|
||||
|
||||
super(PushersStream, self).__init__(hs)
|
||||
|
||||
@@ -334,6 +340,21 @@ class CachesStream(Stream):
|
||||
the cache on the workers
|
||||
"""
|
||||
|
||||
@attr.s
|
||||
class CachesStreamRow:
|
||||
"""Stream to inform workers they should invalidate their cache.
|
||||
|
||||
Attributes:
|
||||
cache_func: Name of the cached function.
|
||||
keys: The entry in the cache to invalidate. If None then will
|
||||
invalidate all.
|
||||
invalidation_ts: Timestamp of when the invalidation took place.
|
||||
"""
|
||||
|
||||
cache_func = attr.ib(type=str)
|
||||
keys = attr.ib(type=Optional[List[Any]])
|
||||
invalidation_ts = attr.ib(type=int)
|
||||
|
||||
NAME = "caches"
|
||||
ROW_TYPE = CachesStreamRow
|
||||
|
||||
@@ -341,7 +362,7 @@ class CachesStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_cache_stream_token # type: ignore
|
||||
self.update_function = store.get_all_updated_caches # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
|
||||
|
||||
super(CachesStream, self).__init__(hs)
|
||||
|
||||
@@ -350,6 +371,16 @@ class PublicRoomsStream(Stream):
|
||||
"""The public rooms list changed
|
||||
"""
|
||||
|
||||
PublicRoomsStreamRow = namedtuple(
|
||||
"PublicRoomsStreamRow",
|
||||
(
|
||||
"room_id", # str
|
||||
"visibility", # str
|
||||
"appservice_id", # str, optional
|
||||
"network_id", # str, optional
|
||||
),
|
||||
)
|
||||
|
||||
NAME = "public_rooms"
|
||||
ROW_TYPE = PublicRoomsStreamRow
|
||||
|
||||
@@ -357,24 +388,28 @@ class PublicRoomsStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_current_public_room_stream_id # type: ignore
|
||||
self.update_function = store.get_all_new_public_rooms # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
|
||||
|
||||
super(PublicRoomsStream, self).__init__(hs)
|
||||
|
||||
|
||||
class DeviceListsStream(Stream):
|
||||
"""Someone added/changed/removed a device
|
||||
"""Either a user has updated their devices or a remote server needs to be
|
||||
told about a device update.
|
||||
"""
|
||||
|
||||
@attr.s
|
||||
class DeviceListsStreamRow:
|
||||
entity = attr.ib(type=str)
|
||||
|
||||
NAME = "device_lists"
|
||||
_LIMITED = False
|
||||
ROW_TYPE = DeviceListsStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
|
||||
|
||||
super(DeviceListsStream, self).__init__(hs)
|
||||
|
||||
@@ -383,6 +418,8 @@ class ToDeviceStream(Stream):
|
||||
"""New to_device messages for a client
|
||||
"""
|
||||
|
||||
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
|
||||
|
||||
NAME = "to_device"
|
||||
ROW_TYPE = ToDeviceStreamRow
|
||||
|
||||
@@ -390,7 +427,7 @@ class ToDeviceStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_to_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_new_device_messages # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
|
||||
|
||||
super(ToDeviceStream, self).__init__(hs)
|
||||
|
||||
@@ -399,6 +436,10 @@ class TagAccountDataStream(Stream):
|
||||
"""Someone added/removed a tag for a room
|
||||
"""
|
||||
|
||||
TagAccountDataStreamRow = namedtuple(
|
||||
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
|
||||
)
|
||||
|
||||
NAME = "tag_account_data"
|
||||
ROW_TYPE = TagAccountDataStreamRow
|
||||
|
||||
@@ -406,7 +447,7 @@ class TagAccountDataStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = store.get_all_updated_tags # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
|
||||
|
||||
super(TagAccountDataStream, self).__init__(hs)
|
||||
|
||||
@@ -415,6 +456,10 @@ class AccountDataStream(Stream):
|
||||
"""Global or per room account data was changed
|
||||
"""
|
||||
|
||||
AccountDataStreamRow = namedtuple(
|
||||
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
|
||||
)
|
||||
|
||||
NAME = "account_data"
|
||||
ROW_TYPE = AccountDataStreamRow
|
||||
|
||||
@@ -422,10 +467,11 @@ class AccountDataStream(Stream):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||
|
||||
super(AccountDataStream, self).__init__(hs)
|
||||
|
||||
async def update_function(self, from_token, to_token, limit):
|
||||
async def _update_function(self, from_token, to_token, limit):
|
||||
global_results, room_results = await self.store.get_all_updated_account_data(
|
||||
from_token, from_token, to_token, limit
|
||||
)
|
||||
@@ -440,6 +486,11 @@ class AccountDataStream(Stream):
|
||||
|
||||
|
||||
class GroupServerStream(Stream):
|
||||
GroupsStreamRow = namedtuple(
|
||||
"GroupsStreamRow",
|
||||
("group_id", "user_id", "type", "content"), # str # str # str # dict
|
||||
)
|
||||
|
||||
NAME = "groups"
|
||||
ROW_TYPE = GroupsStreamRow
|
||||
|
||||
@@ -447,7 +498,7 @@ class GroupServerStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_group_stream_token # type: ignore
|
||||
self.update_function = store.get_all_groups_changes # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
|
||||
|
||||
super(GroupServerStream, self).__init__(hs)
|
||||
|
||||
@@ -456,14 +507,15 @@ class UserSignatureStream(Stream):
|
||||
"""A user has signed their own device with their user-signing key
|
||||
"""
|
||||
|
||||
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
|
||||
|
||||
NAME = "user_signature"
|
||||
_LIMITED = False
|
||||
ROW_TYPE = UserSignatureStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
|
||||
|
||||
super(UserSignatureStream, self).__init__(hs)
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Tuple, Type
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Stream
|
||||
from ._base import Stream, db_query_to_update_function
|
||||
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
@@ -117,10 +117,11 @@ class EventsStream(Stream):
|
||||
def __init__(self, hs):
|
||||
self._store = hs.get_datastore()
|
||||
self.current_token = self._store.get_current_events_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||
|
||||
super(EventsStream, self).__init__(hs)
|
||||
|
||||
async def update_function(self, from_token, current_token, limit=None):
|
||||
async def _update_function(self, from_token, current_token, limit=None):
|
||||
event_rows = await self._store.get_all_new_forward_event_rows(
|
||||
from_token, current_token, limit
|
||||
)
|
||||
|
||||
@@ -15,15 +15,9 @@
|
||||
# limitations under the License.
|
||||
from collections import namedtuple
|
||||
|
||||
from ._base import Stream
|
||||
from twisted.internet import defer
|
||||
|
||||
FederationStreamRow = namedtuple(
|
||||
"FederationStreamRow",
|
||||
(
|
||||
"type", # str, the type of data as defined in the BaseFederationRows
|
||||
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
|
||||
),
|
||||
)
|
||||
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
|
||||
|
||||
|
||||
class FederationStream(Stream):
|
||||
@@ -31,13 +25,28 @@ class FederationStream(Stream):
|
||||
sending disabled.
|
||||
"""
|
||||
|
||||
FederationStreamRow = namedtuple(
|
||||
"FederationStreamRow",
|
||||
(
|
||||
"type", # str, the type of data as defined in the BaseFederationRows
|
||||
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
|
||||
),
|
||||
)
|
||||
|
||||
NAME = "federation"
|
||||
ROW_TYPE = FederationStreamRow
|
||||
_QUERY_MASTER = True
|
||||
|
||||
def __init__(self, hs):
|
||||
federation_sender = hs.get_federation_sender()
|
||||
|
||||
self.current_token = federation_sender.get_current_token # type: ignore
|
||||
self.update_function = federation_sender.get_replication_rows # type: ignore
|
||||
# Not all synapse instances will have a federation sender instance,
|
||||
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
|
||||
# so we stub the stream out when that is the case.
|
||||
if hs.config.worker_app is None or hs.should_send_federation():
|
||||
federation_sender = hs.get_federation_sender()
|
||||
self.current_token = federation_sender.get_current_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
|
||||
else:
|
||||
self.current_token = lambda: 0 # type: ignore
|
||||
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
|
||||
|
||||
super(FederationStream, self).__init__(hs)
|
||||
|
||||
@@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
|
||||
keys,
|
||||
notifications,
|
||||
openid,
|
||||
password_policy,
|
||||
read_marker,
|
||||
receipts,
|
||||
register,
|
||||
@@ -118,6 +119,7 @@ class ClientRestResource(JsonResource):
|
||||
capabilities.register_servlets(hs, client_resource)
|
||||
account_validity.register_servlets(hs, client_resource)
|
||||
relations.register_servlets(hs, client_resource)
|
||||
password_policy.register_servlets(hs, client_resource)
|
||||
|
||||
# moving to /_synapse/admin
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource(
|
||||
|
||||
@@ -29,7 +29,11 @@ from synapse.rest.admin._base import (
|
||||
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
|
||||
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
|
||||
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
|
||||
from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet
|
||||
from synapse.rest.admin.rooms import (
|
||||
JoinRoomAliasServlet,
|
||||
ListRoomRestServlet,
|
||||
ShutdownRoomRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
|
||||
from synapse.rest.admin.users import (
|
||||
AccountValidityRenewServlet,
|
||||
@@ -189,6 +193,7 @@ def register_servlets(hs, http_server):
|
||||
"""
|
||||
register_servlets_for_client_rest_resource(hs, http_server)
|
||||
ListRoomRestServlet(hs).register(http_server)
|
||||
JoinRoomAliasServlet(hs).register(http_server)
|
||||
PurgeRoomServlet(hs).register(http_server)
|
||||
SendServerNoticeServlet(hs).register(http_server)
|
||||
VersionServlet(hs).register(http_server)
|
||||
|
||||
@@ -13,9 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
@@ -29,7 +30,7 @@ from synapse.rest.admin._base import (
|
||||
historical_admin_path_patterns,
|
||||
)
|
||||
from synapse.storage.data_stores.main.room import RoomSortOrder
|
||||
from synapse.types import create_requester
|
||||
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -237,3 +238,75 @@ class ListRoomRestServlet(RestServlet):
|
||||
response["prev_batch"] = 0
|
||||
|
||||
return 200, response
|
||||
|
||||
|
||||
class JoinRoomAliasServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.admin_handler = hs.get_handlers().admin_handler
|
||||
self.state_handler = hs.get_state_handler()
|
||||
|
||||
async def on_POST(self, request, room_identifier):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_dict(content, ["user_id"])
|
||||
target_user = UserID.from_string(content["user_id"])
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "This endpoint can only be used with local users")
|
||||
|
||||
if not await self.admin_handler.get_user(target_user):
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
if RoomID.is_valid(room_identifier):
|
||||
room_id = room_identifier
|
||||
try:
|
||||
remote_room_hosts = [
|
||||
x.decode("ascii") for x in request.args[b"server_name"]
|
||||
] # type: Optional[List[str]]
|
||||
except Exception:
|
||||
remote_room_hosts = None
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
handler = self.room_member_handler
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
|
||||
room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
)
|
||||
|
||||
fake_requester = create_requester(target_user)
|
||||
|
||||
# send invite if room has "JoinRules.INVITE"
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
|
||||
if join_rules_event:
|
||||
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
|
||||
await self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=fake_requester.user,
|
||||
room_id=room_id,
|
||||
action="invite",
|
||||
remote_room_hosts=remote_room_hosts,
|
||||
ratelimit=False,
|
||||
)
|
||||
|
||||
await self.room_member_handler.update_membership(
|
||||
requester=fake_requester,
|
||||
target=fake_requester.user,
|
||||
room_id=room_id,
|
||||
action="join",
|
||||
remote_room_hosts=remote_room_hosts,
|
||||
ratelimit=False,
|
||||
)
|
||||
|
||||
return 200, {"room_id": room_id}
|
||||
|
||||
@@ -14,11 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
@@ -28,10 +23,10 @@ from synapse.http.servlet import (
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
from synapse.types import UserID
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -402,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||
|
||||
def on_GET(self, request):
|
||||
def on_GET(self, request: SynapseRequest):
|
||||
args = request.args
|
||||
if b"redirectUrl" not in args:
|
||||
return 400, "Redirect URL not specified for SSO auth"
|
||||
@@ -411,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet):
|
||||
request.redirect(sso_url)
|
||||
finish_request(request)
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
"""Get the URL to redirect to, to perform SSO auth
|
||||
|
||||
Args:
|
||||
client_redirect_url (bytes): the URL that we should redirect the
|
||||
client_redirect_url: the URL that we should redirect the
|
||||
client to when everything is done
|
||||
|
||||
Returns:
|
||||
bytes: URL to redirect to
|
||||
URL to redirect to
|
||||
"""
|
||||
# to be implemented by subclasses
|
||||
raise NotImplementedError()
|
||||
@@ -427,19 +422,10 @@ class BaseSSORedirectServlet(RestServlet):
|
||||
|
||||
class CasRedirectServlet(BaseSSORedirectServlet):
|
||||
def __init__(self, hs):
|
||||
super(CasRedirectServlet, self).__init__()
|
||||
self.cas_server_url = hs.config.cas_server_url.encode("ascii")
|
||||
self.cas_service_url = hs.config.cas_service_url.encode("ascii")
|
||||
self._cas_handler = hs.get_cas_handler()
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
client_redirect_url_param = urllib.parse.urlencode(
|
||||
{b"redirectUrl": client_redirect_url}
|
||||
).encode("ascii")
|
||||
hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
|
||||
service_param = urllib.parse.urlencode(
|
||||
{b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
|
||||
).encode("ascii")
|
||||
return b"%s/login?%s" % (self.cas_server_url, service_param)
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
return self._cas_handler.handle_redirect_request(client_redirect_url)
|
||||
|
||||
|
||||
class CasTicketServlet(RestServlet):
|
||||
@@ -447,81 +433,15 @@ class CasTicketServlet(RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(CasTicketServlet, self).__init__()
|
||||
self.cas_server_url = hs.config.cas_server_url
|
||||
self.cas_service_url = hs.config.cas_service_url
|
||||
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._cas_handler = hs.get_cas_handler()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> None:
|
||||
client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
||||
uri = self.cas_server_url + "/proxyValidate"
|
||||
args = {
|
||||
"ticket": parse_string(request, "ticket", required=True),
|
||||
"service": self.cas_service_url,
|
||||
}
|
||||
try:
|
||||
body = await self._http_client.get_raw(uri, args)
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted raises this error if the connection is closed,
|
||||
# even if that's being used old-http style to signal end-of-data
|
||||
body = pde.response
|
||||
result = await self.handle_cas_response(request, body, client_redirect_url)
|
||||
return result
|
||||
|
||||
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
||||
user, attributes = self.parse_cas_response(cas_response_body)
|
||||
displayname = attributes.pop(self.cas_displayname_attribute, None)
|
||||
|
||||
for required_attribute, required_value in self.cas_required_attributes.items():
|
||||
# If required attribute was not in CAS Response - Forbidden
|
||||
if required_attribute not in attributes:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
# Also need to check value
|
||||
if required_value is not None:
|
||||
actual_value = attributes[required_attribute]
|
||||
# If required attribute value does not match expected - Forbidden
|
||||
if required_value != actual_value:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
return self._sso_auth_handler.on_successful_auth(
|
||||
user, request, client_redirect_url, displayname
|
||||
ticket = parse_string(request, "ticket", required=True)
|
||||
await self._cas_handler.handle_ticket_request(
|
||||
request, client_redirect_url, ticket
|
||||
)
|
||||
|
||||
def parse_cas_response(self, cas_response_body):
|
||||
user = None
|
||||
attributes = {}
|
||||
try:
|
||||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise Exception("root of CAS response is not serviceResponse")
|
||||
success = root[0].tag.endswith("authenticationSuccess")
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
if child.tag.endswith("attributes"):
|
||||
for attribute in child:
|
||||
# ElementTree library expands the namespace in
|
||||
# attribute tags to the full URL of the namespace.
|
||||
# We don't care about namespace here and it will always
|
||||
# be encased in curly braces, so we remove them.
|
||||
tag = attribute.tag
|
||||
if "}" in tag:
|
||||
tag = tag.split("}")[1]
|
||||
attributes[tag] = attribute.text
|
||||
if user is None:
|
||||
raise Exception("CAS response does not contain user")
|
||||
except Exception:
|
||||
logger.exception("Error parsing CAS response")
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not success:
|
||||
raise LoginError(
|
||||
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
return user, attributes
|
||||
|
||||
|
||||
class SAMLRedirectServlet(BaseSSORedirectServlet):
|
||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
||||
@@ -529,72 +449,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
|
||||
def __init__(self, hs):
|
||||
self._saml_handler = hs.get_saml_handler()
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
return self._saml_handler.handle_redirect_request(client_redirect_url)
|
||||
|
||||
|
||||
class SSOAuthHandler(object):
|
||||
"""
|
||||
Utility class for Resources and Servlets which handle the response from a SSO
|
||||
service
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer)
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self._hostname = hs.hostname
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._macaroon_gen = hs.get_macaroon_generator()
|
||||
|
||||
# Load the redirect page HTML template
|
||||
self._template = load_jinja2_templates(
|
||||
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
|
||||
)[0]
|
||||
|
||||
self._server_name = hs.config.server_name
|
||||
|
||||
# cast to tuple for use with str.startswith
|
||||
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
||||
|
||||
async def on_successful_auth(
|
||||
self, username, request, client_redirect_url, user_display_name=None
|
||||
):
|
||||
"""Called once the user has successfully authenticated with the SSO.
|
||||
|
||||
Registers the user if necessary, and then returns a redirect (with
|
||||
a login token) to the client.
|
||||
|
||||
Args:
|
||||
username (unicode|bytes): the remote user id. We'll map this onto
|
||||
something sane for a MXID localpath.
|
||||
|
||||
request (SynapseRequest): the incoming request from the browser. We'll
|
||||
respond to it with a redirect.
|
||||
|
||||
client_redirect_url (unicode): the redirect_url the client gave us when
|
||||
it first started the process.
|
||||
|
||||
user_display_name (unicode|None): if set, and we have to register a new user,
|
||||
we will set their displayname to this.
|
||||
|
||||
Returns:
|
||||
Deferred[none]: Completes once we have handled the request.
|
||||
"""
|
||||
localpart = map_username_to_mxid_localpart(username)
|
||||
user_id = UserID(localpart, self._hostname).to_string()
|
||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||
if not registered_user_id:
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=user_display_name
|
||||
)
|
||||
|
||||
self._auth_handler.complete_sso_login(
|
||||
registered_user_id, request, client_redirect_url
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
if hs.config.cas_enabled:
|
||||
|
||||
@@ -234,13 +234,16 @@ class PasswordRestServlet(RestServlet):
|
||||
if self.auth.has_access_token(request):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
params = await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester, request, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
user_id = requester.user.to_string()
|
||||
else:
|
||||
requester = None
|
||||
result, params, _ = await self.auth_handler.check_auth(
|
||||
[[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
|
||||
[[LoginType.EMAIL_IDENTITY]],
|
||||
request,
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
)
|
||||
|
||||
if LoginType.EMAIL_IDENTITY in result:
|
||||
@@ -308,7 +311,7 @@ class DeactivateAccountRestServlet(RestServlet):
|
||||
return 200, {}
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester, request, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
result = await self._deactivate_account_handler.deactivate_account(
|
||||
requester.user.to_string(), erase, id_server=body.get("id_server")
|
||||
@@ -602,6 +605,11 @@ class ThreepidRestServlet(RestServlet):
|
||||
return 200, {"threepids": threepids}
|
||||
|
||||
async def on_POST(self, request):
|
||||
if not self.hs.config.enable_3pid_changes:
|
||||
raise SynapseError(
|
||||
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
@@ -646,6 +654,11 @@ class ThreepidAddRestServlet(RestServlet):
|
||||
|
||||
@interactive_auth_handler
|
||||
async def on_POST(self, request):
|
||||
if not self.hs.config.enable_3pid_changes:
|
||||
raise SynapseError(
|
||||
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
@@ -656,7 +669,7 @@ class ThreepidAddRestServlet(RestServlet):
|
||||
assert_valid_client_secret(client_secret)
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester, request, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
|
||||
validation_session = await self.identity_handler.validate_threepid_session(
|
||||
@@ -741,10 +754,16 @@ class ThreepidDeleteRestServlet(RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThreepidDeleteRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
async def on_POST(self, request):
|
||||
if not self.hs.config.enable_3pid_changes:
|
||||
raise SynapseError(
|
||||
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ["medium", "address"])
|
||||
|
||||
|
||||
@@ -142,14 +142,6 @@ class AuthRestServlet(RestServlet):
|
||||
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
|
||||
"sitekey": self.hs.config.recaptcha_public_key,
|
||||
}
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
|
||||
request.write(html_bytes)
|
||||
finish_request(request)
|
||||
return None
|
||||
elif stagetype == LoginType.TERMS:
|
||||
html = TERMS_TEMPLATE % {
|
||||
"session": session,
|
||||
@@ -158,17 +150,19 @@ class AuthRestServlet(RestServlet):
|
||||
"myurl": "%s/r0/auth/%s/fallback/web"
|
||||
% (CLIENT_API_PREFIX, LoginType.TERMS),
|
||||
}
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
|
||||
request.write(html_bytes)
|
||||
finish_request(request)
|
||||
return None
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
||||
# Render the HTML and return.
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
|
||||
request.write(html_bytes)
|
||||
finish_request(request)
|
||||
return None
|
||||
|
||||
async def on_POST(self, request, stagetype):
|
||||
|
||||
session = parse_string(request, "session")
|
||||
@@ -196,15 +190,6 @@ class AuthRestServlet(RestServlet):
|
||||
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
|
||||
"sitekey": self.hs.config.recaptcha_public_key,
|
||||
}
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
|
||||
request.write(html_bytes)
|
||||
finish_request(request)
|
||||
|
||||
return None
|
||||
elif stagetype == LoginType.TERMS:
|
||||
authdict = {"session": session}
|
||||
|
||||
@@ -225,17 +210,19 @@ class AuthRestServlet(RestServlet):
|
||||
"myurl": "%s/r0/auth/%s/fallback/web"
|
||||
% (CLIENT_API_PREFIX, LoginType.TERMS),
|
||||
}
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
|
||||
request.write(html_bytes)
|
||||
finish_request(request)
|
||||
return None
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
||||
# Render the HTML and return.
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
|
||||
request.write(html_bytes)
|
||||
finish_request(request)
|
||||
return None
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ class DeleteDevicesRestServlet(RestServlet):
|
||||
assert_params_in_dict(body, ["devices"])
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester, request, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
|
||||
await self.device_handler.delete_devices(
|
||||
@@ -127,7 +127,7 @@ class DeviceRestServlet(RestServlet):
|
||||
raise
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester, request, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
|
||||
await self.device_handler.delete_device(requester.user.to_string(), device_id)
|
||||
|
||||
@@ -263,7 +263,7 @@ class SigningKeyUploadServlet(RestServlet):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester, request, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
|
||||
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
|
||||
|
||||
58
synapse/rest/client/v2_alpha/password_policy.py
Normal file
58
synapse/rest/client/v2_alpha/password_policy.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordPolicyServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/password_policy$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(PasswordPolicyServlet, self).__init__()
|
||||
|
||||
self.policy = hs.config.password_policy
|
||||
self.enabled = hs.config.password_policy_enabled
|
||||
|
||||
def on_GET(self, request):
|
||||
if not self.enabled or not self.policy:
|
||||
return (200, {})
|
||||
|
||||
policy = {}
|
||||
|
||||
for param in [
|
||||
"minimum_length",
|
||||
"require_digit",
|
||||
"require_symbol",
|
||||
"require_lowercase",
|
||||
"require_uppercase",
|
||||
]:
|
||||
if param in self.policy:
|
||||
policy["m.%s" % param] = self.policy[param]
|
||||
|
||||
return (200, policy)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PasswordPolicyServlet(hs).register(http_server)
|
||||
@@ -373,6 +373,7 @@ class RegisterRestServlet(RestServlet):
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self.ratelimiter = hs.get_registration_ratelimiter()
|
||||
self.password_policy_handler = hs.get_password_policy_handler()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self._registration_flows = _calculate_registration_flows(
|
||||
@@ -420,6 +421,7 @@ class RegisterRestServlet(RestServlet):
|
||||
or len(body["password"]) > 512
|
||||
):
|
||||
raise SynapseError(400, "Invalid password")
|
||||
self.password_policy_handler.validate_password(body["password"])
|
||||
|
||||
desired_username = None
|
||||
if "username" in body:
|
||||
@@ -499,7 +501,10 @@ class RegisterRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
auth_result, params, session_id = await self.auth_handler.check_auth(
|
||||
self._registration_flows, body, self.hs.get_ip_from_request(request)
|
||||
self._registration_flows,
|
||||
request,
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
)
|
||||
|
||||
# Check that we're not trying to register a denied 3pid.
|
||||
|
||||
@@ -188,7 +188,7 @@ class RoomKeysServlet(RestServlet):
|
||||
"""
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
version = parse_string(request, "version")
|
||||
version = parse_string(request, "version", required=True)
|
||||
|
||||
room_keys = await self.e2e_room_keys_handler.get_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
|
||||
@@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource):
|
||||
b" media-src 'self';"
|
||||
b" object-src 'self';",
|
||||
)
|
||||
request.setHeader(
|
||||
b"Referrer-Policy", b"no-referrer",
|
||||
)
|
||||
server_name, media_id, name = parse_media_id(request)
|
||||
if server_name == self.server_name:
|
||||
await self.media_repo.get_local_media(request, media_id, name)
|
||||
|
||||
@@ -24,7 +24,6 @@ from six import iteritems
|
||||
|
||||
import twisted.internet.error
|
||||
import twisted.web.http
|
||||
from twisted.internet import defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.errors import (
|
||||
@@ -114,15 +113,14 @@ class MediaRepository(object):
|
||||
"update_recently_accessed_media", self._update_recently_accessed
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_recently_accessed(self):
|
||||
async def _update_recently_accessed(self):
|
||||
remote_media = self.recently_accessed_remotes
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
local_media = self.recently_accessed_locals
|
||||
self.recently_accessed_locals = set()
|
||||
|
||||
yield self.store.update_cached_last_access_time(
|
||||
await self.store.update_cached_last_access_time(
|
||||
local_media, remote_media, self.clock.time_msec()
|
||||
)
|
||||
|
||||
@@ -138,8 +136,7 @@ class MediaRepository(object):
|
||||
else:
|
||||
self.recently_accessed_locals.add(media_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_content(
|
||||
async def create_content(
|
||||
self, media_type, upload_name, content, content_length, auth_user
|
||||
):
|
||||
"""Store uploaded content for a local user and return the mxc URL
|
||||
@@ -158,11 +155,11 @@ class MediaRepository(object):
|
||||
|
||||
file_info = FileInfo(server_name=None, file_id=media_id)
|
||||
|
||||
fname = yield self.media_storage.store_file(content, file_info)
|
||||
fname = await self.media_storage.store_file(content, file_info)
|
||||
|
||||
logger.info("Stored local media in file %r", fname)
|
||||
|
||||
yield self.store.store_local_media(
|
||||
await self.store.store_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
@@ -171,12 +168,11 @@ class MediaRepository(object):
|
||||
user_id=auth_user,
|
||||
)
|
||||
|
||||
yield self._generate_thumbnails(None, media_id, media_id, media_type)
|
||||
await self._generate_thumbnails(None, media_id, media_id, media_type)
|
||||
|
||||
return "mxc://%s/%s" % (self.server_name, media_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_local_media(self, request, media_id, name):
|
||||
async def get_local_media(self, request, media_id, name):
|
||||
"""Responds to reqests for local media, if exists, or returns 404.
|
||||
|
||||
Args:
|
||||
@@ -190,7 +186,7 @@ class MediaRepository(object):
|
||||
Deferred: Resolves once a response has successfully been written
|
||||
to request
|
||||
"""
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
respond_404(request)
|
||||
return
|
||||
@@ -204,13 +200,12 @@ class MediaRepository(object):
|
||||
|
||||
file_info = FileInfo(None, media_id, url_cache=url_cache)
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
yield respond_with_responder(
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
await respond_with_responder(
|
||||
request, responder, media_type, media_length, upload_name
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_media(self, request, server_name, media_id, name):
|
||||
async def get_remote_media(self, request, server_name, media_id, name):
|
||||
"""Respond to requests for remote media.
|
||||
|
||||
Args:
|
||||
@@ -236,8 +231,8 @@ class MediaRepository(object):
|
||||
# We linearize here to ensure that we don't try and download remote
|
||||
# media multiple times concurrently
|
||||
key = (server_name, media_id)
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
responder, media_info = yield self._get_remote_media_impl(
|
||||
with (await self.remote_media_linearizer.queue(key)):
|
||||
responder, media_info = await self._get_remote_media_impl(
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
@@ -246,14 +241,13 @@ class MediaRepository(object):
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
yield respond_with_responder(
|
||||
await respond_with_responder(
|
||||
request, responder, media_type, media_length, upload_name
|
||||
)
|
||||
else:
|
||||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_media_info(self, server_name, media_id):
|
||||
async def get_remote_media_info(self, server_name, media_id):
|
||||
"""Gets the media info associated with the remote file, downloading
|
||||
if necessary.
|
||||
|
||||
@@ -274,8 +268,8 @@ class MediaRepository(object):
|
||||
# We linearize here to ensure that we don't try and download remote
|
||||
# media multiple times concurrently
|
||||
key = (server_name, media_id)
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
responder, media_info = yield self._get_remote_media_impl(
|
||||
with (await self.remote_media_linearizer.queue(key)):
|
||||
responder, media_info = await self._get_remote_media_impl(
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
@@ -286,8 +280,7 @@ class MediaRepository(object):
|
||||
|
||||
return media_info
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remote_media_impl(self, server_name, media_id):
|
||||
async def _get_remote_media_impl(self, server_name, media_id):
|
||||
"""Looks for media in local cache, if not there then attempt to
|
||||
download from remote server.
|
||||
|
||||
@@ -299,7 +292,7 @@ class MediaRepository(object):
|
||||
Returns:
|
||||
Deferred[(Responder, media_info)]
|
||||
"""
|
||||
media_info = yield self.store.get_cached_remote_media(server_name, media_id)
|
||||
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
||||
|
||||
# file_id is the ID we use to track the file locally. If we've already
|
||||
# seen the file then reuse the existing ID, otherwise genereate a new
|
||||
@@ -317,19 +310,18 @@ class MediaRepository(object):
|
||||
logger.info("Media is quarantined")
|
||||
raise NotFoundError()
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
return responder, media_info
|
||||
|
||||
# Failed to find the file anywhere, lets download it.
|
||||
|
||||
media_info = yield self._download_remote_file(server_name, media_id, file_id)
|
||||
media_info = await self._download_remote_file(server_name, media_id, file_id)
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
return responder, media_info
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _download_remote_file(self, server_name, media_id, file_id):
|
||||
async def _download_remote_file(self, server_name, media_id, file_id):
|
||||
"""Attempt to download the remote file from the given server name,
|
||||
using the given file_id as the local id.
|
||||
|
||||
@@ -351,7 +343,7 @@ class MediaRepository(object):
|
||||
("/_matrix/media/v1/download", server_name, media_id)
|
||||
)
|
||||
try:
|
||||
length, headers = yield self.client.get_file(
|
||||
length, headers = await self.client.get_file(
|
||||
server_name,
|
||||
request_path,
|
||||
output_stream=f,
|
||||
@@ -397,7 +389,7 @@ class MediaRepository(object):
|
||||
)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
yield finish()
|
||||
await finish()
|
||||
|
||||
media_type = headers[b"Content-Type"][0].decode("ascii")
|
||||
upload_name = get_filename_from_headers(headers)
|
||||
@@ -405,7 +397,7 @@ class MediaRepository(object):
|
||||
|
||||
logger.info("Stored remote media in file %r", fname)
|
||||
|
||||
yield self.store.store_cached_remote_media(
|
||||
await self.store.store_cached_remote_media(
|
||||
origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
@@ -423,7 +415,7 @@ class MediaRepository(object):
|
||||
"filesystem_id": file_id,
|
||||
}
|
||||
|
||||
yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
|
||||
await self._generate_thumbnails(server_name, media_id, file_id, media_type)
|
||||
|
||||
return media_info
|
||||
|
||||
@@ -458,16 +450,15 @@ class MediaRepository(object):
|
||||
|
||||
return t_byte_source
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_local_exact_thumbnail(
|
||||
async def generate_local_exact_thumbnail(
|
||||
self, media_id, t_width, t_height, t_method, t_type, url_cache
|
||||
):
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
|
||||
input_path = await self.media_storage.ensure_media_is_in_local_cache(
|
||||
FileInfo(None, media_id, url_cache=url_cache)
|
||||
)
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield defer_to_thread(
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
self._generate_thumbnail,
|
||||
thumbnailer,
|
||||
@@ -490,7 +481,7 @@ class MediaRepository(object):
|
||||
thumbnail_type=t_type,
|
||||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
output_path = await self.media_storage.store_file(
|
||||
t_byte_source, file_info
|
||||
)
|
||||
finally:
|
||||
@@ -500,22 +491,21 @@ class MediaRepository(object):
|
||||
|
||||
t_len = os.path.getsize(output_path)
|
||||
|
||||
yield self.store.store_local_thumbnail(
|
||||
await self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_remote_exact_thumbnail(
|
||||
async def generate_remote_exact_thumbnail(
|
||||
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
|
||||
):
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
|
||||
input_path = await self.media_storage.ensure_media_is_in_local_cache(
|
||||
FileInfo(server_name, file_id, url_cache=False)
|
||||
)
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield defer_to_thread(
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
self._generate_thumbnail,
|
||||
thumbnailer,
|
||||
@@ -537,7 +527,7 @@ class MediaRepository(object):
|
||||
thumbnail_type=t_type,
|
||||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
output_path = await self.media_storage.store_file(
|
||||
t_byte_source, file_info
|
||||
)
|
||||
finally:
|
||||
@@ -547,7 +537,7 @@ class MediaRepository(object):
|
||||
|
||||
t_len = os.path.getsize(output_path)
|
||||
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
await self.store.store_remote_media_thumbnail(
|
||||
server_name,
|
||||
media_id,
|
||||
file_id,
|
||||
@@ -560,8 +550,7 @@ class MediaRepository(object):
|
||||
|
||||
return output_path
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_thumbnails(
|
||||
async def _generate_thumbnails(
|
||||
self, server_name, media_id, file_id, media_type, url_cache=False
|
||||
):
|
||||
"""Generate and store thumbnails for an image.
|
||||
@@ -582,7 +571,7 @@ class MediaRepository(object):
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
|
||||
input_path = await self.media_storage.ensure_media_is_in_local_cache(
|
||||
FileInfo(server_name, file_id, url_cache=url_cache)
|
||||
)
|
||||
|
||||
@@ -600,7 +589,7 @@ class MediaRepository(object):
|
||||
return
|
||||
|
||||
if thumbnailer.transpose_method is not None:
|
||||
m_width, m_height = yield defer_to_thread(
|
||||
m_width, m_height = await defer_to_thread(
|
||||
self.hs.get_reactor(), thumbnailer.transpose
|
||||
)
|
||||
|
||||
@@ -620,11 +609,11 @@ class MediaRepository(object):
|
||||
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
|
||||
# Generate the thumbnail
|
||||
if t_method == "crop":
|
||||
t_byte_source = yield defer_to_thread(
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
|
||||
)
|
||||
elif t_method == "scale":
|
||||
t_byte_source = yield defer_to_thread(
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
|
||||
)
|
||||
else:
|
||||
@@ -646,7 +635,7 @@ class MediaRepository(object):
|
||||
url_cache=url_cache,
|
||||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
output_path = await self.media_storage.store_file(
|
||||
t_byte_source, file_info
|
||||
)
|
||||
finally:
|
||||
@@ -656,7 +645,7 @@ class MediaRepository(object):
|
||||
|
||||
# Write to database
|
||||
if server_name:
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
await self.store.store_remote_media_thumbnail(
|
||||
server_name,
|
||||
media_id,
|
||||
file_id,
|
||||
@@ -667,15 +656,14 @@ class MediaRepository(object):
|
||||
t_len,
|
||||
)
|
||||
else:
|
||||
yield self.store.store_local_thumbnail(
|
||||
await self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
return {"width": m_width, "height": m_height}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_old_remote_media(self, before_ts):
|
||||
old_media = yield self.store.get_remote_media_before(before_ts)
|
||||
async def delete_old_remote_media(self, before_ts):
|
||||
old_media = await self.store.get_remote_media_before(before_ts)
|
||||
|
||||
deleted = 0
|
||||
|
||||
@@ -689,7 +677,7 @@ class MediaRepository(object):
|
||||
|
||||
# TODO: Should we delete from the backup store
|
||||
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
with (await self.remote_media_linearizer.queue(key)):
|
||||
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
||||
try:
|
||||
os.remove(full_path)
|
||||
@@ -705,7 +693,7 @@ class MediaRepository(object):
|
||||
)
|
||||
shutil.rmtree(thumbnail_dir, ignore_errors=True)
|
||||
|
||||
yield self.store.delete_remote_media(origin, media_id)
|
||||
await self.store.delete_remote_media(origin, media_id)
|
||||
deleted += 1
|
||||
|
||||
return {"deleted": deleted}
|
||||
|
||||
@@ -165,8 +165,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
|
||||
respond_with_json_bytes(request, 200, og, send_cors=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_preview(self, url, user, ts):
|
||||
async def _do_preview(self, url, user, ts):
|
||||
"""Check the db, and download the URL and build a preview
|
||||
|
||||
Args:
|
||||
@@ -179,7 +178,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
"""
|
||||
# check the URL cache in the DB (which will also provide us with
|
||||
# historical previews, if we have any)
|
||||
cache_result = yield self.store.get_url_cache(url, ts)
|
||||
cache_result = await self.store.get_url_cache(url, ts)
|
||||
if (
|
||||
cache_result
|
||||
and cache_result["expires_ts"] > ts
|
||||
@@ -192,13 +191,13 @@ class PreviewUrlResource(DirectServeResource):
|
||||
og = og.encode("utf8")
|
||||
return og
|
||||
|
||||
media_info = yield self._download_url(url, user)
|
||||
media_info = await self._download_url(url, user)
|
||||
|
||||
logger.debug("got media_info of '%s'", media_info)
|
||||
|
||||
if _is_media(media_info["media_type"]):
|
||||
file_id = media_info["filesystem_id"]
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
dims = await self.media_repo._generate_thumbnails(
|
||||
None, file_id, file_id, media_info["media_type"], url_cache=True
|
||||
)
|
||||
|
||||
@@ -248,14 +247,14 @@ class PreviewUrlResource(DirectServeResource):
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
if "og:image" in og and og["og:image"]:
|
||||
image_info = yield self._download_url(
|
||||
image_info = await self._download_url(
|
||||
_rebase_url(og["og:image"], media_info["uri"]), user
|
||||
)
|
||||
|
||||
if _is_media(image_info["media_type"]):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
file_id = image_info["filesystem_id"]
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
dims = await self.media_repo._generate_thumbnails(
|
||||
None, file_id, file_id, image_info["media_type"], url_cache=True
|
||||
)
|
||||
if dims:
|
||||
@@ -293,7 +292,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
jsonog = json.dumps(og)
|
||||
|
||||
# store OG in history-aware DB cache
|
||||
yield self.store.store_url_cache(
|
||||
await self.store.store_url_cache(
|
||||
url,
|
||||
media_info["response_code"],
|
||||
media_info["etag"],
|
||||
@@ -305,8 +304,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
|
||||
return jsonog.encode("utf8")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _download_url(self, url, user):
|
||||
async def _download_url(self, url, user):
|
||||
# TODO: we should probably honour robots.txt... except in practice
|
||||
# we're most likely being explicitly triggered by a human rather than a
|
||||
# bot, so are we really a robot?
|
||||
@@ -318,7 +316,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
try:
|
||||
logger.debug("Trying to get url '%s'", url)
|
||||
length, headers, uri, code = yield self.client.get_file(
|
||||
length, headers, uri, code = await self.client.get_file(
|
||||
url, output_stream=f, max_size=self.max_spider_size
|
||||
)
|
||||
except SynapseError:
|
||||
@@ -345,7 +343,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
% (traceback.format_exception_only(sys.exc_info()[0], e),),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
yield finish()
|
||||
await finish()
|
||||
|
||||
try:
|
||||
if b"Content-Type" in headers:
|
||||
@@ -356,7 +354,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
|
||||
download_name = get_filename_from_headers(headers)
|
||||
|
||||
yield self.store.store_local_media(
|
||||
await self.store.store_local_media(
|
||||
media_id=file_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
@@ -393,8 +391,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
"expire_url_cache_data", self._expire_url_cache_data
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _expire_url_cache_data(self):
|
||||
async def _expire_url_cache_data(self):
|
||||
"""Clean up expired url cache content, media and thumbnails.
|
||||
"""
|
||||
# TODO: Delete from backup media store
|
||||
@@ -403,12 +400,12 @@ class PreviewUrlResource(DirectServeResource):
|
||||
|
||||
logger.info("Running url preview cache expiry")
|
||||
|
||||
if not (yield self.store.db.updates.has_completed_background_updates()):
|
||||
if not (await self.store.db.updates.has_completed_background_updates()):
|
||||
logger.info("Still running DB updates; skipping expiry")
|
||||
return
|
||||
|
||||
# First we delete expired url cache entries
|
||||
media_ids = yield self.store.get_expired_url_cache(now)
|
||||
media_ids = await self.store.get_expired_url_cache(now)
|
||||
|
||||
removed_media = []
|
||||
for media_id in media_ids:
|
||||
@@ -430,7 +427,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield self.store.delete_url_cache(removed_media)
|
||||
await self.store.delete_url_cache(removed_media)
|
||||
|
||||
if removed_media:
|
||||
logger.info("Deleted %d entries from url cache", len(removed_media))
|
||||
@@ -440,7 +437,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
# may have a room open with a preview url thing open).
|
||||
# So we wait a couple of days before deleting, just in case.
|
||||
expire_before = now - 2 * 24 * 60 * 60 * 1000
|
||||
media_ids = yield self.store.get_url_cache_media_before(expire_before)
|
||||
media_ids = await self.store.get_url_cache_media_before(expire_before)
|
||||
|
||||
removed_media = []
|
||||
for media_id in media_ids:
|
||||
@@ -478,7 +475,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield self.store.delete_url_cache_media(removed_media)
|
||||
await self.store.delete_url_cache_media(removed_media)
|
||||
|
||||
logger.info("Deleted %d media from url cache", len(removed_media))
|
||||
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
set_cors_headers,
|
||||
@@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource):
|
||||
)
|
||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_local_thumbnail(
|
||||
async def _respond_local_thumbnail(
|
||||
self, request, media_id, width, height, method, m_type
|
||||
):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
respond_404(request)
|
||||
@@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource):
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||
|
||||
if thumbnail_infos:
|
||||
thumbnail_info = self._select_thumbnail(
|
||||
@@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource):
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = thumbnail_info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
await respond_with_responder(request, responder, t_type, t_length)
|
||||
else:
|
||||
logger.info("Couldn't find any generated thumbnails")
|
||||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _select_or_generate_local_thumbnail(
|
||||
async def _select_or_generate_local_thumbnail(
|
||||
self,
|
||||
request,
|
||||
media_id,
|
||||
@@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource):
|
||||
desired_method,
|
||||
desired_type,
|
||||
):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
respond_404(request)
|
||||
@@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource):
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||
for info in thumbnail_infos:
|
||||
t_w = info["thumbnail_width"] == desired_width
|
||||
t_h = info["thumbnail_height"] == desired_height
|
||||
@@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource):
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
await respond_with_responder(request, responder, t_type, t_length)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||
|
||||
# Okay, so we generate one.
|
||||
file_path = yield self.media_repo.generate_local_exact_thumbnail(
|
||||
file_path = await self.media_repo.generate_local_exact_thumbnail(
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
@@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource):
|
||||
)
|
||||
|
||||
if file_path:
|
||||
yield respond_with_file(request, desired_type, file_path)
|
||||
await respond_with_file(request, desired_type, file_path)
|
||||
else:
|
||||
logger.warning("Failed to generate thumbnail")
|
||||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _select_or_generate_remote_thumbnail(
|
||||
async def _select_or_generate_remote_thumbnail(
|
||||
self,
|
||||
request,
|
||||
server_name,
|
||||
@@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource):
|
||||
desired_method,
|
||||
desired_type,
|
||||
):
|
||||
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
|
||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
@@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource):
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
await respond_with_responder(request, responder, t_type, t_length)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||
|
||||
# Okay, so we generate one.
|
||||
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
|
||||
file_path = await self.media_repo.generate_remote_exact_thumbnail(
|
||||
server_name,
|
||||
file_id,
|
||||
media_id,
|
||||
@@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource):
|
||||
)
|
||||
|
||||
if file_path:
|
||||
yield respond_with_file(request, desired_type, file_path)
|
||||
await respond_with_file(request, desired_type, file_path)
|
||||
else:
|
||||
logger.warning("Failed to generate thumbnail")
|
||||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_remote_thumbnail(
|
||||
async def _respond_remote_thumbnail(
|
||||
self, request, server_name, media_id, width, height, method, m_type
|
||||
):
|
||||
# TODO: Don't download the whole remote file
|
||||
# We should proxy the thumbnail from the remote server instead of
|
||||
# downloading the remote file and generating our own thumbnails.
|
||||
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
|
||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
@@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource):
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = thumbnail_info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
await respond_with_responder(request, responder, t_type, t_length)
|
||||
else:
|
||||
logger.info("Failed to find any generated thumbnails")
|
||||
respond_404(request)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user