1
0

Compare commits

..

29 Commits

Author SHA1 Message Date
Andrew Morgan
18cf53376d ree 2020-04-20 18:05:36 +01:00
Andrew Morgan
e49a90899b another 2020-04-20 17:59:46 +01:00
Andrew Morgan
6317eba770 ahh 2020-04-20 17:17:11 +01:00
Andrew Morgan
da51afdc6b whoops 2020-04-20 17:07:35 +01:00
Andrew Morgan
2ccad9a1b6 add debug logging 2020-04-20 16:50:17 +01:00
Andrew Morgan
714e75dc1b lint 2020-04-20 16:07:00 +01:00
Andrew Morgan
0c1b27ecd0 Resolve review comments 2020-04-20 16:05:44 +01:00
Andrew Morgan
ac1bbfdd2b Update changelog 2020-04-20 15:35:22 +01:00
Andrew Morgan
7dd06332a9 Update changelog 2020-04-20 11:58:35 +01:00
Andrew Morgan
76f15f4bf2 Remove extraneous key_id and verify_key 2020-04-20 11:57:13 +01:00
Andrew Morgan
887ec58556 Update method docstring 2020-04-17 12:54:55 +01:00
Andrew Morgan
b3b2da56b3 Send device updates, modeled after SigningKeyEduUpdater._handle_signing_key_updates 2020-04-17 12:30:54 +01:00
Andrew Morgan
cb56a51ada Factor key retrieval out into a separate function 2020-04-17 12:08:09 +01:00
Andrew Morgan
40042dec0d lint 2020-04-17 11:39:30 +01:00
Andrew Morgan
2c881bf8a4 Remove extraneous items from remote query try/except 2020-04-17 11:36:08 +01:00
Andrew Morgan
667e9ca5be Fix log statements, docstrings 2020-04-17 11:33:43 +01:00
Andrew Morgan
8490a8793c Only fetch master and self_signing key types 2020-04-16 20:01:34 +01:00
Andrew Morgan
2ff55e02c1 Add comment explaining why this is useful 2020-04-16 17:59:47 +01:00
Andrew Morgan
0ac339cfe6 lint 2020-04-16 17:55:59 +01:00
Andrew Morgan
f4064862b2 Note that _get_e2e_cross_signing_verify_key can raise a SynapseError 2020-04-16 17:54:34 +01:00
Andrew Morgan
67851671e5 Wrap get_verify_key_from_cross_signing_key in a try/except 2020-04-16 17:51:00 +01:00
Andrew Morgan
a7dadf87be Remove very specific exception handling 2020-04-16 17:48:52 +01:00
Andrew Morgan
c215378411 Make changelog more useful 2020-04-16 17:23:28 +01:00
Andrew Morgan
e91373c1fa Use query_user_devices instead, assume only master, self_signing key types 2020-04-16 17:13:57 +01:00
Andrew Morgan
4e48515bbf Fix and de-brittle remote result dict processing 2020-04-16 16:14:01 +01:00
Andrew Morgan
70807c83c6 lint 2020-04-16 13:58:29 +01:00
Andrew Morgan
44cf7cf56e Save retrieved keys to the db 2020-04-16 13:54:59 +01:00
Andrew Morgan
d2a9b45df0 Add changelog 2020-04-16 12:50:34 +01:00
Andrew Morgan
81e74fbae5 Query missing cross-signing keys on local sig upload 2020-04-16 12:49:00 +01:00
114 changed files with 1452 additions and 3922 deletions

View File

@@ -1,40 +1,15 @@
Next version
============
* New templates (`sso_auth_confirm.html`, `sso_auth_success.html`, and
`sso_account_deactivated.html`) were added to Synapse. If your Synapse is
configured to use SSO and a custom `sso_redirect_confirm_template_dir`
configuration then these templates will need to be duplicated into that
directory.
* Two new templates (`sso_auth_confirm.html` and `sso_account_deactivated.html`)
were added to Synapse. If your Synapse is configured to use SSO and a custom
`sso_redirect_confirm_template_dir` configuration then these templates will
need to be duplicated into that directory.
* Plugins using the `complete_sso_login` method of `synapse.module_api.ModuleApi`
should update to using the async/await version `complete_sso_login_async` which
includes additional checks. The non-async version is considered deprecated.
Synapse 1.12.4 (2020-04-23)
===========================
No significant changes.
Synapse 1.12.4rc1 (2020-04-22)
==============================
Features
--------
- Always send users their own device updates. ([\#7160](https://github.com/matrix-org/synapse/issues/7160))
- Add support for handling GET requests for `account_data` on a worker. ([\#7311](https://github.com/matrix-org/synapse/issues/7311))
Bugfixes
--------
- Fix a bug that prevented cross-signing with users on worker-mode synapses. ([\#7255](https://github.com/matrix-org/synapse/issues/7255))
- Do not treat display names as globs in push rules. ([\#7271](https://github.com/matrix-org/synapse/issues/7271))
- Fix a bug with cross-signing devices belonging to remote users who did not share a room with any user on the local homeserver. ([\#7289](https://github.com/matrix-org/synapse/issues/7289))
Synapse 1.12.3 (2020-04-03)
===========================
@@ -67,19 +42,12 @@ Bugfixes
Synapse 1.12.0 (2020-03-23)
===========================
No significant changes since 1.12.0rc1.
Debian packages and Docker images are rebuilt using the latest versions of
dependency libraries, including Twisted 20.3.0. **Please see security advisory
below**.
Potential slow database update during upgrade
---------------------------------------------
Synapse 1.12.0 includes a database update which is run as part of the upgrade,
and which may take some time (several hours in the case of a large
server). Synapse will not respond to HTTP requests while this update is taking
place. For imformation on seeing if you are affected, and workaround if you
are, see the [upgrade notes](UPGRADE.rst#upgrading-to-v1120).
Security advisory
-----------------

View File

@@ -75,71 +75,6 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.12.0
====================
This version includes a database update which is run as part of the upgrade,
and which may take some time (several hours in the case of a large
server). Synapse will not respond to HTTP requests while this update is taking
place.
This is only likely to be a problem in the case of a server which is
participating in many rooms.
0. As with all upgrades, it is recommended that you have a recent backup of
your database which can be used for recovery in the event of any problems.
1. As an initial check to see if you will be affected, you can try running the
following query from the `psql` or `sqlite3` console. It is safe to run it
while Synapse is still running.
.. code:: sql
SELECT MAX(q.v) FROM (
SELECT (
SELECT ej.json AS v
FROM state_events se INNER JOIN event_json ej USING (event_id)
WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
LIMIT 1
) FROM rooms WHERE rooms.room_version IS NULL
) q;
This query will take about the same amount of time as the upgrade process: ie,
if it takes 5 minutes, then it is likely that Synapse will be unresponsive for
5 minutes during the upgrade.
If you consider an outage of this duration to be acceptable, no further
action is necessary and you can simply start Synapse 1.12.0.
If you would prefer to reduce the downtime, continue with the steps below.
2. The easiest workaround for this issue is to manually
create a new index before upgrading. On PostgreSQL, his can be done as follows:
.. code:: sql
CREATE INDEX CONCURRENTLY tmp_upgrade_1_12_0_index
ON state_events(room_id) WHERE type = 'm.room.create';
The above query may take some time, but is also safe to run while Synapse is
running.
We assume that no SQLite users have databases large enough to be
affected. If you *are* affected, you can run a similar query, omitting the
``CONCURRENTLY`` keyword. Note however that this operation may in itself cause
Synapse to stop running for some time. Synapse admins are reminded that
`SQLite is not recommended for use outside a test
environment <https://github.com/matrix-org/synapse/blob/master/README.rst#using-postgresql>`_.
3. Once the index has been created, the ``SELECT`` query in step 1 above should
complete quickly. It is therefore safe to upgrade to Synapse 1.12.0.
4. Once Synapse 1.12.0 has successfully started and is responding to HTTP
requests, the temporary index can be removed:
.. code:: sql
DROP INDEX tmp_upgrade_1_12_0_index;
Upgrading to v1.10.0
====================

View File

@@ -1 +0,0 @@
Add support for running replication over Redis when using workers.

View File

@@ -1 +0,0 @@
Improve the documentation of application service configuration files.

View File

@@ -1 +0,0 @@
Run replication streamers on workers.

1
changelog.d/7160.feature Normal file
View File

@@ -0,0 +1 @@
Always send users their own device updates.

View File

@@ -1 +0,0 @@
Add explicit Python build tooling as dependencies for the snapcraft build.

View File

@@ -1 +0,0 @@
Extend room admin api (`GET /_synapse/admin/v1/rooms`) with additional attributes.

1
changelog.d/7255.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug that prevented cross-signing with users on worker-mode synapses.

View File

@@ -1 +0,0 @@
Reject unknown session IDs during user interactive authentication instead of silently creating a new session.

View File

@@ -1 +0,0 @@
Documentation of media_storage_providers options updated to avoid misunderstandings. Contributed by Tristan Lins.

View File

@@ -1 +0,0 @@
Add some unit tests for replication.

View File

@@ -1 +0,0 @@
Support SSO in the user interactive authentication workflow.

View File

@@ -1 +0,0 @@
Move catchup of replication streams logic to worker.

1
changelog.d/7289.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug with cross-signing devices with remote users when they did not share a room with any user on the local homeserver.

View File

@@ -1 +0,0 @@
Move catchup of replication streams logic to worker.

View File

@@ -1 +0,0 @@
Improve typing annotations in `synapse.replication.tcp.streams.Stream`.

View File

@@ -1 +0,0 @@
Reduce log verbosity of url cache cleanup tasks.

View File

@@ -1 +0,0 @@
Fix sample SAML Service Provider configuration. Contributed by @frcl.

View File

@@ -1 +0,0 @@
Fix StreamChangeCache to work with multiple entities changing on the same stream id.

View File

@@ -1 +0,0 @@
Allow `/requestToken` endpoints to hide the existence (or lack thereof) of 3PID associations on the homeserver.

View File

@@ -1 +0,0 @@
Move catchup of replication streams logic to worker.

View File

@@ -1 +0,0 @@
Fix an incorrect import in IdentityHandler.

View File

@@ -1 +0,0 @@
Reduce logging verbosity for successful federation requests.

View File

@@ -1 +0,0 @@
Add support for running replication over Redis when using workers.

View File

@@ -1 +0,0 @@
Move catchup of replication streams logic to worker.

View File

@@ -1 +0,0 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.

View File

@@ -1 +0,0 @@
Convert some federation handler code to async/await.

View File

@@ -1 +0,0 @@
Fix bad error handling that would cause Synapse to crash if it's provided with a YAML configuration file that's either empty or doesn't parse into a key-value map.

View File

@@ -1 +0,0 @@
Support SSO in the user interactive authentication workflow.

View File

@@ -1 +0,0 @@
Fix incorrect metrics reporting for `renew_attestations` background task.

View File

@@ -1 +0,0 @@
Add support for running replication over Redis when using workers.

View File

@@ -1 +0,0 @@
Add documentation on monitoring workers with Prometheus.

View File

@@ -1 +0,0 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.

View File

@@ -1 +0,0 @@
Fix collation for postgres for unit tests.

8
debian/changelog vendored
View File

@@ -1,16 +1,8 @@
<<<<<<< HEAD
matrix-synapse-py3 (1.12.3ubuntu1) UNRELEASED; urgency=medium
* Add information about .well-known files to Debian installation scripts.
-- Patrick Cloke <patrickc@matrix.org> Mon, 06 Apr 2020 10:10:38 -0400
=======
matrix-synapse-py3 (1.12.4) stable; urgency=medium
* New synapse release 1.12.4.
-- Synapse Packaging team <packages@matrix.org> Thu, 23 Apr 2020 10:58:14 -0400
>>>>>>> master
matrix-synapse-py3 (1.12.3) stable; urgency=medium

View File

@@ -11,21 +11,8 @@ The following query parameters are available:
* `from` - Offset in the returned list. Defaults to `0`.
* `limit` - Maximum amount of rooms to return. Defaults to `100`.
* `order_by` - The method in which to sort the returned list of rooms. Valid values are:
- `alphabetical` - Same as `name`. This is deprecated.
- `size` - Same as `joined_members`. This is deprecated.
- `name` - Rooms are ordered alphabetically by room name. This is the default.
- `canonical_alias` - Rooms are ordered alphabetically by main alias address of the room.
- `joined_members` - Rooms are ordered by the number of members. Largest to smallest.
- `joined_local_members` - Rooms are ordered by the number of local members. Largest to smallest.
- `version` - Rooms are ordered by room version. Largest to smallest.
- `creator` - Rooms are ordered alphabetically by creator of the room.
- `encryption` - Rooms are ordered alphabetically by the end-to-end encryption algorithm.
- `federatable` - Rooms are ordered by whether the room is federatable.
- `public` - Rooms are ordered by visibility in room list.
- `join_rules` - Rooms are ordered alphabetically by join rules of the room.
- `guest_access` - Rooms are ordered alphabetically by guest access option of the room.
- `history_visibility` - Rooms are ordered alphabetically by visibility of history of the room.
- `state_events` - Rooms are ordered by number of state events. Largest to smallest.
- `alphabetical` - Rooms are ordered alphabetically by room name. This is the default.
- `size` - Rooms are ordered by the number of members. Largest to smallest.
* `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting
this value to `b` will reverse the above sort order. Defaults to `f`.
* `search_term` - Filter rooms by their room name. Search term can be contained in any
@@ -39,16 +26,6 @@ The following fields are possible in the JSON response body:
- `name` - The name of the room.
- `canonical_alias` - The canonical (main) alias address of the room.
- `joined_members` - How many users are currently in the room.
- `joined_local_members` - How many local users are currently in the room.
- `version` - The version of the room as a string.
- `creator` - The `user_id` of the room creator.
- `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active.
- `federatable` - Whether users on other servers can join this room.
- `public` - Whether the room is visible in room directory.
- `join_rules` - The type of rules used for users wishing to join this room. One of: ["public", "knock", "invite", "private"].
- `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"].
- `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"].
- `state_events` - Total number of state_events of a room. Complexity of the room.
* `offset` - The current pagination offset in rooms. This parameter should be
used instead of `next_token` for room offset as `next_token` is
not intended to be parsed.
@@ -83,34 +60,14 @@ Response:
"room_id": "!OGEhHVWSdvArJzumhm:matrix.org",
"name": "Matrix HQ",
"canonical_alias": "#matrix:matrix.org",
"joined_members": 8326,
"joined_local_members": 2,
"version": "1",
"creator": "@foo:matrix.org",
"encryption": null,
"federatable": true,
"public": true,
"join_rules": "invite",
"guest_access": null,
"history_visibility": "shared",
"state_events": 93534
"joined_members": 8326
},
... (8 hidden items) ...
{
"room_id": "!xYvNcQPhnkrdUmYczI:matrix.org",
"name": "This Week In Matrix (TWIM)",
"canonical_alias": "#twim:matrix.org",
"joined_members": 314,
"joined_local_members": 20,
"version": "4",
"creator": "@foo:matrix.org",
"encryption": "m.megolm.v1.aes-sha2",
"federatable": true,
"public": false,
"join_rules": "invite",
"guest_access": null,
"history_visibility": "shared",
"state_events": 8345
"joined_members": 314
}
],
"offset": 0,
@@ -135,17 +92,7 @@ Response:
"room_id": "!xYvNcQPhnkrdUmYczI:matrix.org",
"name": "This Week In Matrix (TWIM)",
"canonical_alias": "#twim:matrix.org",
"joined_members": 314,
"joined_local_members": 20,
"version": "4",
"creator": "@foo:matrix.org",
"encryption": "m.megolm.v1.aes-sha2",
"federatable": true,
"public": false,
"join_rules": "invite",
"guest_access": null,
"history_visibility": "shared",
"state_events": 8
"joined_members": 314
}
],
"offset": 0,
@@ -170,34 +117,14 @@ Response:
"room_id": "!OGEhHVWSdvArJzumhm:matrix.org",
"name": "Matrix HQ",
"canonical_alias": "#matrix:matrix.org",
"joined_members": 8326,
"joined_local_members": 2,
"version": "1",
"creator": "@foo:matrix.org",
"encryption": null,
"federatable": true,
"public": true,
"join_rules": "invite",
"guest_access": null,
"history_visibility": "shared",
"state_events": 93534
"joined_members": 8326
},
... (98 hidden items) ...
{
"room_id": "!xYvNcQPhnkrdUmYczI:matrix.org",
"name": "This Week In Matrix (TWIM)",
"canonical_alias": "#twim:matrix.org",
"joined_members": 314,
"joined_local_members": 20,
"version": "4",
"creator": "@foo:matrix.org",
"encryption": "m.megolm.v1.aes-sha2",
"federatable": true,
"public": false,
"join_rules": "invite",
"guest_access": null,
"history_visibility": "shared",
"state_events": 8345
"joined_members": 314
}
],
"offset": 0,
@@ -227,16 +154,6 @@ Response:
"name": "Music Theory",
"canonical_alias": "#musictheory:matrix.org",
"joined_members": 127
"joined_local_members": 2,
"version": "1",
"creator": "@foo:matrix.org",
"encryption": null,
"federatable": true,
"public": true,
"join_rules": "invite",
"guest_access": null,
"history_visibility": "shared",
"state_events": 93534
},
... (48 hidden items) ...
{
@@ -244,16 +161,6 @@ Response:
"name": "weechat-matrix",
"canonical_alias": "#weechat-matrix:termina.org.uk",
"joined_members": 137
"joined_local_members": 20,
"version": "4",
"creator": "@foo:termina.org.uk",
"encryption": null,
"federatable": true,
"public": true,
"join_rules": "invite",
"guest_access": null,
"history_visibility": "shared",
"state_events": 8345
}
],
"offset": 100,

View File

@@ -23,13 +23,9 @@ namespaces:
users: # List of users we're interested in
- exclusive: <bool>
regex: <regex>
group_id: <group>
- ...
aliases: [] # List of aliases we're interested in
rooms: [] # List of room ids we're interested in
```
`exclusive`: If enabled, only this application service is allowed to register users in its namespace(s).
`group_id`: All users of this application service are dynamically joined to this group. This is useful for e.g user organisation or flairs.
See the [spec](https://matrix.org/docs/spec/application_service/unstable.html) for further details on how application services work.

View File

@@ -60,31 +60,6 @@
1. Restart Prometheus.
## Monitoring workers
To monitor a Synapse installation using
[workers](https://github.com/matrix-org/synapse/blob/master/docs/workers.md),
every worker needs to be monitored independently, in addition to
the main homeserver process. This is because workers don't send
their metrics to the main homeserver process, but expose them
directly (if they are configured to do so).
To allow collecting metrics from a worker, you need to add a
`metrics` listener to its configuration, by adding the following
under `worker_listeners`:
```yaml
- type: metrics
bind_address: ''
port: 9101
```
The `bind_address` and `port` parameters should be set so that
the resulting listener can be reached by prometheus, and they
don't clash with an existing worker.
With this example, the worker's metrics would then be available
on `http://127.0.0.1:9101`.
## Renaming of metrics & deprecation of old names in 1.2
Synapse 1.2 updates the Prometheus metrics to match the naming

View File

@@ -414,16 +414,6 @@ retention:
# longest_max_lifetime: 1y
# interval: 1d
# Inhibits the /requestToken endpoints from returning an error that might leak
# information about whether an e-mail address is in use or not on this
# homeserver.
# Note that for some endpoints the error situation is the e-mail already being
# used, and for others the error is entering the e-mail being unused.
# If this option is enabled, instead of returning an error, these endpoints will
# act as if no error happened and return a fake session ID ('sid') to clients.
#
#request_token_inhibit_3pid_errors: true
## TLS ##
@@ -745,11 +735,12 @@ media_store_path: "DATADIR/media_store"
#
#media_storage_providers:
# - module: file_system
# # Whether to store newly uploaded local files
# # Whether to write new local files.
# store_local: false
# # Whether to store newly downloaded remote files
# # Whether to write new remote media
# store_remote: false
# # Whether to wait for successful storage for local uploads
# # Whether to block upload requests waiting for write to this
# # provider to complete
# store_synchronous: false
# config:
# directory: /mnt/some/other/directory
@@ -1349,32 +1340,32 @@ saml2_config:
# remote:
# - url: https://our_idp/metadata.xml
#
# # By default, the user has to go to our login page first. If you'd like
# # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
# # 'service.sp' section:
# #
# #service:
# # sp:
# # allow_unsolicited: true
# # By default, the user has to go to our login page first. If you'd like
# # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
# # 'service.sp' section:
# #
# #service:
# # sp:
# # allow_unsolicited: true
#
# # The examples below are just used to generate our metadata xml, and you
# # may well not need them, depending on your setup. Alternatively you
# # may need a whole lot more detail - see the pysaml2 docs!
# # The examples below are just used to generate our metadata xml, and you
# # may well not need them, depending on your setup. Alternatively you
# # may need a whole lot more detail - see the pysaml2 docs!
#
# description: ["My awesome SP", "en"]
# name: ["Test SP", "en"]
# description: ["My awesome SP", "en"]
# name: ["Test SP", "en"]
#
# organization:
# name: Example com
# display_name:
# - ["Example co", "en"]
# url: "http://example.com"
# organization:
# name: Example com
# display_name:
# - ["Example co", "en"]
# url: "http://example.com"
#
# contact_person:
# - given_name: Bob
# sur_name: "the Sysadmin"
# email_address": ["admin@example.com"]
# contact_type": technical
# contact_person:
# - given_name: Bob
# sur_name: "the Sysadmin"
# email_address": ["admin@example.com"]
# contact_type": technical
# Instead of putting the config inline as above, you can specify a
# separate pysaml2 configuration file:
@@ -1518,30 +1509,6 @@ sso:
#
# * server_name: the homeserver's name.
#
# * HTML page which notifies the user that they are authenticating to confirm
# an operation on their account during the user interactive authentication
# process: 'sso_auth_confirm.html'.
#
# When rendering, this template is given the following variables:
# * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * description: the operation which the user is being asked to confirm
#
# * HTML page shown after a successful user interactive authentication session:
# 'sso_auth_success.html'.
#
# Note that this page must include the JavaScript which notifies of a successful authentication
# (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
#
# This template has no additional variables.
#
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#
# This template has no additional variables.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#

View File

@@ -196,7 +196,7 @@ Asks the server for the current position of all streams.
#### USER_SYNC (C)
A user has started or stopped syncing on this process.
A user has started or stopped syncing
#### CLEAR_USER_SYNC (C)
@@ -216,6 +216,10 @@ Asks the server for the current position of all streams.
Inform the server a cache should be invalidated
#### SYNC (S, C)
Used exclusively in tests
### REMOTE_SERVER_UP (S, C)
Inform other processes that a remote server may have come back online.

View File

@@ -120,7 +120,7 @@ Your home server configuration file needs the following extra keys:
As an example, here is the relevant section of the config file for matrix.org:
turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ]
turn_shared_secret: "n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons"
turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons
turn_user_lifetime: 86400000
turn_allow_guests: True

View File

@@ -286,8 +286,6 @@ Additionally, the following REST endpoints can be handled for GET requests:
^/_matrix/client/(api/v1|r0|unstable)/pushrules/.*$
^/_matrix/client/(api/v1|r0|unstable)/groups/.*$
^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/account_data/
^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/rooms/[^/]*/account_data/
Additionally, the following REST endpoints can be handled, but all requests must
be routed to the same instance:

View File

@@ -33,10 +33,6 @@ parts:
python-version: python3
python-packages:
- '.[all]'
- pip
- setuptools
- setuptools-scm
- wheel
build-packages:
- libffi-dev
- libturbojpeg0-dev

View File

@@ -1,13 +0,0 @@
from .sorteddict import (
SortedDict,
SortedKeysView,
SortedItemsView,
SortedValuesView,
)
__all__ = [
"SortedDict",
"SortedKeysView",
"SortedItemsView",
"SortedValuesView",
]

View File

@@ -1,124 +0,0 @@
# stub for SortedDict. This is a lightly edited copy of
# https://github.com/grantjenks/python-sortedcontainers/blob/eea42df1f7bad2792e8da77335ff888f04b9e5ae/sortedcontainers/sorteddict.pyi
# (from https://github.com/grantjenks/python-sortedcontainers/pull/107)
from typing import (
Any,
Callable,
Dict,
Hashable,
Iterator,
Iterable,
ItemsView,
KeysView,
List,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Tuple,
Union,
ValuesView,
overload,
)
_T = TypeVar("_T")
_S = TypeVar("_S")
_T_h = TypeVar("_T_h", bound=Hashable)
_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.
_KT_co = TypeVar("_KT_co", covariant=True, bound=Hashable)
_VT_co = TypeVar("_VT_co", covariant=True)
_SD = TypeVar("_SD", bound=SortedDict)
_Key = Callable[[_T], Any]
class SortedDict(Dict[_KT, _VT]):
@overload
def __init__(self, **kwargs: _VT) -> None: ...
@overload
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
@overload
def __init__(
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
@overload
def __init__(self, __key: _Key[_KT], **kwargs: _VT) -> None: ...
@overload
def __init__(
self, __key: _Key[_KT], __map: Mapping[_KT, _VT], **kwargs: _VT
) -> None: ...
@overload
def __init__(
self, __key: _Key[_KT], __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
@property
def key(self) -> Optional[_Key[_KT]]: ...
@property
def iloc(self) -> SortedKeysView[_KT]: ...
def clear(self) -> None: ...
def __delitem__(self, key: _KT) -> None: ...
def __iter__(self) -> Iterator[_KT]: ...
def __reversed__(self) -> Iterator[_KT]: ...
def __setitem__(self, key: _KT, value: _VT) -> None: ...
def _setitem(self, key: _KT, value: _VT) -> None: ...
def copy(self: _SD) -> _SD: ...
def __copy__(self: _SD) -> _SD: ...
@classmethod
@overload
def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ...
@classmethod
@overload
def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ...
def keys(self) -> SortedKeysView[_KT]: ...
def items(self) -> SortedItemsView[_KT, _VT]: ...
def values(self) -> SortedValuesView[_VT]: ...
@overload
def pop(self, key: _KT) -> _VT: ...
@overload
def pop(self, key: _KT, default: _T = ...) -> Union[_VT, _T]: ...
def popitem(self, index: int = ...) -> Tuple[_KT, _VT]: ...
def peekitem(self, index: int = ...) -> Tuple[_KT, _VT]: ...
def setdefault(self, key: _KT, default: Optional[_VT] = ...) -> _VT: ...
@overload
def update(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
@overload
def update(self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT) -> None: ...
@overload
def update(self, **kwargs: _VT) -> None: ...
def __reduce__(
self,
) -> Tuple[
Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
]: ...
def __repr__(self) -> str: ...
def _check(self) -> None: ...
def islice(
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
) -> Iterator[_KT]: ...
def bisect_left(self, value: _KT) -> int: ...
def bisect_right(self, value: _KT) -> int: ...
class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]):
@overload
def __getitem__(self, index: int) -> _KT_co: ...
@overload
def __getitem__(self, index: slice) -> List[_KT_co]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
class SortedItemsView( # type: ignore
ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]
):
def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ...
@overload
def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ...
@overload
def __getitem__(self, index: slice) -> List[Tuple[_KT_co, _VT_co]]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]):
@overload
def __getitem__(self, index: int) -> _VT_co: ...
@overload
def __getitem__(self, index: slice) -> List[_VT_co]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...

View File

@@ -1,40 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains *incomplete* type hints for txredisapi.
"""
from typing import List, Optional, Union
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
class SubscriberProtocol:
def subscribe(self, channels: Union[str, List[str]]): ...
def lazyConnection(
host: str = ...,
port: int = ...,
dbid: Optional[int] = ...,
reconnect: bool = ...,
charset: str = ...,
password: Optional[str] = ...,
connectTimeout: Optional[int] = ...,
replyTimeout: Optional[int] = ...,
convertNumbers: bool = ...,
) -> RedisProtocol: ...
class SubscriberFactory:
def buildProtocol(self, addr): ...

View File

@@ -36,7 +36,7 @@ try:
except ImportError:
pass
__version__ = "1.12.4"
__version__ = "1.12.3"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View File

@@ -97,8 +97,6 @@ class EventTypes(object):
Retention = "m.room.retention"
Presence = "m.presence"
class RejectedReason(object):
AUTH_ERROR = "auth_error"

View File

@@ -17,9 +17,6 @@
import contextlib
import logging
import sys
from typing import Dict, Iterable
from typing_extensions import ContextManager
from twisted.internet import defer, reactor
from twisted.web.resource import NoResource
@@ -41,14 +38,14 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.federation import send_queue
from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.presence import BasePresenceHandler, get_interested_parties
from synapse.handlers.presence import PresenceHandler, get_interested_parties
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
@@ -113,10 +110,6 @@ from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, sync, user_directory
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
from synapse.rest.client.v2_alpha.account_data import (
AccountDataServlet,
RoomAccountDataServlet,
)
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.versions import VersionsRestServlet
@@ -228,32 +221,23 @@ class KeyUploadServlet(RestServlet):
return 200, {"one_time_key_counts": result}
class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing."""
def __exit__(self, exc_type, exc_val, exc_tb):
pass
UPDATE_SYNCING_USERS_MS = 10 * 1000
class GenericWorkerPresence(BasePresenceHandler):
class GenericWorkerPresence(object):
def __init__(self, hs):
super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client()
self._presence_enabled = hs.config.use_presence
# The number of ongoing syncs on this process, by user id.
# Empty if _presence_enabled is false.
self._user_to_num_current_syncs = {} # type: Dict[str, int]
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
# user_id -> last_sync_ms. Lists the users that have stopped syncing
# but we haven't notified the master of that yet
self.users_going_offline = {}
@@ -271,13 +255,13 @@ class GenericWorkerPresence(BasePresenceHandler):
)
def _on_shutdown(self):
if self._presence_enabled:
if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_command(
ClearUserSyncsCommand(self.instance_id)
)
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
if self._presence_enabled:
if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_user_sync(
self.instance_id, user_id, is_syncing, last_sync_ms
)
@@ -319,33 +303,28 @@ class GenericWorkerPresence(BasePresenceHandler):
# TODO Hows this supposed to work?
return defer.succeed(None)
async def user_syncing(
self, user_id: str, affect_presence: bool
) -> ContextManager[None]:
"""Record that a user is syncing.
get_states = __func__(PresenceHandler.get_states)
get_state = __func__(PresenceHandler.get_state)
current_state_for_users = __func__(PresenceHandler.current_state_for_users)
Called by the sync and events servlets to record that a user has connected to
this worker and is waiting for some events.
"""
if not affect_presence or not self._presence_enabled:
return _NullContextManager()
def user_syncing(self, user_id, affect_presence):
if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
self._user_to_num_current_syncs[user_id] = curr_sync + 1
# If we went from no in flight sync to some, notify replication
if self._user_to_num_current_syncs[user_id] == 1:
self.mark_as_coming_online(user_id)
# If we went from no in flight sync to some, notify replication
if self.user_to_num_current_syncs[user_id] == 1:
self.mark_as_coming_online(user_id)
def _end():
# We check that the user_id is in user_to_num_current_syncs because
# user_to_num_current_syncs may have been cleared if we are
# shutting down.
if user_id in self._user_to_num_current_syncs:
self._user_to_num_current_syncs[user_id] -= 1
if affect_presence and user_id in self.user_to_num_current_syncs:
self.user_to_num_current_syncs[user_id] -= 1
# If we went from one in flight sync to non, notify replication
if self._user_to_num_current_syncs[user_id] == 0:
if self.user_to_num_current_syncs[user_id] == 0:
self.mark_as_going_offline(user_id)
@contextlib.contextmanager
@@ -355,7 +334,7 @@ class GenericWorkerPresence(BasePresenceHandler):
finally:
_end()
return _user_syncing()
return defer.succeed(_user_syncing())
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
@@ -390,12 +369,15 @@ class GenericWorkerPresence(BasePresenceHandler):
stream_id = token
yield self.notify_from_replication(states, stream_id)
def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
return [
user_id
for user_id, count in self._user_to_num_current_syncs.items()
if count > 0
]
def get_currently_syncing_users(self):
if self.hs.config.use_presence:
return [
user_id
for user_id, count in self.user_to_num_current_syncs.items()
if count > 0
]
else:
return set()
class GenericWorkerTyping(object):
@@ -519,8 +501,6 @@ class GenericWorkerServer(HomeServer):
ProfileDisplaynameRestServlet(self).register(resource)
ProfileRestServlet(self).register(resource)
KeyUploadServlet(self).register(resource)
AccountDataServlet(self).register(resource)
RoomAccountDataServlet(self).register(resource)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
@@ -639,7 +619,8 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence
# NB this is a SynchrotronPresence, not a normal PresenceHandler
self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier()
self.notify_pushers = hs.config.start_pushers
@@ -960,22 +941,17 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
hs = GenericWorkerServer(
ss = GenericWorkerServer(
config.server_name,
config=config,
version_string="Synapse/" + get_version_string(synapse),
)
setup_logging(hs, config, use_worker_options=True)
hs.setup()
# Ensure the replication streamer is always started in case we write to any
# streams. Will no-op if no streams can be written to by this worker.
hs.get_replication_streamer()
setup_logging(ss, config, use_worker_options=True)
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, hs, config.worker_listeners
"before", "startup", _base.start, ss, config.worker_listeners
)
_base.start_worker_reactor("synapse-generic-worker", config)

View File

@@ -273,12 +273,6 @@ class SynapseHomeServer(HomeServer):
def start_listening(self, listeners):
config = self.get_config()
if config.redis_enabled:
# If redis is enabled we connect via the replication command handler
# in the same way as the workers (since we're effectively a client
# rather than a server).
self.get_tcp_replication().start_replication(self)
for listener in listeners:
if listener["type"] == "http":
self._listening_services.extend(self._listener_http(config, listener))

View File

@@ -657,12 +657,6 @@ def read_config_files(config_files):
for config_file in config_files:
with open(config_file) as file_stream:
yaml_config = yaml.safe_load(file_stream)
if not isinstance(yaml_config, dict):
err = "File %r is empty or doesn't parse into a key-value map. IGNORING."
print(err % (config_file,))
continue
specified_config.update(yaml_config)
if "server_name" not in specified_config:

View File

@@ -138,7 +138,7 @@ class DatabaseConfig(Config):
database_path = config.get("database_path")
if multi_database_config and database_config:
raise ConfigError("Can't specify both 'database' and 'databases' in config")
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
if multi_database_config:
if database_path:

View File

@@ -31,7 +31,6 @@ from .password import PasswordConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
from .ratelimiting import RatelimitConfig
from .redis import RedisConfig
from .registration import RegistrationConfig
from .repository import ContentRepositoryConfig
from .room_directory import RoomDirectoryConfig
@@ -83,5 +82,4 @@ class HomeServerConfig(RootConfig):
RoomDirectoryConfig,
ThirdPartyRulesConfig,
TracerConfig,
RedisConfig,
]

View File

@@ -1,35 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.config._base import Config
from synapse.python_dependencies import check_requirements
class RedisConfig(Config):
section = "redis"
def read_config(self, config, **kwargs):
redis_config = config.get("redis", {})
self.redis_enabled = redis_config.get("enabled", False)
if not self.redis_enabled:
return
check_requirements("redis")
self.redis_host = redis_config.get("host", "localhost")
self.redis_port = redis_config.get("port", 6379)
self.redis_dbid = redis_config.get("dbid")
self.redis_password = redis_config.get("password")

View File

@@ -224,11 +224,12 @@ class ContentRepositoryConfig(Config):
#
#media_storage_providers:
# - module: file_system
# # Whether to store newly uploaded local files
# # Whether to write new local files.
# store_local: false
# # Whether to store newly downloaded remote files
# # Whether to write new remote media
# store_remote: false
# # Whether to wait for successful storage for local uploads
# # Whether to block upload requests waiting for write to this
# # provider to complete
# store_synchronous: false
# config:
# directory: /mnt/some/other/directory

View File

@@ -248,32 +248,32 @@ class SAML2Config(Config):
# remote:
# - url: https://our_idp/metadata.xml
#
# # By default, the user has to go to our login page first. If you'd like
# # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
# # 'service.sp' section:
# #
# #service:
# # sp:
# # allow_unsolicited: true
# # By default, the user has to go to our login page first. If you'd like
# # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
# # 'service.sp' section:
# #
# #service:
# # sp:
# # allow_unsolicited: true
#
# # The examples below are just used to generate our metadata xml, and you
# # may well not need them, depending on your setup. Alternatively you
# # may need a whole lot more detail - see the pysaml2 docs!
# # The examples below are just used to generate our metadata xml, and you
# # may well not need them, depending on your setup. Alternatively you
# # may need a whole lot more detail - see the pysaml2 docs!
#
# description: ["My awesome SP", "en"]
# name: ["Test SP", "en"]
# description: ["My awesome SP", "en"]
# name: ["Test SP", "en"]
#
# organization:
# name: Example com
# display_name:
# - ["Example co", "en"]
# url: "http://example.com"
# organization:
# name: Example com
# display_name:
# - ["Example co", "en"]
# url: "http://example.com"
#
# contact_person:
# - given_name: Bob
# sur_name: "the Sysadmin"
# email_address": ["admin@example.com"]
# contact_type": technical
# contact_person:
# - given_name: Bob
# sur_name: "the Sysadmin"
# email_address": ["admin@example.com"]
# contact_type": technical
# Instead of putting the config inline as above, you can specify a
# separate pysaml2 configuration file:

View File

@@ -507,17 +507,6 @@ class ServerConfig(Config):
self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False)
# Inhibits the /requestToken endpoints from returning an error that might leak
# information about whether an e-mail address is in use or not on this
# homeserver, and instead return a 200 with a fake sid if this kind of error is
# met, without sending anything.
# This is a compromise between sending an email, which could be a spam vector,
# and letting the client know which email address is bound to an account and
# which one isn't.
self.request_token_inhibit_3pid_errors = config.get(
"request_token_inhibit_3pid_errors", False,
)
def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners)
@@ -983,16 +972,6 @@ class ServerConfig(Config):
# - shortest_max_lifetime: 3d
# longest_max_lifetime: 1y
# interval: 1d
# Inhibits the /requestToken endpoints from returning an error that might leak
# information about whether an e-mail address is in use or not on this
# homeserver.
# Note that for some endpoints the error situation is the e-mail already being
# used, and for others the error is entering the e-mail being unused.
# If this option is enabled, instead of returning an error, these endpoints will
# act as if no error happened and return a fake session ID ('sid') to clients.
#
#request_token_inhibit_3pid_errors: true
"""
% locals()
)

View File

@@ -43,12 +43,6 @@ class SSOConfig(Config):
),
"sso_account_deactivated_template",
)
self.sso_auth_success_template = self.read_file(
os.path.join(
self.sso_redirect_confirm_template_dir, "sso_auth_success.html"
),
"sso_auth_success_template",
)
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
@@ -113,30 +107,6 @@ class SSOConfig(Config):
#
# * server_name: the homeserver's name.
#
# * HTML page which notifies the user that they are authenticating to confirm
# an operation on their account during the user interactive authentication
# process: 'sso_auth_confirm.html'.
#
# When rendering, this template is given the following variables:
# * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * description: the operation which the user is being asked to confirm
#
# * HTML page shown after a successful user interactive authentication session:
# 'sso_auth_success.html'.
#
# Note that this page must include the JavaScript which notifies of a successful authentication
# (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
#
# This template has no additional variables.
#
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#
# This template has no additional variables.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#

View File

@@ -399,24 +399,20 @@ class TransportLayerClient(object):
{
"device_keys": {
"<user_id>": ["<device_id>"]
}
}
} }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {...}
}
},
"master_key": {
} }
"master_keys": {
"<user_id>": {...}
}
},
"self_signing_key": {
} }
"self_signing_keys": {
"<user_id>": {...}
}
}
} } }
Args:
destination(str): The server to query.
@@ -440,22 +436,8 @@ class TransportLayerClient(object):
{
"stream_id": "...",
"devices": [ { ... } ],
"master_key": {
"user_id": "<user_id>",
"usage": [...],
"keys": {...},
"signatures": {
"<user_id>": {...}
}
},
"self_signing_key": {
"user_id": "<user_id>",
"usage": [...],
"keys": {...},
"signatures": {
"<user_id>": {...}
}
}
"master_key": { ... },
"self_signing_key: { ... }
}
Args:
@@ -480,10 +462,8 @@ class TransportLayerClient(object):
{
"one_time_keys": {
"<user_id>": {
"<device_id>": "<algorithm>"
}
}
}
"<device_id>": "<algorithm>"
} } }
Response:
{
@@ -491,16 +471,13 @@ class TransportLayerClient(object):
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
}
}
}
}
} } } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containing the one-time keys.
A dict containg the one-time keys.
"""
path = _create_v1_path("/user/keys/claim")

View File

@@ -37,16 +37,15 @@ An attestation is a signed blob of json that looks like:
import logging
import random
from typing import Tuple
from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util.async_helpers import yieldable_gather_results
logger = logging.getLogger(__name__)
@@ -163,19 +162,19 @@ class GroupAttestionRenewer(object):
def _start_renew_attestations(self):
return run_as_background_process("renew_attestations", self._renew_attestations)
async def _renew_attestations(self):
@defer.inlineCallbacks
def _renew_attestations(self):
"""Called periodically to check if we need to update any of our attestations
"""
now = self.clock.time_msec()
rows = await self.store.get_attestations_need_renewals(
rows = yield self.store.get_attestations_need_renewals(
now + UPDATE_ATTESTATION_TIME_MS
)
@defer.inlineCallbacks
def _renew_attestation(group_user: Tuple[str, str]):
group_id, user_id = group_user
def _renew_attestation(group_id, user_id):
try:
if not self.is_mine_id(group_id):
destination = get_domain_from_id(group_id)
@@ -208,6 +207,8 @@ class GroupAttestionRenewer(object):
"Error renewing attestation of %r in %r", user_id, group_id
)
await yieldable_gather_results(
_renew_attestation, ((row["group_id"], row["user_id"]) for row in rows)
)
for row in rows:
group_id = row["group_id"]
user_id = row["user_id"]
run_in_background(_renew_attestation, group_id, user_id)

View File

@@ -51,6 +51,31 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
SUCCESS_TEMPLATE = """
<html>
<head>
<title>Success!</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script>
if (window.onAuthDone) {
window.onAuthDone();
} else if (window.opener && window.opener.postMessage) {
window.opener.postMessage("authDone", "*");
}
</script>
</head>
<body>
<div>
<p>Thank you</p>
<p>You may now close this window and return to the application</p>
</div>
</body>
</html>
"""
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -134,11 +159,6 @@ class AuthHandler(BaseHandler):
self._sso_auth_confirm_template = load_jinja2_templates(
hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"],
)[0]
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
self._sso_auth_success_template = hs.config.sso_auth_success_template
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
hs.config.sso_account_deactivated_template
)
@@ -257,6 +277,10 @@ class AuthHandler(BaseHandler):
Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
If no auth flows have been completed successfully, raises an
InteractiveAuthIncompleteError. To handle this, you can use
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
@@ -300,47 +324,50 @@ class AuthHandler(BaseHandler):
del clientdict["auth"]
if "session" in authdict:
sid = authdict["session"]
session = self._get_session_info(sid)
# If there's no session ID, create a new session.
if not sid:
session = self._create_session(
clientdict, (request.uri, request.method, clientdict), description
if len(clientdict) > 0:
# This was designed to allow the client to omit the parameters
# and just supply the session in subsequent calls so it split
# auth between devices by just sharing the session, (eg. so you
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a homeserver.
# Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbintrary.
session["clientdict"] = clientdict
self._save_session(session)
elif "clientdict" in session:
clientdict = session["clientdict"]
# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
comparator = (request.uri, request.method, clientdict)
if "ui_auth" not in session:
session["ui_auth"] = comparator
self._save_session(session)
elif session["ui_auth"] != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
)
session_id = session["id"]
else:
session = self._get_session_info(sid)
session_id = sid
if not clientdict:
# This was designed to allow the client to omit the parameters
# and just supply the session in subsequent calls so it split
# auth between devices by just sharing the session, (eg. so you
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a homeserver.
# Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbitrary.
clientdict = session["clientdict"]
# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
comparator = (request.uri, request.method, clientdict)
if session["ui_auth"] != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
)
# Add a human readable description to the session.
if "description" not in session:
session["description"] = description
self._save_session(session)
if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session_id)
self._auth_dict_for_flows(flows, session)
)
if "creds" not in session:
session["creds"] = {}
creds = session["creds"]
# check auth type currently being presented
@@ -380,9 +407,9 @@ class AuthHandler(BaseHandler):
list(clientdict),
)
return creds, clientdict, session_id
return creds, clientdict, session["id"]
ret = self._auth_dict_for_flows(flows, session_id)
ret = self._auth_dict_for_flows(flows, session)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
@@ -400,6 +427,8 @@ class AuthHandler(BaseHandler):
raise LoginError(400, "", Codes.MISSING_PARAM)
sess = self._get_session_info(authdict["session"])
if "creds" not in sess:
sess["creds"] = {}
creds = sess["creds"]
result = await self.checkers[stagetype].check_auth(authdict, clientip)
@@ -439,7 +468,7 @@ class AuthHandler(BaseHandler):
value: The data to store
"""
sess = self._get_session_info(session_id)
sess["serverdict"][key] = value
sess.setdefault("serverdict", {})[key] = value
self._save_session(sess)
def get_session_data(
@@ -454,7 +483,7 @@ class AuthHandler(BaseHandler):
default: Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess["serverdict"].get(key, default)
return sess.setdefault("serverdict", {}).get(key, default)
async def _check_auth_dict(
self, authdict: Dict[str, Any], clientip: str
@@ -510,7 +539,7 @@ class AuthHandler(BaseHandler):
}
def _auth_dict_for_flows(
self, flows: List[List[str]], session_id: str,
self, flows: List[List[str]], session: Dict[str, Any]
) -> Dict[str, Any]:
public_flows = []
for f in flows:
@@ -529,72 +558,29 @@ class AuthHandler(BaseHandler):
params[stage] = get_params[stage]()
return {
"session": session_id,
"session": session["id"],
"flows": [{"stages": f} for f in public_flows],
"params": params,
}
def _create_session(
self,
clientdict: Dict[str, Any],
ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
description: str,
) -> dict:
def _get_session_info(self, session_id: Optional[str]) -> dict:
"""
Creates a new user interactive authentication session.
Gets or creates a session given a session ID.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
Each session has the following keys:
id:
A unique identifier for this session. Passed back to the client
and returned for each stage.
clientdict:
The dictionary from the client root level, not the 'auth' key.
ui_auth:
A tuple which is checked at each stage of the authentication to
ensure that the asked for operation has not changed.
creds:
A map, which maps each auth-type (str) to the relevant identity
authenticated by that auth-type (mostly str, but for captcha, bool).
serverdict:
A map of data that is stored server-side and cannot be modified
by the client.
description:
A string description of the operation that the current
authentication is authorising.
Returns:
The newly created session.
"""
session_id = None
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
if session_id not in self.sessions:
session_id = None
self.sessions[session_id] = {
"id": session_id,
"clientdict": clientdict,
"ui_auth": ui_auth,
"creds": {},
"serverdict": {},
"description": description,
}
if not session_id:
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {"id": session_id}
return self.sessions[session_id]
def _get_session_info(self, session_id: str) -> dict:
"""
Gets a session given a session ID.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
"""
try:
return self.sessions[session_id]
except KeyError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
@@ -1064,8 +1050,11 @@ class AuthHandler(BaseHandler):
The HTML to render.
"""
session = self._get_session_info(session_id)
# Get the human readable operation of what is occurring, falling back to
# a generic message if it isn't available for some reason.
description = session.get("description", "modify your account")
return self._sso_auth_confirm_template.render(
description=session["description"], redirect_url=redirect_url,
description=description, redirect_url=redirect_url,
)
def complete_sso_ui_auth(
@@ -1081,6 +1070,8 @@ class AuthHandler(BaseHandler):
"""
# Mark the stage of the authentication as successful.
sess = self._get_session_info(session_id)
if "creds" not in sess:
sess["creds"] = {}
creds = sess["creds"]
# Save the user who authenticated with SSO, this will be used to ensure
@@ -1089,7 +1080,7 @@ class AuthHandler(BaseHandler):
self._save_session(sess)
# Render the HTML and return.
html_bytes = self._sso_auth_success_template.encode("utf-8")
html_bytes = SUCCESS_TEMPLATE.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
@@ -1115,12 +1106,12 @@ class AuthHandler(BaseHandler):
# flow.
deactivated = await self.store.get_user_deactivated_status(registered_user_id)
if deactivated:
html_bytes = self._sso_account_deactivated_template.encode("utf-8")
html = self._sso_account_deactivated_template.encode("utf-8")
request.setResponseCode(403)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
request.setHeader(b"Content-Length", b"%d" % (len(html),))
request.write(html)
finish_request(request)
return
@@ -1162,7 +1153,7 @@ class AuthHandler(BaseHandler):
# URL we redirect users to.
redirect_url_no_params = client_redirect_url.split("?")[0]
html_bytes = self._sso_redirect_confirm_template.render(
html = self._sso_redirect_confirm_template.render(
display_url=redirect_url_no_params,
redirect_url=redirect_url,
server_name=self._server_name,
@@ -1170,8 +1161,8 @@ class AuthHandler(BaseHandler):
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
request.setHeader(b"Content-Length", b"%d" % (len(html),))
request.write(html)
finish_request(request)
@staticmethod

View File

@@ -880,6 +880,7 @@ class E2eKeysHandler(object):
try:
# get our user-signing key to verify the signatures
logger.info("***Getting the user_signing")
(
user_signing_key,
user_signing_key_id,
@@ -903,6 +904,7 @@ class E2eKeysHandler(object):
try:
# get the target user's master key, to make sure it matches
# what was sent
logger.info("***Getting the master")
(
master_key,
master_key_id,
@@ -985,14 +987,11 @@ class E2eKeysHandler(object):
SynapseError: if `user_id` is invalid
"""
user = UserID.from_string(user_id)
logger.info("***Trying to get a %s key for %s from storage...", key_type, user_id)
key = yield self.store.get_e2e_cross_signing_key(
user_id, key_type, from_user_id
)
if key:
# We found a copy of this key in our database. Decode and return it
key_id, verify_key = get_verify_key_from_cross_signing_key(key)
return key, key_id, verify_key
logger.info("***Well, we got this: %s", key)
# If we couldn't find the key locally, and we're looking for keys of
# another user then attempt to fetch the missing key from the remote
@@ -1000,22 +999,37 @@ class E2eKeysHandler(object):
#
# We may run into this in possible edge cases where a user tries to
# cross-sign a remote user, but does not share any rooms with them yet.
# Thus, we would not have their key list yet. We instead fetch the key,
# Thus, we would not have their key list yet. We fetch the key here,
# store it and notify clients of new, associated device IDs.
if self.is_mine(user) or key_type not in ["master", "self_signing"]:
# Note that master and self_signing keys are the only cross-signing keys we
# can request over federation
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
(
key,
key_id,
verify_key,
) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
logger.info("***Checking if we'll do our thingy")
if (
key is None
and not self.is_mine(user)
# We only get "master" and "self_signing" keys from remote servers
and key_type in ["master", "self_signing"]
):
logger.info("***Doing our thingy")
(
key,
key_id,
verify_key,
) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
if key is None:
logger.warning("No %s key found for %s", key_type, user_id)
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
try:
key_id, verify_key = get_verify_key_from_cross_signing_key(key)
except ValueError as e:
logger.warning(
"Invalid %s key retrieved: %s - %s %s", key_type, key, type(e), e,
)
raise SynapseError(
502, "Invalid %s key retrieved from remote server" % (key_type,)
)
logger.info("***Finally returning %s - %s - %s", key, key_id, verify_key)
return key, key_id, verify_key
@defer.inlineCallbacks
@@ -1040,6 +1054,7 @@ class E2eKeysHandler(object):
remote_result = yield self.federation.query_user_devices(
user.domain, user.to_string()
)
logger.info("***Got our remote_result: %s", remote_result)
except Exception as e:
logger.warning(
"Unable to query %s for cross-signing keys of user %s: %s %s",
@@ -1051,38 +1066,31 @@ class E2eKeysHandler(object):
return None, None, None
# Process each of the retrieved cross-signing keys
desired_key = None
desired_key_id = None
desired_verify_key = None
retrieved_device_ids = []
final_key = None
final_key_id = None
final_verify_key = None
device_ids = []
for key_type in ["master", "self_signing"]:
logger.info("***Processing retrieved key type: %s", key_type)
key_content = remote_result.get(key_type + "_key")
logger.info("***Got key_content: %s", key_content)
if not key_content:
continue
# Ensure these keys belong to the correct user
if "user_id" not in key_content:
logger.warning(
"Invalid %s key retrieved, missing user_id field: %s",
key_type,
key_content,
)
continue
if user.to_string() != key_content["user_id"]:
logger.warning(
"Found %s key of user %s when querying for keys of user %s",
key_type,
key_content["user_id"],
user.to_string(),
)
continue
# At the same time, store this key in the db for
# subsequent queries
yield self.store.set_e2e_cross_signing_key(
user.to_string(), key_type, key_content
)
logger.info("***Stored key")
# Validate the key contents
# Note down the device ID attached to this key
try:
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
key_id, verify_key = get_verify_key_from_cross_signing_key(key_content)
logger.info("***Verified key: %s - %s", key_id, verify_key)
except ValueError as e:
logger.warning(
"Invalid %s key retrieved: %s - %s %s",
@@ -1092,29 +1100,23 @@ class E2eKeysHandler(object):
e,
)
continue
# Note down the device ID attached to this key
retrieved_device_ids.append(verify_key.version)
logger.info("***Appending device id: %s - %s", verify_key, verify_key.version)
device_ids.append(verify_key.version)
# If this is the desired key type, save it and its ID/VerifyKey
if key_type == desired_key_type:
desired_key = key_content
desired_verify_key = verify_key
desired_key_id = key_id
# At the same time, store this key in the db for subsequent queries
yield self.store.set_e2e_cross_signing_key(
user.to_string(), key_type, key_content
)
logger.info("***We found our desired key type, %s!", key_type)
final_key = key_content
final_verify_key = verify_key
final_key_id = key_id
# Notify clients that new devices for this user have been discovered
if retrieved_device_ids:
# XXX is this necessary?
yield self.device_handler.notify_device_update(
user.to_string(), retrieved_device_ids
)
if device_ids:
logger.info("***Updating clients with devices: %s", device_ids)
yield self.device_handler.notify_device_update(user.to_string(), device_ids)
return desired_key, desired_key_id, desired_verify_key
logger.info("***Returning %s - %s - %s", final_key, final_key_id, final_verify_key)
return final_key, final_key_id, final_verify_key
def _check_cross_signing_key(key, user_id, key_type, signing_key=None):

View File

@@ -19,7 +19,6 @@ import random
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.utils import log_function
from synapse.types import UserID
from synapse.visibility import filter_events_for_client
@@ -98,8 +97,6 @@ class EventStreamHandler(BaseHandler):
explicit_room_id=room_id,
)
time_now = self.clock.time_msec()
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = []
@@ -115,20 +112,19 @@ class EventStreamHandler(BaseHandler):
users = await self.state.get_current_users_in_room(
event.room_id
)
states = await presence_handler.get_states(users, as_event=True)
to_add.extend(states)
else:
users = [event.state_key]
states = await presence_handler.get_states(users)
to_add.extend(
{
"type": EventTypes.Presence,
"content": format_user_presence_state(state, time_now),
}
for state in states
)
ev = await presence_handler.get_state(
UserID.from_string(event.state_key), as_event=True
)
to_add.append(ev)
events.extend(to_add)
time_now = self.clock.time_msec()
chunks = await self._event_serializer.serialize_events(
events,
time_now,

View File

@@ -343,7 +343,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(ours.values()) # type: List[StateMap[str]]
state_maps = list(ours.values()) # type: list[StateMap[str]]
# we don't need this any more, let's delete it.
del ours
@@ -1694,15 +1694,16 @@ class FederationHandler(BaseHandler):
return None
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
@defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(
event = yield self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
@@ -1713,7 +1714,7 @@ class FederationHandler(BaseHandler):
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
prev_event = await self.store.get_event(prev_id)
prev_event = yield self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
@@ -1723,14 +1724,15 @@ class FederationHandler(BaseHandler):
else:
return []
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(
event = yield self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@@ -1749,50 +1751,49 @@ class FederationHandler(BaseHandler):
else:
return []
@defer.inlineCallbacks
@log_function
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin)
def on_backfill_request(self, origin, room_id, pdu_list, limit):
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
# Synapse asks for 100 events per backfill request. Do not allow more.
limit = min(limit, 100)
events = await self.store.get_backfill_events(room_id, pdu_list, limit)
events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
events = await filter_events_for_server(self.storage, origin, events)
events = yield filter_events_for_server(self.storage, origin, events)
return events
@defer.inlineCallbacks
@log_function
async def get_persisted_pdu(
self, origin: str, event_id: str
) -> Optional[EventBase]:
def get_persisted_pdu(self, origin, event_id):
"""Get an event from the database for the given server.
Args:
origin: hostname of server which is requesting the event; we
origin [str]: hostname of server which is requesting the event; we
will check that the server is allowed to see it.
event_id: id of the event being requested
event_id [str]: id of the event being requested
Returns:
None if we know nothing about the event; otherwise the (possibly-redacted) event.
Deferred[EventBase|None]: None if we know nothing about the event;
otherwise the (possibly-redacted) event.
Raises:
AuthError if the server is not currently in the room
"""
event = await self.store.get_event(
event = yield self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
if event:
in_room = await self.auth.check_host_in_room(event.room_id, origin)
in_room = yield self.auth.check_host_in_room(event.room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
events = await filter_events_for_server(self.storage, origin, [event])
events = yield filter_events_for_server(self.storage, origin, [event])
event = events[0]
return event
else:
@@ -2396,7 +2397,7 @@ class FederationHandler(BaseHandler):
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
event_key = (event.type, event.state_key)
else:
event_key = None
state_updates = {

View File

@@ -18,7 +18,7 @@
"""Utilities for interacting with Identity Servers"""
import logging
import urllib.parse
import urllib
from canonicaljson import json
from signedjson.key import decode_verify_key_bytes

View File

@@ -381,16 +381,10 @@ class InitialSyncHandler(BaseHandler):
return []
states = await presence_handler.get_states(
[m.user_id for m in room_members]
[m.user_id for m in room_members], as_event=True
)
return [
{
"type": EventTypes.Presence,
"content": format_user_presence_state(s, time_now),
}
for s in states
]
return states
async def get_receipts():
receipts = await self.store.get_linearized_receipts_for_room(

View File

@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,10 +21,10 @@ The methods that define policy are:
- PresenceHandler._handle_timeouts
- should_notify
"""
import abc
import logging
from contextlib import contextmanager
from typing import Dict, Iterable, List, Set
from typing import Dict, List, Set
from six import iteritems, itervalues
@@ -42,7 +41,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
@@ -100,106 +99,13 @@ EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class BasePresenceHandler(abc.ABC):
"""Parts of the PresenceHandler that are shared between workers and master"""
class PresenceHandler(object):
def __init__(self, hs: "synapse.server.HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
@abc.abstractmethod
async def user_syncing(
self, user_id: str, affect_presence: bool
) -> ContextManager[None]:
"""Returns a context manager that should surround any stream requests
from the user.
This allows us to keep track of who is currently streaming and who isn't
without having to have timers outside of this module to avoid flickering
when users disconnect/reconnect.
Args:
user_id: the user that is starting a sync
affect_presence: If false this function will be a no-op.
Useful for streams that are not associated with an actual
client that is being used by a user.
"""
@abc.abstractmethod
def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
"""Get an iterable of syncing users on this worker, to send to the presence handler
This is called when a replication connection is established. It should return
a list of user ids, which are then sent as USER_SYNC commands to inform the
process handling presence about those users.
Returns:
An iterable of user_id strings.
"""
async def get_state(self, target_user: UserID) -> UserPresenceState:
results = await self.get_states([target_user.to_string()])
return results[0]
async def get_states(
self, target_user_ids: Iterable[str]
) -> List[UserPresenceState]:
"""Get the presence state for users."""
updates_d = await self.current_state_for_users(target_user_ids)
updates = list(updates_d.values())
for user_id in set(target_user_ids) - {u.user_id for u in updates}:
updates.append(UserPresenceState.default(user_id))
return updates
async def current_state_for_users(
self, user_ids: Iterable[str]
) -> Dict[str, UserPresenceState]:
"""Get the current presence state for multiple users.
Returns:
dict: `user_id` -> `UserPresenceState`
"""
states = {
user_id: self.user_to_current_state.get(user_id, None)
for user_id in user_ids
}
missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = await self.store.get_presence_for_users(missing)
states.update(res)
missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id) for user_id in missing
}
states.update(new)
self.user_to_current_state.update(new)
return states
@abc.abstractmethod
async def set_state(
self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
) -> None:
"""Set the presence state of the user. """
class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "synapse.server.HomeServer"):
super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender()
@@ -209,6 +115,13 @@ class PresenceHandler(BasePresenceHandler):
federation_registry.register_edu_handler("m.presence", self.incoming_presence)
active_presence = self.store.take_presence_startup_info()
# A dictionary of the current state of users. This is prefilled with
# non-offline presence from the DB. We should fetch from the DB if
# we can't find a users presence in here.
self.user_to_current_state = {state.user_id: state for state in active_presence}
LaterGauge(
"synapse_handlers_presence_user_to_current_state_size",
"",
@@ -217,7 +130,7 @@ class PresenceHandler(BasePresenceHandler):
)
now = self.clock.time_msec()
for state in self.user_to_current_state.values():
for state in active_presence:
self.wheel_timer.insert(
now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
)
@@ -448,18 +361,10 @@ class PresenceHandler(BasePresenceHandler):
timers_fired_counter.inc(len(states))
syncing_user_ids = {
user_id
for user_id, count in self.user_to_num_current_syncs.items()
if count
}
for user_ids in self.external_process_to_current_syncs.values():
syncing_user_ids.update(user_ids)
changes = handle_timeouts(
states,
is_mine_fn=self.is_mine_id,
syncing_user_ids=syncing_user_ids,
syncing_user_ids=self.get_currently_syncing_users(),
now=now,
)
@@ -557,9 +462,22 @@ class PresenceHandler(BasePresenceHandler):
return _user_syncing()
def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
# since we are the process handling presence, there is nothing to do here.
return []
def get_currently_syncing_users(self):
"""Get the set of user ids that are currently syncing on this HS.
Returns:
set(str): A set of user_id strings.
"""
if self.hs.config.use_presence:
syncing_user_ids = {
user_id
for user_id, count in self.user_to_num_current_syncs.items()
if count
}
for user_ids in self.external_process_to_current_syncs.values():
syncing_user_ids.update(user_ids)
return syncing_user_ids
else:
return set()
async def update_external_syncs_row(
self, process_id, user_id, is_syncing, sync_time_msec
@@ -636,6 +554,34 @@ class PresenceHandler(BasePresenceHandler):
res = await self.current_state_for_users([user_id])
return res[user_id]
async def current_state_for_users(self, user_ids):
"""Get the current presence state for multiple users.
Returns:
dict: `user_id` -> `UserPresenceState`
"""
states = {
user_id: self.user_to_current_state.get(user_id, None)
for user_id in user_ids
}
missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = await self.store.get_presence_for_users(missing)
states.update(res)
missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id) for user_id in missing
}
states.update(new)
self.user_to_current_state.update(new)
return states
async def _persist_and_notify(self, states):
"""Persist states in the database, poke the notifier and send to
interested remote servers
@@ -723,6 +669,40 @@ class PresenceHandler(BasePresenceHandler):
federation_presence_counter.inc(len(updates))
await self._update_states(updates)
async def get_state(self, target_user, as_event=False):
results = await self.get_states([target_user.to_string()], as_event=as_event)
return results[0]
async def get_states(self, target_user_ids, as_event=False):
"""Get the presence state for users.
Args:
target_user_ids (list)
as_event (bool): Whether to format it as a client event or not.
Returns:
list
"""
updates = await self.current_state_for_users(target_user_ids)
updates = list(updates.values())
for user_id in set(target_user_ids) - {u.user_id for u in updates}:
updates.append(UserPresenceState.default(user_id))
now = self.clock.time_msec()
if as_event:
return [
{
"type": "m.presence",
"content": format_user_presence_state(state, now),
}
for state in updates
]
else:
return updates
async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
@@ -909,7 +889,7 @@ class PresenceHandler(BasePresenceHandler):
user_ids = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
states_d = await self.current_state_for_users(user_ids)
states = await self.current_state_for_users(user_ids)
# Filter out old presence, i.e. offline presence states where
# the user hasn't been active for a week. We can change this
@@ -919,7 +899,7 @@ class PresenceHandler(BasePresenceHandler):
now = self.clock.time_msec()
states = [
state
for state in states_d.values()
for state in states.values()
if state.state != PresenceState.OFFLINE
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
or state.status_msg is not None

View File

@@ -434,27 +434,21 @@ class MatrixFederationHttpClient(object):
logger.info("Failed to send request: %s", e)
raise_from(RequestSendFailed(e, can_retry=True), e)
logger.info(
"{%s} [%s] Got response headers: %d %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
)
incoming_responses_counter.labels(method_bytes, response.code).inc()
set_tag(tags.HTTP_STATUS_CODE, response.code)
if 200 <= response.code < 300:
logger.debug(
"{%s} [%s] Got response headers: %d %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
)
pass
else:
logger.info(
"{%s} [%s] Got response headers: %d %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
)
# :'(
# Update transactions table?
d = treq.content(response)

View File

@@ -220,6 +220,12 @@ class Notifier(object):
"""
self.replication_callbacks.append(cb)
def add_remote_server_up_callback(self, cb: Callable[[str], None]):
"""Add a callback that will be called when synapse detects a server
has been
"""
self.remote_server_up_callbacks.append(cb)
def on_new_room_event(
self, event, room_stream_id, max_room_stream_id, extra_users=[]
):
@@ -538,3 +544,6 @@ class Notifier(object):
# circular dependencies.
if self.federation_sender:
self.federation_sender.wake_destination(server)
for cb in self.remote_server_up_callbacks:
cb(server)

View File

@@ -16,11 +16,9 @@
import logging
import re
from typing import Pattern
from six import string_types
from synapse.events import EventBase
from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
@@ -58,18 +56,18 @@ def _test_ineq_condition(condition, number):
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs_int = int(rhs)
rhs = int(rhs)
if ineq == "" or ineq == "==":
return number == rhs_int
return number == rhs
elif ineq == "<":
return number < rhs_int
return number < rhs
elif ineq == ">":
return number > rhs_int
return number > rhs
elif ineq == ">=":
return number >= rhs_int
return number >= rhs
elif ineq == "<=":
return number <= rhs_int
return number <= rhs
else:
return False
@@ -85,13 +83,7 @@ def tweaks_for_actions(actions):
class PushRuleEvaluatorForEvent(object):
def __init__(
self,
event: EventBase,
room_member_count: int,
sender_power_level: int,
power_levels: dict,
):
def __init__(self, event, room_member_count, sender_power_level, power_levels):
self._event = event
self._room_member_count = room_member_count
self._sender_power_level = sender_power_level
@@ -100,7 +92,7 @@ class PushRuleEvaluatorForEvent(object):
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
def matches(self, condition, user_id, display_name):
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
elif condition["kind"] == "contains_display_name":
@@ -114,7 +106,7 @@ class PushRuleEvaluatorForEvent(object):
else:
return True
def _event_match(self, condition: dict, user_id: str) -> bool:
def _event_match(self, condition, user_id):
pattern = condition.get("pattern", None)
if not pattern:
@@ -142,7 +134,7 @@ class PushRuleEvaluatorForEvent(object):
return _glob_matches(pattern, haystack)
def _contains_display_name(self, display_name: str) -> bool:
def _contains_display_name(self, display_name):
if not display_name:
return False
@@ -150,52 +142,51 @@ class PushRuleEvaluatorForEvent(object):
if not body:
return False
# Similar to _glob_matches, but do not treat display_name as a glob.
r = regex_cache.get((display_name, False, True), None)
if not r:
r = re.escape(display_name)
r = _re_word_boundary(r)
r = re.compile(r, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r
return _glob_matches(display_name, body, word_boundary=True)
return r.search(body)
def _get_value(self, dotted_key: str) -> str:
def _get_value(self, dotted_key):
return self._value_cache.get(dotted_key, None)
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
register_cache("cache", "regex_push_cache", regex_cache)
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
def _glob_matches(glob, value, word_boundary=False):
"""Tests if value matches glob.
Args:
glob
value: String to test against glob.
word_boundary: Whether to match against word boundaries or entire
glob (string)
value (string): String to test against glob.
word_boundary (bool): Whether to match against word boundaries or entire
string. Defaults to False.
Returns:
bool
"""
try:
r = regex_cache.get((glob, True, word_boundary), None)
r = regex_cache.get((glob, word_boundary), None)
if not r:
r = _glob_to_re(glob, word_boundary)
regex_cache[(glob, True, word_boundary)] = r
regex_cache[(glob, word_boundary)] = r
return r.search(value)
except re.error:
logger.warning("Failed to parse glob to regex: %r", glob)
return False
def _glob_to_re(glob: str, word_boundary: bool) -> Pattern:
def _glob_to_re(glob, word_boundary):
"""Generates regex for a given glob.
Args:
glob
word_boundary: Whether to match against word boundaries or entire string.
glob (string)
word_boundary (bool): Whether to match against word boundaries or entire
string. Defaults to False.
Returns:
regex object
"""
if IS_GLOB.search(glob):
r = re.escape(glob)
@@ -228,7 +219,7 @@ def _glob_to_re(glob: str, word_boundary: bool) -> Pattern:
return re.compile(r, flags=re.IGNORECASE)
def _re_word_boundary(r: str) -> str:
def _re_word_boundary(r):
"""
Adds word boundary characters to the start and end of an
expression to require that the match occur as a whole word,

View File

@@ -98,7 +98,6 @@ CONDITIONAL_REQUIREMENTS = {
"sentry": ["sentry-sdk>=0.7.2"],
"opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
"jwt": ["pyjwt>=1.6.4"],
"redis": ["txredisapi>=1.4.7"],
}
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]

View File

@@ -28,7 +28,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
The API looks like:
GET /_synapse/replication/get_repl_stream_updates/<stream name>?from_token=0&to_token=10
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
200 OK
@@ -38,9 +38,6 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
limited: False,
}
If there are more rows than can sensibly be returned in one lump, `limited` will be
set to true, and the caller should call again with a new `from_token`.
"""
NAME = "get_repl_stream_updates"
@@ -55,8 +52,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
self.streams = hs.get_replication_streamer().get_streams()
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token):
return {"from_token": from_token, "upto_token": upto_token}
def _serialize_payload(stream_name, from_token, upto_token, limit):
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
async def _handle_request(self, request, stream_name):
stream = self.streams.get(stream_name)
@@ -65,9 +62,10 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
from_token = parse_integer(request, "from_token", required=True)
upto_token = parse_integer(request, "upto_token", required=True)
limit = parse_integer(request, "limit", required=True)
updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token
from_token, upto_token, limit
)
return (

View File

@@ -30,7 +30,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
class ReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the
connection is lost.

View File

@@ -210,10 +210,7 @@ class ReplicateCommand(Command):
class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or
stopped syncing on this process.
This is used by the process handling presence (typically the master) to
calculate who is online and who is not.
stopped syncing. Used to calculate presence on the master.
Includes a timestamp of when the last user sync was.
@@ -221,7 +218,7 @@ class UserSyncCommand(Command):
USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
Where <state> is either "start" or "end"
Where <state> is either "start" or "stop"
"""
NAME = "USER_SYNC"
@@ -457,21 +454,3 @@ VALID_CLIENT_COMMANDS = (
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
)
def parse_command_from_line(line: str) -> Command:
"""Parses a command from a received line.
Line should already be stripped of whitespace and be checked if blank.
"""
idx = line.find(" ")
if idx >= 0:
cmd_name = line[:idx]
rest_of_line = line[idx + 1 :]
else:
cmd_name = line
rest_of_line = ""
cmd_cls = COMMAND_MAP[cmd_name]
return cmd_cls.from_line(rest_of_line)

View File

@@ -15,25 +15,12 @@
# limitations under the License.
import logging
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
)
from typing import Any, Callable, Dict, List, Optional, Set
from prometheus_client import Counter
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.metrics import LaterGauge
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
from synapse.replication.tcp.client import ReplicationClientFactory
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
@@ -94,7 +81,7 @@ class ReplicationCommandHandler:
self._pending_batches = {} # type: Dict[str, List[Any]]
# The factory used to create connections.
self._factory = None # type: Optional[ReconnectingClientFactory]
self._factory = None # type: Optional[ReplicationClientFactory]
# The currently connected connections.
self._connections = [] # type: List[AbstractConnection]
@@ -115,52 +102,19 @@ class ReplicationCommandHandler:
self._server_notices_sender = None
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
if hs.config.redis.redis_enabled:
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
)
import txredisapi
client_name = hs.config.worker_name
self._factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
logger.info(
"Connecting to redis (host=%r port=%r DBID=%r)",
hs.config.redis_host,
hs.config.redis_port,
hs.config.redis_dbid,
)
# We need two connections to redis, one for the subscription stream and
# one to send commands to (as you can't send further redis commands to a
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
outbound_redis_connection = txredisapi.lazyConnection(
host=hs.config.redis_host,
port=hs.config.redis_port,
dbid=hs.config.redis_dbid,
password=hs.config.redis.redis_password,
reconnect=True,
)
# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
)
else:
client_name = hs.config.worker_name
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
async def on_REPLICATE(self, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self._is_master:
@@ -170,7 +124,7 @@ class ReplicationCommandHandler:
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
async def on_USER_SYNC(self, cmd: UserSyncCommand):
user_sync_counter.inc()
if self._is_master:
@@ -178,23 +132,17 @@ class ReplicationCommandHandler:
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
):
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
async def on_FEDERATION_ACK(
self, conn: AbstractConnection, cmd: FederationAckCommand
):
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
):
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
remove_pusher_counter.inc()
if self._is_master:
@@ -204,9 +152,7 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
):
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
invalidate_cache_counter.inc()
if self._is_master:
@@ -216,7 +162,7 @@ class ReplicationCommandHandler:
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
async def on_USER_IP(self, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self._is_master:
@@ -232,7 +178,7 @@ class ReplicationCommandHandler:
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
async def on_RDATA(self, cmd: RdataCommand):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@@ -283,7 +229,7 @@ class ReplicationCommandHandler:
logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
async def on_POSITION(self, cmd: PositionCommand):
stream = self._streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
@@ -322,14 +268,11 @@ class ReplicationCommandHandler:
missing_updates,
) = await stream.get_updates_since(current_token, cmd.token)
# TODO: add some tests for this
# Some streams return multiple rows with the same stream IDs,
# which need to be processed in batches.
for token, rows in _batch_updates(updates):
if updates:
await self.on_rdata(
cmd.stream_name, token, [stream.parse_row(row) for row in rows],
cmd.stream_name,
current_token,
[stream.parse_row(update[1]) for update in updates],
)
# We've now caught up to position sent to us, notify handler.
@@ -337,30 +280,19 @@ class ReplicationCommandHandler:
self._streams_connected.add(cmd.stream_name)
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
):
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
self._notifier.notify_remote_server_up(cmd.data)
if self._is_master:
self._notifier.notify_remote_server_up(cmd.data)
# We relay to all other connections to ensure every instance gets the
# notification.
#
# When configured to use redis we'll always only have one connection and
# so this is a no-op (all instances will have already received the same
# REMOTE_SERVER_UP command).
#
# For direct TCP connections this will relay to all other connections
# connected to us. When on master this will correctly fan out to all
# other direct TCP clients and on workers there'll only be the one
# connection to master.
#
# (The logic here should also be sound if we have a mix of Redis and
# direct TCP connections so long as there is only one traffic route
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users.
"""
return self._presence_handler.get_currently_syncing_users()
def new_connection(self, connection: AbstractConnection):
"""Called when we have a new connection.
@@ -379,11 +311,9 @@ class ReplicationCommandHandler:
if self._factory:
self._factory.resetDelay()
# Tell the other end if we have any users currently syncing.
currently_syncing = (
self._presence_handler.get_currently_syncing_users_for_replication()
)
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
currently_syncing = self.get_currently_syncing_users()
now = self._clock.time_msec()
for user_id in currently_syncing:
connection.send_command(
@@ -405,21 +335,11 @@ class ReplicationCommandHandler:
"""
return bool(self._connections)
def send_command(
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
):
def send_command(self, cmd: Command):
"""Send a command to all connected connections.
Args:
cmd
ignore_conn: If set don't send command to the given connection.
Used when relaying commands from one connection to all others.
"""
if self._connections:
for connection in self._connections:
if connection == ignore_conn:
continue
try:
connection.send_command(cmd)
except Exception:
@@ -484,52 +404,3 @@ class ReplicationCommandHandler:
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))
UpdateToken = TypeVar("UpdateToken")
UpdateRow = TypeVar("UpdateRow")
def _batch_updates(
updates: Iterable[Tuple[UpdateToken, UpdateRow]]
) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
"""Collect stream updates with the same token together
Given a series of updates returned by Stream.get_updates_since(), collects
the updates which share the same stream_id together.
For example:
[(1, a), (1, b), (2, c), (3, d), (3, e)]
becomes:
[
(1, [a, b]),
(2, [c]),
(3, [d, e]),
]
"""
update_iter = iter(updates)
first_update = next(update_iter, None)
if first_update is None:
# empty input
return
current_batch_token = first_update[0]
current_batch = [first_update[1]]
for token, row in update_iter:
if token != current_batch_token:
# different token to the previous row: flush the previous
# batch and start anew
yield current_batch_token, current_batch
current_batch_token = token
current_batch = []
current_batch.append(row)
# flush the final batch
yield current_batch_token, current_batch

View File

@@ -50,7 +50,10 @@ import abc
import fcntl
import logging
import struct
from typing import TYPE_CHECKING, List
from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, List
from six import iteritems
from prometheus_client import Counter
@@ -60,6 +63,7 @@ from twisted.python.failure import Failure
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
Command,
@@ -68,7 +72,6 @@ from synapse.replication.tcp.commands import (
PingCommand,
ReplicateCommand,
ServerCommand,
parse_command_from_line,
)
from synapse.types import Collection
from synapse.util import Clock
@@ -83,18 +86,6 @@ connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
tcp_inbound_commands_counter = Counter(
"synapse_replication_tcp_protocol_inbound_commands",
"Number of commands received from replication, by command and name of process connected to",
["command", "name"],
)
tcp_outbound_commands_counter = Counter(
"synapse_replication_tcp_protocol_outbound_commands",
"Number of commands sent to replication, by command and name of process connected to",
["command", "name"],
)
# A list of all connected protocols. This allows us to send metrics about the
# connections.
connected_connections = []
@@ -160,6 +151,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The LoopingCall for sending pings.
self._send_ping_loop = None
self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -216,21 +210,37 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
linestr = line.decode("utf-8")
try:
cmd = parse_command_from_line(linestr)
except Exception as e:
logger.exception("[%s] failed to parse line: %r", self.id(), linestr)
self.send_error("failed to parse line: %r (%r):" % (e, linestr))
return
# split at the first " ", handling one-word commands
idx = linestr.index(" ")
if idx >= 0:
cmd_name = linestr[:idx]
rest_of_line = linestr[idx + 1 :]
else:
cmd_name = linestr
rest_of_line = ""
if cmd.NAME not in self.VALID_INBOUND_COMMANDS:
logger.error("[%s] invalid command %s", self.id(), cmd.NAME)
self.send_error("invalid command: %s", cmd.NAME)
if cmd_name not in self.VALID_INBOUND_COMMANDS:
logger.error("[%s] invalid command %s", self.id(), cmd_name)
self.send_error("invalid command: %s", cmd_name)
return
self.last_received_command = self.clock.time_msec()
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
self.inbound_commands_counter[cmd_name] = (
self.inbound_commands_counter[cmd_name] + 1
)
cmd_cls = COMMAND_MAP[cmd_name]
try:
cmd = cmd_cls.from_line(rest_of_line)
except Exception as e:
logger.exception(
"[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
)
self.send_error(
"failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line)
)
return
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
@@ -260,7 +270,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(self, cmd)
await cmd_func(cmd)
handled = True
if not handled:
@@ -296,8 +306,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self._queue_command(cmd)
return
tcp_outbound_commands_counter.labels(cmd.NAME, self.name).inc()
self.outbound_commands_counter[cmd.NAME] = (
self.outbound_commands_counter[cmd.NAME] + 1
)
string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -549,3 +560,26 @@ tcp_transport_kernel_read_buffer = LaterGauge(
for p in connected_connections
},
)
tcp_inbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_inbound_commands",
"",
["command", "name"],
lambda: {
(k, p.name): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
},
)
tcp_outbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_outbound_commands",
"",
["command", "name"],
lambda: {
(k, p.name): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
},
)

View File

@@ -1,193 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
import txredisapi
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
Command,
ReplicateCommand,
parse_command_from_line,
)
from synapse.replication.tcp.protocol import (
AbstractConnection,
tcp_inbound_commands_counter,
tcp_outbound_commands_counter,
)
if TYPE_CHECKING:
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
"""Connection to redis subscribed to replication stream.
Parses incoming messages from redis into replication commands, and passes
them to `ReplicationCommandHandler`
Due to the vagaries of `txredisapi` we don't want to have a custom
constructor, so instead we expect the defined attributes below to be set
immediately after initialisation.
Attributes:
handler: The command handler to handle incoming commands.
stream_name: The *redis* stream name to subscribe to (not anything to
do with Synapse replication streams).
outbound_redis_connection: The connection to redis to use to send
commands.
"""
handler = None # type: ReplicationCommandHandler
stream_name = None # type: str
outbound_redis_connection = None # type: txredisapi.RedisProtocol
def connectionMade(self):
logger.info("Connected to redis instance")
self.subscribe(self.stream_name)
self.send_command(ReplicateCommand())
self.handler.new_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
if message.strip() == "":
# Ignore blank lines
return
try:
cmd = parse_command_from_line(message)
except Exception:
logger.exception(
"[%s] failed to parse line: %r", message,
)
return
# We use "redis" as the name here as we don't have 1:1 connections to
# remote instances.
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
cmd: received command
"""
handled = False
# First call any command handlers on this instance. These are for redis
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(self, cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd)
def connectionLost(self, reason):
logger.info("Lost connection to redis instance")
self.handler.lost_connection(self)
def send_command(self, cmd: Command):
"""Send a command if connection has been established.
Args:
cmd (Command)
"""
string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
encoded_string = string.encode("utf-8")
# We use "redis" as the name here as we don't have 1:1 connections to
# remote instances.
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
async def _send():
with PreserveLoggingContext():
# Note that we use the other connection as we can't send
# commands using the subscription connection.
await self.outbound_redis_connection.publish(
self.stream_name, encoded_string
)
run_as_background_process("send-cmd", _send)
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
Args:
hs
outbound_redis_connection: A connection to redis that will be used to
send outbound commands (this is seperate to the redis connection
used to subscribe).
"""
maxDelay = 5
continueTrying = True
protocol = RedisSubscriber
def __init__(
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
):
super().__init__()
# This sets the password on the RedisFactory base class (as
# SubscriberFactory constructor doesn't pass it through).
self.password = hs.config.redis.redis_password
self.handler = hs.get_tcp_replication()
self.stream_name = hs.hostname
self.outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr):
p = super().buildProtocol(addr) # type: RedisSubscriber
# We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the
# protocol.
p.handler = self.handler
p.outbound_redis_connection = self.outbound_redis_connection
p.stream_name = self.stream_name
return p

View File

@@ -17,7 +17,9 @@
import logging
import random
from typing import Dict, List
from typing import Dict
from six import itervalues
from prometheus_client import Counter
@@ -69,28 +71,29 @@ class ReplicationStreamer(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self._server_notices_sender = hs.get_server_notices_sender()
self._replication_torture_level = hs.config.replication_torture_level
# Work out list of streams that this instance is the source of.
self.streams = [] # type: List[Stream]
if hs.config.worker_app is None:
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
# We only support federation stream if federation sending
# hase been disabled on the master.
continue
self.streams.append(stream(hs))
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
stream(hs)
for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
# Only bother registering the notifier callback if we have streams to
# publish.
if self.streams:
self.notifier.add_replication_callback(self.on_notifier_poke)
self.federation_sender = None
if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates
self.is_looping = False

View File

@@ -25,6 +25,8 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
from typing import Dict, Type
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
@@ -65,7 +67,8 @@ STREAMS_MAP = {
GroupServerStream,
UserSignatureStream,
)
}
} # type: Dict[str, Type[Stream]]
__all__ = [
"STREAMS_MAP",

View File

@@ -16,16 +16,17 @@
import logging
from collections import namedtuple
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
# the number of rows to request from an update_function.
_STREAM_UPDATE_TARGET_ROW_COUNT = 100
MAX_EVENTS_BEHIND = 500000
# Some type aliases to make things a bit easier.
@@ -33,36 +34,8 @@ _STREAM_UPDATE_TARGET_ROW_COUNT = 100
# A stream position token
Token = int
# The type of a stream update row, after JSON deserialisation, but before
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
StreamRow = Tuple
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
#
# It consists of a triplet `(updates, new_last_token, limited)`, where:
# * `updates` is a list of `(token, row)` entries.
# * `new_last_token` is the new position in stream.
# * `limited` is whether there are more updates to fetch.
#
StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# The type of an update_function for a stream
#
# The arguments are:
#
# * from_token: the previous stream token: the starting point for fetching the
# updates
# * to_token: the new stream token: the point to get updates up to
# * target_row_count: a target for the number of rows to be returned.
#
# The update_function is expected to return up to _approximately_ target_row_count rows.
# If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch.
#
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
StreamRow = Tuple[Token, tuple]
class Stream(object):
@@ -77,7 +50,7 @@ class Stream(object):
ROW_TYPE = None # type: Any
@classmethod
def parse_row(cls, row: StreamRow):
def parse_row(cls, row):
"""Parse a row received over replication
By default, assumes that the row data is an array object and passes its contents
@@ -91,28 +64,7 @@ class Stream(object):
"""
return cls.ROW_TYPE(*row)
def __init__(
self,
current_token_function: Callable[[], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
current_token_function and update_function are callbacks which should be
implemented by subclasses.
current_token_function is called to get the current token of the underlying
stream.
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
Args:
current_token_function: callback to get the current token, as above
update_function: callback go get stream updates, as above
"""
self.current_token = current_token_function
self.update_function = update_function
def __init__(self, hs):
# The token from which we last asked for updates
self.last_token = self.current_token()
@@ -123,7 +75,7 @@ class Stream(object):
"""
self.last_token = self.current_token()
async def get_updates(self) -> StreamUpdateResult:
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before).
@@ -142,8 +94,8 @@ class Stream(object):
return updates, current_token, limited
async def get_updates_since(
self, from_token: Token, upto_token: Token
) -> StreamUpdateResult:
self, from_token: Token, upto_token: Token, limit: int = 100
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Like get_updates except allows specifying from when we should
stream updates
@@ -160,14 +112,33 @@ class Stream(object):
return [], upto_token, False
updates, upto_token, limited = await self.update_function(
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
from_token, upto_token, limit=limit,
)
return updates, upto_token, limited
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
by the sub classes
Returns:
int
"""
raise NotImplementedError()
def update_function(self, from_token, current_token, limit):
"""Get updates between from_token and to_token.
Returns:
Deferred(list(tuple)): the first entry in the tuple is the token for
that update, and the rest of the tuple gets used to construct
a ``ROW_TYPE`` instance
"""
raise NotImplementedError()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
) -> UpdateFunction:
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
@@ -177,16 +148,17 @@ def db_query_to_update_function(
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) == limit:
upto_token = updates[-1][0]
upto_token = rows[-1][0]
limited = True
assert len(updates) <= limit
return updates, upto_token, limited
return update_function
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
def make_http_update_function(
hs, stream_name: str
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
"""
@@ -195,9 +167,12 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
async def update_function(
from_token: int, upto_token: int, limit: int
) -> StreamUpdateResult:
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
result = await client(
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
limit=limit,
)
return result["updates"], result["upto_token"], result["limited"]
@@ -226,10 +201,10 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
store.get_current_backfill_token,
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
self.current_token = store.get_current_backfill_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs)
class PresenceStream(Stream):
@@ -251,18 +226,19 @@ class PresenceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
self._is_worker = hs.config.worker_app is not None
self.current_token = store.get_current_presence_token # type: ignore
if hs.config.worker_app is None:
# on the master, query the presence handler
presence_handler = hs.get_presence_handler()
update_function = db_query_to_update_function(
presence_handler.get_all_presence_updates
)
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
else:
# Query master process
update_function = make_http_update_function(hs, self.NAME)
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super().__init__(store.get_current_presence_token, update_function)
super(PresenceStream, self).__init__(hs)
class TypingStream(Stream):
@@ -276,16 +252,15 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore
if hs.config.worker_app is None:
# on the master, query the typing handler
update_function = db_query_to_update_function(
typing_handler.get_all_typing_updates
)
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
else:
# Query master process
update_function = make_http_update_function(hs, self.NAME)
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super().__init__(typing_handler.get_current_token, update_function)
super(TypingStream, self).__init__(hs)
class ReceiptsStream(Stream):
@@ -305,10 +280,11 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
store.get_max_receipt_stream_id,
db_query_to_update_function(store.get_all_updated_receipts),
)
self.current_token = store.get_max_receipt_stream_id # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs)
class PushRulesStream(Stream):
@@ -322,15 +298,13 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super(PushRulesStream, self).__init__(
self._current_token, self._update_function
)
super(PushRulesStream, self).__init__(hs)
def _current_token(self) -> int:
def current_token(self):
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
async def _update_function(self, from_token: Token, to_token: Token, limit: int):
async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
limited = False
@@ -356,10 +330,10 @@ class PushersStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
store.get_pushers_stream_token,
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
self.current_token = store.get_pushers_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs)
class CachesStream(Stream):
@@ -387,10 +361,11 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
)
self.current_token = store.get_cache_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs)
class PublicRoomsStream(Stream):
@@ -412,10 +387,11 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
store.get_current_public_room_stream_id,
db_query_to_update_function(store.get_all_new_public_rooms),
)
self.current_token = store.get_current_public_room_stream_id # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs)
class DeviceListsStream(Stream):
@@ -432,10 +408,11 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
store.get_device_stream_token,
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs)
class ToDeviceStream(Stream):
@@ -449,10 +426,11 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
store.get_to_device_stream_token,
db_query_to_update_function(store.get_all_new_device_messages),
)
self.current_token = store.get_to_device_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs)
class TagAccountDataStream(Stream):
@@ -468,10 +446,11 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
store.get_max_account_data_stream_id,
db_query_to_update_function(store.get_all_updated_tags),
)
self.current_token = store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs)
class AccountDataStream(Stream):
@@ -487,10 +466,11 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super().__init__(
self.store.get_max_account_data_stream_id,
db_query_to_update_function(self._update_function),
)
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
@@ -517,10 +497,11 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
store.get_group_stream_token,
db_query_to_update_function(store.get_all_groups_changes),
)
self.current_token = store.get_group_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs)
class UserSignatureStream(Stream):
@@ -534,9 +515,8 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
store.get_device_stream_token,
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),
)
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs)

View File

@@ -15,12 +15,11 @@
# limitations under the License.
import heapq
from collections import Iterable
from typing import List, Tuple, Type
from typing import Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
from ._base import Stream, db_query_to_update_function
"""Handling of the 'events' replication stream
@@ -117,121 +116,29 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
super().__init__(
self._store.get_current_events_token, self._update_function,
self.current_token = self._store.get_current_events_token # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs)
async def _update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)
event_updates = (
(row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
)
async def _update_function(
self, from_token: Token, current_token: Token, target_row_count: int
) -> StreamUpdateResult:
# the events stream merges together three separate sources:
# * new events
# * current_state changes
# * events which were previously outliers, but have now been de-outliered.
#
# The merge operation is complicated by the fact that we only have a single
# "stream token" which is supposed to indicate how far we have got through
# all three streams. It's therefore no good to return rows 1-1000 from the
# "new events" table if the state_deltas are limited to rows 1-100 by the
# target_row_count.
#
# In other words: we must pick a new upper limit, and must return *all* rows
# up to that point for each of the three sources.
#
# Start by trying to split the target_row_count up. We expect to have a
# negligible number of ex-outliers, and a rough approximation based on recent
# traffic on sw1v.org shows that there are approximately the same number of
# event rows between a given pair of stream ids as there are state
# updates, so let's split our target_row_count among those two types. The target
# is only an approximation - it doesn't matter if we end up going a bit over it.
target_row_count //= 2
# now we fetch up to that many rows from the events table
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, target_row_count
) # type: List[Tuple]
# we rely on get_all_new_forward_event_rows strictly honouring the limit, so
# that we know it is safe to just take upper_limit = event_rows[-1][0].
assert (
len(event_rows) <= target_row_count
), "get_all_new_forward_event_rows did not honour row limit"
# if we hit the limit on event_updates, there's no point in going beyond the
# last stream_id in the batch for the other sources.
if len(event_rows) == target_row_count:
limited = True
upper_limit = event_rows[-1][0] # type: int
else:
limited = False
upper_limit = current_token
# next up is the state delta table
state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, upper_limit, target_row_count
) # type: List[Tuple]
assert len(state_rows) <= target_row_count
# there can be more than one row per stream_id in that table, so if we hit
# the limit there, we'll need to truncate the results so that we have a complete
# set of changes for all the stream IDs we include.
if len(state_rows) == target_row_count:
assert state_rows[-1][0] <= upper_limit
upper_limit = state_rows[-1][0] - 1
# search for the point to truncate the list
for idx in range(len(state_rows) - 1, 0, -1):
if state_rows[idx - 1][0] <= upper_limit:
state_rows = state_rows[:idx]
break
else:
# bother. We didn't get a full set of changes for even a single
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
upper_limit += 1
state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, upper_limit, limit=None
)
limited = True
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
from_token, upper_limit
) # type: List[Tuple]
# we now need to turn the raw database rows returned into tuples suitable
# for the replication protocol (basically, we add an identifier to
# distinguish the row type). At the same time, we can limit the event_rows
# to the max stream_id from state_rows.
event_updates = (
(stream_id, (EventsStreamEventRow.TypeId, rest))
for (stream_id, *rest) in event_rows
if stream_id <= upper_limit
) # type: Iterable[Tuple[int, Tuple]]
from_token, current_token, limit
)
state_updates = (
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
for (stream_id, *rest) in state_rows
) # type: Iterable[Tuple[int, Tuple]]
(row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
)
ex_outliers_updates = (
(stream_id, (EventsStreamEventRow.TypeId, rest))
for (stream_id, *rest) in ex_outliers_rows
) # type: Iterable[Tuple[int, Tuple]]
all_updates = heapq.merge(event_updates, state_updates)
# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
return updates, upper_limit, limited
return all_updates
@classmethod
def parse_row(cls, row):

View File

@@ -15,6 +15,8 @@
# limitations under the License.
from collections import namedtuple
from twisted.internet import defer
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
@@ -33,6 +35,7 @@ class FederationStream(Stream):
NAME = "federation"
ROW_TYPE = FederationStreamRow
_QUERY_MASTER = True
def __init__(self, hs):
# Not all synapse instances will have a federation sender instance,
@@ -40,16 +43,10 @@ class FederationStream(Stream):
# so we stub the stream out when that is the case.
if hs.config.worker_app is None or hs.should_send_federation():
federation_sender = hs.get_federation_sender()
current_token = federation_sender.get_current_token
update_function = db_query_to_update_function(
federation_sender.get_replication_rows
)
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
else:
current_token = lambda: 0
update_function = self._stub_update_function
self.current_token = lambda: 0 # type: ignore
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
super().__init__(current_token, update_function)
@staticmethod
async def _stub_update_function(from_token, upto_token, limit):
return [], upto_token, False
super(FederationStream, self).__init__(hs)

View File

@@ -1,18 +0,0 @@
<html>
<head>
<title>Authentication Successful</title>
<script>
if (window.onAuthDone) {
window.onAuthDone();
} else if (window.opener && window.opener.postMessage) {
window.opener.postMessage("authDone", "*");
}
</script>
</head>
<body>
<div>
<p>Thank you</p>
<p>You may now close this window and return to the application</p>
</div>
</body>
</html>

View File

@@ -183,23 +183,10 @@ class ListRoomRestServlet(RestServlet):
# Extract query parameters
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
order_by = parse_string(request, "order_by", default="alphabetical")
if order_by not in (
RoomSortOrder.ALPHABETICAL.value,
RoomSortOrder.SIZE.value,
RoomSortOrder.NAME.value,
RoomSortOrder.CANONICAL_ALIAS.value,
RoomSortOrder.JOINED_MEMBERS.value,
RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
RoomSortOrder.VERSION.value,
RoomSortOrder.CREATOR.value,
RoomSortOrder.ENCRYPTION.value,
RoomSortOrder.FEDERATABLE.value,
RoomSortOrder.PUBLIC.value,
RoomSortOrder.JOIN_RULES.value,
RoomSortOrder.GUEST_ACCESS.value,
RoomSortOrder.HISTORY_VISIBILITY.value,
RoomSortOrder.STATE_EVENTS.value,
):
raise SynapseError(
400,

View File

@@ -30,7 +30,7 @@ from synapse.http.servlet import (
)
from synapse.push.mailer import Mailer, load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.stringutils import assert_valid_client_secret
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -100,11 +100,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
)
if existing_user_id is None:
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
@@ -395,11 +390,6 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
)
if existing_user_id is not None:
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
@@ -463,11 +453,6 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
if not self.hs.config.account_threepid_delegate_msisdn:

View File

@@ -38,12 +38,8 @@ class AccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -90,12 +86,8 @@ class RoomAccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
async def on_PUT(self, request, user_id, room_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")

View File

@@ -18,6 +18,7 @@ import logging
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.handlers.auth import SUCCESS_TEMPLATE
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_string
@@ -89,30 +90,6 @@ TERMS_TEMPLATE = """
</html>
"""
SUCCESS_TEMPLATE = """
<html>
<head>
<title>Success!</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script>
if (window.onAuthDone) {
window.onAuthDone();
} else if (window.opener && window.opener.postMessage) {
window.opener.postMessage("authDone", "*");
}
</script>
</head>
<body>
<div>
<p>Thank you</p>
<p>You may now close this window and return to the application</p>
</div>
</body>
</html>
"""
class AuthRestServlet(RestServlet):
"""

View File

@@ -49,7 +49,7 @@ from synapse.http.servlet import (
from synapse.push.mailer import load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.stringutils import assert_valid_client_secret
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -135,11 +135,6 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
if existing_user_id is not None:
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
@@ -207,11 +202,6 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
if existing_user_id is not None:
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
return 200, {"sid": random_string(16)}
raise SynapseError(
400, "Phone number is already in use", Codes.THREEPID_IN_USE
)

View File

@@ -402,7 +402,7 @@ class PreviewUrlResource(DirectServeResource):
now = self.clock.time_msec()
logger.debug("Running url preview cache expiry")
logger.info("Running url preview cache expiry")
if not (await self.store.db.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
@@ -435,8 +435,6 @@ class PreviewUrlResource(DirectServeResource):
if removed_media:
logger.info("Deleted %d entries from url cache", len(removed_media))
else:
logger.debug("No entries removed from url cache")
# Now we delete old images associated with the url cache.
# These may be cached for a bit on the client (i.e., they
@@ -483,10 +481,7 @@ class PreviewUrlResource(DirectServeResource):
await self.store.delete_url_cache_media(removed_media)
if removed_media:
logger.info("Deleted %d media from url cache", len(removed_media))
else:
logger.debug("No media removed from url cache")
logger.info("Deleted %d media from url cache", len(removed_media))
def decode_and_calc_og(body, media_uri, request_encoding=None):

View File

@@ -25,7 +25,6 @@ import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
class HomeServer(object):
@property
@@ -98,7 +97,7 @@ class HomeServer(object):
pass
def get_notifier(self) -> synapse.notifier.Notifier:
pass
def get_presence_handler(self) -> synapse.handlers.presence.BasePresenceHandler:
def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler:
pass
def get_clock(self) -> synapse.util.Clock:
pass
@@ -122,7 +121,3 @@ class HomeServer(object):
pass
def get_instance_id(self) -> str:
pass
def get_event_builder_factory(self) -> EventBuilderFactory:
pass
def get_storage(self) -> synapse.storage.Storage:
pass

View File

@@ -973,18 +973,8 @@ class EventsWorkerStore(SQLBaseStore):
return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
"""Returns new events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns: Deferred[List[Tuple]]
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
"""
if last_id == current_id:
return defer.succeed([])
def get_all_new_forward_event_rows(txn):
sql = (
@@ -999,26 +989,13 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
new_event_updates = txn.fetchall()
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
def get_ex_outlier_stream_rows(self, last_id, current_id):
"""Returns de-outliered events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
Returns: Deferred[List[Tuple]]
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
"""
def get_ex_outlier_stream_rows_txn(txn):
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
@@ -1029,14 +1006,15 @@ class EventsWorkerStore(SQLBaseStore):
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering ASC"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_id, upper_bound))
new_event_updates.extend(txn)
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
return new_event_updates
return self.db.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
@@ -1084,23 +1062,15 @@ class EventsWorkerStore(SQLBaseStore):
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
def get_all_updated_current_state_deltas(
self, from_token: int, to_token: int, limit: Optional[int]
):
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
ORDER BY stream_id ASC LIMIT ?
"""
params = [from_token, to_token]
if limit is not None:
sql += "LIMIT ?"
params.append(limit)
txn.execute(sql, params)
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
return self.db.runInteraction(

View File

@@ -52,28 +52,12 @@ class RoomSortOrder(Enum):
"""
Enum to define the sorting method used when returning rooms with get_rooms_paginate
NAME = sort rooms alphabetically by name
JOINED_MEMBERS = sort rooms by membership size, highest to lowest
ALPHABETICAL = sort rooms alphabetically by name
SIZE = sort rooms by membership size, highest to lowest
"""
# ALPHABETICAL and SIZE are deprecated.
# ALPHABETICAL is the same as NAME.
ALPHABETICAL = "alphabetical"
# SIZE is the same as JOINED_MEMBERS.
SIZE = "size"
NAME = "name"
CANONICAL_ALIAS = "canonical_alias"
JOINED_MEMBERS = "joined_members"
JOINED_LOCAL_MEMBERS = "joined_local_members"
VERSION = "version"
CREATOR = "creator"
ENCRYPTION = "encryption"
FEDERATABLE = "federatable"
PUBLIC = "public"
JOIN_RULES = "join_rules"
GUEST_ACCESS = "guest_access"
HISTORY_VISIBILITY = "history_visibility"
STATE_EVENTS = "state_events"
class RoomWorkerStore(SQLBaseStore):
@@ -345,52 +329,12 @@ class RoomWorkerStore(SQLBaseStore):
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
# Deprecated in favour of RoomSortOrder.JOINED_MEMBERS
order_by_column = "curr.joined_members"
order_by_asc = False
elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL:
# Deprecated in favour of RoomSortOrder.NAME
# Sort alphabetically
order_by_column = "state.name"
order_by_asc = True
elif RoomSortOrder(order_by) == RoomSortOrder.NAME:
order_by_column = "state.name"
order_by_asc = True
elif RoomSortOrder(order_by) == RoomSortOrder.CANONICAL_ALIAS:
order_by_column = "state.canonical_alias"
order_by_asc = True
elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_MEMBERS:
order_by_column = "curr.joined_members"
order_by_asc = False
elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_LOCAL_MEMBERS:
order_by_column = "curr.local_users_in_room"
order_by_asc = False
elif RoomSortOrder(order_by) == RoomSortOrder.VERSION:
order_by_column = "rooms.room_version"
order_by_asc = False
elif RoomSortOrder(order_by) == RoomSortOrder.CREATOR:
order_by_column = "rooms.creator"
order_by_asc = True
elif RoomSortOrder(order_by) == RoomSortOrder.ENCRYPTION:
order_by_column = "state.encryption"
order_by_asc = True
elif RoomSortOrder(order_by) == RoomSortOrder.FEDERATABLE:
order_by_column = "state.is_federatable"
order_by_asc = True
elif RoomSortOrder(order_by) == RoomSortOrder.PUBLIC:
order_by_column = "rooms.is_public"
order_by_asc = True
elif RoomSortOrder(order_by) == RoomSortOrder.JOIN_RULES:
order_by_column = "state.join_rules"
order_by_asc = True
elif RoomSortOrder(order_by) == RoomSortOrder.GUEST_ACCESS:
order_by_column = "state.guest_access"
order_by_asc = True
elif RoomSortOrder(order_by) == RoomSortOrder.HISTORY_VISIBILITY:
order_by_column = "state.history_visibility"
order_by_asc = True
elif RoomSortOrder(order_by) == RoomSortOrder.STATE_EVENTS:
order_by_column = "curr.current_state_events"
order_by_asc = False
else:
raise StoreError(
500, "Incorrect value for order_by provided: %s" % order_by
@@ -405,13 +349,9 @@ class RoomWorkerStore(SQLBaseStore):
# for, and another query for getting the total number of events that could be
# returned. Thus allowing us to see if there are more events to paginate through
info_sql = """
SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room, rooms.room_version, rooms.creator,
state.encryption, state.is_federatable, rooms.is_public, state.join_rules,
state.guest_access, state.history_visibility, curr.current_state_events
SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members
FROM room_stats_state state
INNER JOIN room_stats_current curr USING (room_id)
INNER JOIN rooms USING (room_id)
%s
ORDER BY %s %s
LIMIT ?
@@ -449,16 +389,6 @@ class RoomWorkerStore(SQLBaseStore):
"name": room[1],
"canonical_alias": room[2],
"joined_members": room[3],
"joined_local_members": room[4],
"version": room[5],
"creator": room[6],
"encryption": room[7],
"federatable": room[8],
"public": room[9],
"join_rules": room[10],
"guest_access": room[11],
"history_visibility": room[12],
"state_events": room[13],
}
)

View File

@@ -14,7 +14,6 @@
# limitations under the License.
import logging
from typing import Dict, Iterable, List, Mapping, Optional, Set
from six import integer_types
@@ -24,11 +23,8 @@ from synapse.util import caches
logger = logging.getLogger(__name__)
# for now, assume all entities in the cache are strings
EntityType = str
class StreamChangeCache:
class StreamChangeCache(object):
"""Keeps track of the stream positions of the latest change in a set of entities.
Typically the entity will be a room or user id.
@@ -38,23 +34,10 @@ class StreamChangeCache:
old then the cache will simply return all given entities.
"""
def __init__(
self,
name: str,
current_stream_pos: int,
max_size=10000,
prefilled_cache: Optional[Mapping[EntityType, int]] = None,
):
def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None):
self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR)
self._entity_to_key = {} # type: Dict[EntityType, int]
# map from stream id to the a set of entities which changed at that stream id.
self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]]
# the earliest stream_pos for which we can reliably answer
# get_all_entities_changed. In other words, one less than the earliest
# stream_pos for which we know _cache is valid.
#
self._entity_to_key = {}
self._cache = SortedDict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
self.metrics = caches.register_cache("cache", self.name, self._cache)
@@ -63,7 +46,7 @@ class StreamChangeCache:
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
assert type(stream_pos) in integer_types
@@ -84,17 +67,22 @@ class StreamChangeCache:
self.metrics.inc_hits()
return False
def get_entities_changed(
self, entities: Iterable[EntityType], stream_pos: int
) -> Set[EntityType]:
def get_entities_changed(self, entities, stream_pos):
"""
Returns subset of entities that have had new things since the given
position. Entities unknown to the cache will be returned. If the
position is too old it will just return the given list.
"""
changed_entities = self.get_all_entities_changed(stream_pos)
if changed_entities is not None:
result = set(changed_entities).intersection(entities)
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
changed_entities = {
self._cache[k]
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
}
result = changed_entities.intersection(entities)
self.metrics.inc_hits()
else:
result = set(entities)
@@ -102,13 +90,13 @@ class StreamChangeCache:
return result
def has_any_entity_changed(self, stream_pos: int) -> bool:
def has_any_entity_changed(self, stream_pos):
"""Returns if any entity has changed
"""
assert type(stream_pos) is int
if not self._cache:
# If the cache is empty, nothing can have changed.
# If we have no cache, nothing can have changed.
return False
if stream_pos >= self._earliest_known_stream_pos:
@@ -118,58 +106,42 @@ class StreamChangeCache:
self.metrics.inc_misses()
return True
def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
"""Returns all entities that have had new things since the given
def get_all_entities_changed(self, stream_pos):
"""Returns all entites that have had new things since the given
position. If the position is too old it will return None.
Returns the entities in the order that they were changed.
"""
assert type(stream_pos) is int
if stream_pos < self._earliest_known_stream_pos:
if stream_pos >= self._earliest_known_stream_pos:
return [
self._cache[k]
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
]
else:
return None
changed_entities = [] # type: List[EntityType]
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
changed_entities.extend(self._cache[k])
return changed_entities
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
def entity_has_changed(self, entity, stream_pos):
"""Informs the cache that the entity has been changed at the given
position.
"""
assert type(stream_pos) is int
if stream_pos <= self._earliest_known_stream_pos:
return
if stream_pos > self._earliest_known_stream_pos:
old_pos = self._entity_to_key.get(entity, None)
if old_pos is not None:
stream_pos = max(stream_pos, old_pos)
self._cache.pop(old_pos, None)
self._cache[stream_pos] = entity
self._entity_to_key[entity] = stream_pos
old_pos = self._entity_to_key.get(entity, None)
if old_pos is not None:
if old_pos >= stream_pos:
# nothing to do
return
e = self._cache[old_pos]
e.remove(entity)
if not e:
# cache at this point is now empty
del self._cache[old_pos]
e1 = self._cache.get(stream_pos)
if e1 is None:
e1 = self._cache[stream_pos] = set()
e1.add(entity)
self._entity_to_key[entity] = stream_pos
# if the cache is too big, remove entries
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
for entity in r:
del self._entity_to_key[entity]
def get_max_pos_of_last_change(self, entity: EntityType) -> int:
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(
k, self._earliest_known_stream_pos
)
self._entity_to_key.pop(r, None)
def get_max_pos_of_last_change(self, entity):
"""Returns an upper bound of the stream id of the last change to an
entity.
"""

View File

@@ -1,65 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase):
def setUp(self):
event = FrozenEvent(
{
"event_id": "$event_id",
"type": "m.room.history_visibility",
"sender": "@user:test",
"state_key": "",
"room_id": "@room:test",
"content": {"body": "foo bar baz"},
},
RoomVersions.V1,
)
room_member_count = 0
sender_power_level = 0
power_levels = {}
self.evaluator = PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels
)
def test_display_name(self):
"""Check for a matching display name in the body of the event."""
condition = {
"kind": "contains_display_name",
}
# Blank names are skipped.
self.assertFalse(self.evaluator.matches(condition, "@user:test", ""))
# Check a display name that doesn't match.
self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found"))
# Check a display name which matches.
self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo"))
# A display name that matches, but not a full word does not result in a match.
self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba"))
# A display name should not be interpreted as a regular expression.
self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]"))
# A display name with spaces should work fine.
self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar"))

View File

@@ -16,7 +16,7 @@
from mock import Mock, NonCallableMock
from synapse.replication.tcp.client import (
DirectTcpReplicationClientFactory,
ReplicationClientFactory,
ReplicationDataHandler,
)
from synapse.replication.tcp.handler import ReplicationCommandHandler
@@ -61,7 +61,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
self.slaved_store
)
client_factory = DirectTcpReplicationClientFactory(
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
client_factory.handler = self.replication_handler

View File

@@ -13,72 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from mock import Mock
import attr
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
def make_homeserver(self, reactor, clock):
self.test_handler = Mock(wraps=TestReplicationDataHandler())
return self.setup_test_homeserver(replication_data_handler=self.test_handler)
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
self.server = server_factory.buildProtocol(None)
# Make a new HomeServer object for the worker
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
http_client=None,
homeserverToUse=GenericWorkerServer,
config=config,
reactor=self.reactor,
)
# Since we use sqlite in memory databases we need to make sure the
# databases objects are the same.
self.worker_hs.get_datastore().db = hs.get_datastore().db
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs)
repl_handler = ReplicationCommandHandler(hs)
repl_handler.handler = self.test_handler
self.client = ClientReplicationStreamProtocol(
self.worker_hs, "client", "test", clock, repl_handler,
hs, "client", "test", clock, repl_handler,
)
self._client_transport = None
self._server_transport = None
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())
def reconnect(self):
if self._client_transport:
self.client.close()
@@ -108,204 +74,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump(0.1)
def handle_http_replication_attempt(self) -> SynapseRequest:
"""Asserts that a connection attempt was made to the master HS on the
HTTP replication port, then proxies it to the master HS object to be
handled.
Returns:
The request object received by master HS.
"""
# We should have an outbound connection attempt.
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8765)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
channel.site = self.site
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
)
channel.makeConnection(server_to_client_transport)
# The request will now be processed by `self.site` and the response
# streamed back.
self.reactor.advance(0)
# We tear down the connection so it doesn't get reused without our
# knowledge.
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
return request_factory.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
):
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
self.assertRegex(
request.path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
% (stream_name.encode("ascii"),),
)
self.assertEqual(request.method, b"GET")
class TestReplicationDataHandler(ReplicationDataHandler):
class TestReplicationDataHandler:
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self, store: BaseSlavedStore):
super().__init__(store)
# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]
# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
def __init__(self):
self.streams = set()
self._received_rdata_rows = []
def get_streams_to_replicate(self):
return self.stream_positions
positions = {s: 0 for s in self.streams}
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
self._received_rdata_rows.append((stream_name, token, r))
if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token
@attr.s()
class OneShotRequestFactory:
"""A simple request factory that generates a single `SynapseRequest` and
stores it for future use. Can only be used once.
"""
request = attr.ib(default=None)
def __call__(self, *args, **kwargs):
assert self.request is None
self.request = SynapseRequest(*args, **kwargs)
return self.request
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
This is a hack to get around the fact that HTTPChannel transparently wraps a
pull producer (which is what Synapse uses to reply to requests) with
`_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
uses the standard reactor rather than letting us use our test reactor, which
makes it very hard to test.
"""
def __init__(self, reactor: IReactorTime):
super().__init__()
self.reactor = reactor
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
def registerProducer(self, producer, streaming):
# Convert pull producers to push producer.
if not streaming:
self._pull_to_push_producer = _PullToPushProducer(
self.reactor, producer, self
)
producer = self._pull_to_push_producer
super().registerProducer(producer, True)
def unregisterProducer(self):
if self._pull_to_push_producer:
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
class _PullToPushProducer:
"""A push producer that wraps a pull producer.
"""
def __init__(
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
):
self._clock = Clock(reactor)
self._producer = producer
self._consumer = consumer
# While running we use a looping call with a zero delay to call
# resumeProducing on given producer.
self._looping_call = None # type: Optional[LoopingCall]
# We start writing next reactor tick.
self._start_loop()
def _start_loop(self):
"""Start the looping call to
"""
if not self._looping_call:
# Start a looping call which runs every tick.
self._looping_call = self._clock.looping_call(self._run_once, 0)
def stop(self):
"""Stops calling resumeProducing.
"""
if self._looping_call:
self._looping_call.stop()
self._looping_call = None
def pauseProducing(self):
"""Implements IPushProducer
"""
self.stop()
def resumeProducing(self):
"""Implements IPushProducer
"""
self._start_loop()
def stopProducing(self):
"""Implements IPushProducer
"""
self.stop()
self._producer.stopProducing()
def _run_once(self):
"""Calls resumeProducing on producer once.
"""
try:
self._producer.resumeProducing()
except Exception:
logger.exception("Failed to call resumeProducing")
try:
self._consumer.unregisterProducer()
except Exception:
pass
self.stopProducing()
async def on_position(self, stream_name, token):
pass

View File

@@ -1,390 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests.replication.tcp.streams._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
class EventsStreamTestCase(BaseStreamTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
super().prepare(reactor, clock, hs)
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
self.reconnect()
self.test_handler.stream_positions["events"] = 0
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
def test_update_function_event_row_limit(self):
"""Test replication with many non-state events
Checks that all events are correctly replicated when there are lots of
event rows to be replicated.
"""
# generate lots of non-state events. We inject them using inject_event
# so that they are not send out over replication until we call self.replicate().
events = [
self._inject_test_event()
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1)
]
# also one state event
state_event = self._inject_state_event()
# check we're testing what we think we are: no rows should yet have been
# receieved
self.assertEqual([], self.test_handler.received_rdata_rows)
# now fire up the replicator
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
for event in events:
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state_event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state_event.event_id)
self.assertEqual([], received_rows)
def test_update_function_huge_state_change(self):
"""Test replication with many state events
Ensures that all events are correctly replicated when there are lots of
state change rows to be replicated.
"""
# we want to generate lots of state changes at a single stream ID.
#
# We do this by having two branches in the DAG. On one, we have a moderator
# which that generates lots of state; on the other, we de-op the moderator,
# thus invalidating all the state.
OTHER_USER = "@other_user:localhost"
# have the user join
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
pls["users"][OTHER_USER] = 50
self.helper.send_state(
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
fork_point = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
events = [
self._inject_state_event(sender=OTHER_USER)
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
]
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
self.test_handler.received_rdata_rows.clear()
# a state event which doesn't get rolled back, to check that the state
# before the huge update comes through ok
state1 = self._inject_state_event()
# roll back all the state by de-modding the user
prev_events = fork_point
pls["users"][OTHER_USER] = 0
pl_event = inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
room_id=self.room_id,
content=pls,
)
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
# check we're testing what we think we are: no rows should yet have been
# receieved
self.assertEqual([], self.test_handler.received_rdata_rows)
# now fire up the replicator
self.replicate()
# now we should have received all the expected rows in the right order.
#
# we expect:
#
# - two rows for state1
# - the PL event row, plus state rows for the PL event and each
# of the states that got reverted.
# - two rows for state2
received_rows = self.test_handler.received_rdata_rows
# first check the first two rows, which should be state1
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state1.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state1.event_id)
# now the last two rows, which should be state2
stream_name, token, row = received_rows.pop(-2)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state2.event_id)
stream_name, token, row = received_rows.pop(-1)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state2.event_id)
# that should leave us with the rows for the PL event
self.assertEqual(len(received_rows), len(events) + 2)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_event.event_id)
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
for stream_name, token, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)
state_rows.sort(key=lambda r: r.state_key)
sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_event.event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self):
"""Test replication with many state events over several stream ids.
"""
# we want to generate lots of state changes, but for this test, we want to
# spread out the state changes over a few stream IDs.
#
# We do this by having two branches in the DAG. On one, we have four moderators,
# each of which that generates lots of state; on the other, we de-op the users,
# thus invalidating all the state.
NUM_USERS = 4
STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1
user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)]
# have the users join
for u in user_ids:
inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
pls["users"].update({u: 50 for u in user_ids})
self.helper.send_state(
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
fork_point = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
events = [] # type: List[EventBase]
for user in user_ids:
events.extend(
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
)
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
self.test_handler.received_rdata_rows.clear()
# now roll back all that state by de-modding the users
prev_events = fork_point
pl_events = []
for u in user_ids:
pls["users"][u] = 0
e = inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
room_id=self.room_id,
content=pls,
)
prev_events = [e.event_id]
pl_events.append(e)
# check we're testing what we think we are: no rows should yet have been
# receieved
self.assertEqual([], self.test_handler.received_rdata_rows)
# now fire up the replicator
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for
# the PL event and each of the states that got reverted.
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_events[i].event_id)
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
for j in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)
state_rows.sort(key=lambda r: r.state_key)
sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_events[i].event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
self.assertEqual([], received_rows)
event_count = 0
def _inject_test_event(
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
) -> EventBase:
if sender is None:
sender = self.user_id
if body is None:
body = "event %i" % (self.event_count,)
self.event_count += 1
return inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
type="test_event",
content={"body": body},
**kwargs
)
def _inject_state_event(
self,
body: Optional[str] = None,
state_key: Optional[str] = None,
sender: Optional[str] = None,
) -> EventBase:
if sender is None:
sender = self.user_id
if state_key is None:
state_key = "state_%i" % (self.event_count,)
self.event_count += 1
if body is None:
body = "state event %s" % (state_key,)
return inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
type="test_state_event",
state_key=state_key,
content={"body": body},
)

View File

@@ -12,11 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# type: ignore
from mock import Mock
from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
@@ -25,14 +20,11 @@ USER_ID = "@feeling:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def test_receipt(self):
self.reconnect()
# make the client subscribe to the receipts stream
self.test_handler.stream_positions.update({"receipts": 0})
self.test_handler.streams.add("receipts")
# tell the master to send a new receipt
self.get_success(

Some files were not shown because too many files have changed in this diff Show More