Compare commits
29 Commits
hs/sssh-te
...
anoa/testi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
18cf53376d | ||
|
|
e49a90899b | ||
|
|
6317eba770 | ||
|
|
da51afdc6b | ||
|
|
2ccad9a1b6 | ||
|
|
714e75dc1b | ||
|
|
0c1b27ecd0 | ||
|
|
ac1bbfdd2b | ||
|
|
7dd06332a9 | ||
|
|
76f15f4bf2 | ||
|
|
887ec58556 | ||
|
|
b3b2da56b3 | ||
|
|
cb56a51ada | ||
|
|
40042dec0d | ||
|
|
2c881bf8a4 | ||
|
|
667e9ca5be | ||
|
|
8490a8793c | ||
|
|
2ff55e02c1 | ||
|
|
0ac339cfe6 | ||
|
|
f4064862b2 | ||
|
|
67851671e5 | ||
|
|
a7dadf87be | ||
|
|
c215378411 | ||
|
|
e91373c1fa | ||
|
|
4e48515bbf | ||
|
|
70807c83c6 | ||
|
|
44cf7cf56e | ||
|
|
d2a9b45df0 | ||
|
|
81e74fbae5 |
44
CHANGES.md
44
CHANGES.md
@@ -1,40 +1,15 @@
|
||||
Next version
|
||||
============
|
||||
|
||||
* New templates (`sso_auth_confirm.html`, `sso_auth_success.html`, and
|
||||
`sso_account_deactivated.html`) were added to Synapse. If your Synapse is
|
||||
configured to use SSO and a custom `sso_redirect_confirm_template_dir`
|
||||
configuration then these templates will need to be duplicated into that
|
||||
directory.
|
||||
* Two new templates (`sso_auth_confirm.html` and `sso_account_deactivated.html`)
|
||||
were added to Synapse. If your Synapse is configured to use SSO and a custom
|
||||
`sso_redirect_confirm_template_dir` configuration then these templates will
|
||||
need to be duplicated into that directory.
|
||||
|
||||
* Plugins using the `complete_sso_login` method of `synapse.module_api.ModuleApi`
|
||||
should update to using the async/await version `complete_sso_login_async` which
|
||||
includes additional checks. The non-async version is considered deprecated.
|
||||
|
||||
|
||||
Synapse 1.12.4 (2020-04-23)
|
||||
===========================
|
||||
|
||||
No significant changes.
|
||||
|
||||
|
||||
Synapse 1.12.4rc1 (2020-04-22)
|
||||
==============================
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
- Always send users their own device updates. ([\#7160](https://github.com/matrix-org/synapse/issues/7160))
|
||||
- Add support for handling GET requests for `account_data` on a worker. ([\#7311](https://github.com/matrix-org/synapse/issues/7311))
|
||||
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix a bug that prevented cross-signing with users on worker-mode synapses. ([\#7255](https://github.com/matrix-org/synapse/issues/7255))
|
||||
- Do not treat display names as globs in push rules. ([\#7271](https://github.com/matrix-org/synapse/issues/7271))
|
||||
- Fix a bug with cross-signing devices belonging to remote users who did not share a room with any user on the local homeserver. ([\#7289](https://github.com/matrix-org/synapse/issues/7289))
|
||||
|
||||
Synapse 1.12.3 (2020-04-03)
|
||||
===========================
|
||||
|
||||
@@ -67,19 +42,12 @@ Bugfixes
|
||||
Synapse 1.12.0 (2020-03-23)
|
||||
===========================
|
||||
|
||||
No significant changes since 1.12.0rc1.
|
||||
|
||||
Debian packages and Docker images are rebuilt using the latest versions of
|
||||
dependency libraries, including Twisted 20.3.0. **Please see security advisory
|
||||
below**.
|
||||
|
||||
Potential slow database update during upgrade
|
||||
---------------------------------------------
|
||||
|
||||
Synapse 1.12.0 includes a database update which is run as part of the upgrade,
|
||||
and which may take some time (several hours in the case of a large
|
||||
server). Synapse will not respond to HTTP requests while this update is taking
|
||||
place. For imformation on seeing if you are affected, and workaround if you
|
||||
are, see the [upgrade notes](UPGRADE.rst#upgrading-to-v1120).
|
||||
|
||||
Security advisory
|
||||
-----------------
|
||||
|
||||
|
||||
65
UPGRADE.rst
65
UPGRADE.rst
@@ -75,71 +75,6 @@ for example:
|
||||
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||
|
||||
Upgrading to v1.12.0
|
||||
====================
|
||||
|
||||
This version includes a database update which is run as part of the upgrade,
|
||||
and which may take some time (several hours in the case of a large
|
||||
server). Synapse will not respond to HTTP requests while this update is taking
|
||||
place.
|
||||
|
||||
This is only likely to be a problem in the case of a server which is
|
||||
participating in many rooms.
|
||||
|
||||
0. As with all upgrades, it is recommended that you have a recent backup of
|
||||
your database which can be used for recovery in the event of any problems.
|
||||
|
||||
1. As an initial check to see if you will be affected, you can try running the
|
||||
following query from the `psql` or `sqlite3` console. It is safe to run it
|
||||
while Synapse is still running.
|
||||
|
||||
.. code:: sql
|
||||
|
||||
SELECT MAX(q.v) FROM (
|
||||
SELECT (
|
||||
SELECT ej.json AS v
|
||||
FROM state_events se INNER JOIN event_json ej USING (event_id)
|
||||
WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
|
||||
LIMIT 1
|
||||
) FROM rooms WHERE rooms.room_version IS NULL
|
||||
) q;
|
||||
|
||||
This query will take about the same amount of time as the upgrade process: ie,
|
||||
if it takes 5 minutes, then it is likely that Synapse will be unresponsive for
|
||||
5 minutes during the upgrade.
|
||||
|
||||
If you consider an outage of this duration to be acceptable, no further
|
||||
action is necessary and you can simply start Synapse 1.12.0.
|
||||
|
||||
If you would prefer to reduce the downtime, continue with the steps below.
|
||||
|
||||
2. The easiest workaround for this issue is to manually
|
||||
create a new index before upgrading. On PostgreSQL, his can be done as follows:
|
||||
|
||||
.. code:: sql
|
||||
|
||||
CREATE INDEX CONCURRENTLY tmp_upgrade_1_12_0_index
|
||||
ON state_events(room_id) WHERE type = 'm.room.create';
|
||||
|
||||
The above query may take some time, but is also safe to run while Synapse is
|
||||
running.
|
||||
|
||||
We assume that no SQLite users have databases large enough to be
|
||||
affected. If you *are* affected, you can run a similar query, omitting the
|
||||
``CONCURRENTLY`` keyword. Note however that this operation may in itself cause
|
||||
Synapse to stop running for some time. Synapse admins are reminded that
|
||||
`SQLite is not recommended for use outside a test
|
||||
environment <https://github.com/matrix-org/synapse/blob/master/README.rst#using-postgresql>`_.
|
||||
|
||||
3. Once the index has been created, the ``SELECT`` query in step 1 above should
|
||||
complete quickly. It is therefore safe to upgrade to Synapse 1.12.0.
|
||||
|
||||
4. Once Synapse 1.12.0 has successfully started and is responding to HTTP
|
||||
requests, the temporary index can be removed:
|
||||
|
||||
.. code:: sql
|
||||
|
||||
DROP INDEX tmp_upgrade_1_12_0_index;
|
||||
|
||||
Upgrading to v1.10.0
|
||||
====================
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Add support for running replication over Redis when using workers.
|
||||
@@ -1 +0,0 @@
|
||||
Improve the documentation of application service configuration files.
|
||||
@@ -1 +0,0 @@
|
||||
Run replication streamers on workers.
|
||||
1
changelog.d/7160.feature
Normal file
1
changelog.d/7160.feature
Normal file
@@ -0,0 +1 @@
|
||||
Always send users their own device updates.
|
||||
@@ -1 +0,0 @@
|
||||
Add explicit Python build tooling as dependencies for the snapcraft build.
|
||||
@@ -1 +0,0 @@
|
||||
Extend room admin api (`GET /_synapse/admin/v1/rooms`) with additional attributes.
|
||||
1
changelog.d/7255.bugfix
Normal file
1
changelog.d/7255.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug that prevented cross-signing with users on worker-mode synapses.
|
||||
@@ -1 +0,0 @@
|
||||
Reject unknown session IDs during user interactive authentication instead of silently creating a new session.
|
||||
@@ -1 +0,0 @@
|
||||
Documentation of media_storage_providers options updated to avoid misunderstandings. Contributed by Tristan Lins.
|
||||
@@ -1 +0,0 @@
|
||||
Add some unit tests for replication.
|
||||
@@ -1 +0,0 @@
|
||||
Support SSO in the user interactive authentication workflow.
|
||||
@@ -1 +0,0 @@
|
||||
Move catchup of replication streams logic to worker.
|
||||
1
changelog.d/7289.bugfix
Normal file
1
changelog.d/7289.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug with cross-signing devices with remote users when they did not share a room with any user on the local homeserver.
|
||||
@@ -1 +0,0 @@
|
||||
Move catchup of replication streams logic to worker.
|
||||
@@ -1 +0,0 @@
|
||||
Improve typing annotations in `synapse.replication.tcp.streams.Stream`.
|
||||
@@ -1 +0,0 @@
|
||||
Reduce log verbosity of url cache cleanup tasks.
|
||||
@@ -1 +0,0 @@
|
||||
Fix sample SAML Service Provider configuration. Contributed by @frcl.
|
||||
@@ -1 +0,0 @@
|
||||
Fix StreamChangeCache to work with multiple entities changing on the same stream id.
|
||||
@@ -1 +0,0 @@
|
||||
Allow `/requestToken` endpoints to hide the existence (or lack thereof) of 3PID associations on the homeserver.
|
||||
@@ -1 +0,0 @@
|
||||
Move catchup of replication streams logic to worker.
|
||||
@@ -1 +0,0 @@
|
||||
Fix an incorrect import in IdentityHandler.
|
||||
@@ -1 +0,0 @@
|
||||
Reduce logging verbosity for successful federation requests.
|
||||
@@ -1 +0,0 @@
|
||||
Add support for running replication over Redis when using workers.
|
||||
@@ -1 +0,0 @@
|
||||
Move catchup of replication streams logic to worker.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
|
||||
@@ -1 +0,0 @@
|
||||
Convert some federation handler code to async/await.
|
||||
@@ -1 +0,0 @@
|
||||
Fix bad error handling that would cause Synapse to crash if it's provided with a YAML configuration file that's either empty or doesn't parse into a key-value map.
|
||||
@@ -1 +0,0 @@
|
||||
Support SSO in the user interactive authentication workflow.
|
||||
@@ -1 +0,0 @@
|
||||
Fix incorrect metrics reporting for `renew_attestations` background task.
|
||||
@@ -1 +0,0 @@
|
||||
Add support for running replication over Redis when using workers.
|
||||
@@ -1 +0,0 @@
|
||||
Add documentation on monitoring workers with Prometheus.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
|
||||
@@ -1 +0,0 @@
|
||||
Fix collation for postgres for unit tests.
|
||||
8
debian/changelog
vendored
8
debian/changelog
vendored
@@ -1,16 +1,8 @@
|
||||
<<<<<<< HEAD
|
||||
matrix-synapse-py3 (1.12.3ubuntu1) UNRELEASED; urgency=medium
|
||||
|
||||
* Add information about .well-known files to Debian installation scripts.
|
||||
|
||||
-- Patrick Cloke <patrickc@matrix.org> Mon, 06 Apr 2020 10:10:38 -0400
|
||||
=======
|
||||
matrix-synapse-py3 (1.12.4) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.12.4.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Thu, 23 Apr 2020 10:58:14 -0400
|
||||
>>>>>>> master
|
||||
|
||||
matrix-synapse-py3 (1.12.3) stable; urgency=medium
|
||||
|
||||
|
||||
@@ -11,21 +11,8 @@ The following query parameters are available:
|
||||
* `from` - Offset in the returned list. Defaults to `0`.
|
||||
* `limit` - Maximum amount of rooms to return. Defaults to `100`.
|
||||
* `order_by` - The method in which to sort the returned list of rooms. Valid values are:
|
||||
- `alphabetical` - Same as `name`. This is deprecated.
|
||||
- `size` - Same as `joined_members`. This is deprecated.
|
||||
- `name` - Rooms are ordered alphabetically by room name. This is the default.
|
||||
- `canonical_alias` - Rooms are ordered alphabetically by main alias address of the room.
|
||||
- `joined_members` - Rooms are ordered by the number of members. Largest to smallest.
|
||||
- `joined_local_members` - Rooms are ordered by the number of local members. Largest to smallest.
|
||||
- `version` - Rooms are ordered by room version. Largest to smallest.
|
||||
- `creator` - Rooms are ordered alphabetically by creator of the room.
|
||||
- `encryption` - Rooms are ordered alphabetically by the end-to-end encryption algorithm.
|
||||
- `federatable` - Rooms are ordered by whether the room is federatable.
|
||||
- `public` - Rooms are ordered by visibility in room list.
|
||||
- `join_rules` - Rooms are ordered alphabetically by join rules of the room.
|
||||
- `guest_access` - Rooms are ordered alphabetically by guest access option of the room.
|
||||
- `history_visibility` - Rooms are ordered alphabetically by visibility of history of the room.
|
||||
- `state_events` - Rooms are ordered by number of state events. Largest to smallest.
|
||||
- `alphabetical` - Rooms are ordered alphabetically by room name. This is the default.
|
||||
- `size` - Rooms are ordered by the number of members. Largest to smallest.
|
||||
* `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting
|
||||
this value to `b` will reverse the above sort order. Defaults to `f`.
|
||||
* `search_term` - Filter rooms by their room name. Search term can be contained in any
|
||||
@@ -39,16 +26,6 @@ The following fields are possible in the JSON response body:
|
||||
- `name` - The name of the room.
|
||||
- `canonical_alias` - The canonical (main) alias address of the room.
|
||||
- `joined_members` - How many users are currently in the room.
|
||||
- `joined_local_members` - How many local users are currently in the room.
|
||||
- `version` - The version of the room as a string.
|
||||
- `creator` - The `user_id` of the room creator.
|
||||
- `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active.
|
||||
- `federatable` - Whether users on other servers can join this room.
|
||||
- `public` - Whether the room is visible in room directory.
|
||||
- `join_rules` - The type of rules used for users wishing to join this room. One of: ["public", "knock", "invite", "private"].
|
||||
- `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"].
|
||||
- `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"].
|
||||
- `state_events` - Total number of state_events of a room. Complexity of the room.
|
||||
* `offset` - The current pagination offset in rooms. This parameter should be
|
||||
used instead of `next_token` for room offset as `next_token` is
|
||||
not intended to be parsed.
|
||||
@@ -83,34 +60,14 @@ Response:
|
||||
"room_id": "!OGEhHVWSdvArJzumhm:matrix.org",
|
||||
"name": "Matrix HQ",
|
||||
"canonical_alias": "#matrix:matrix.org",
|
||||
"joined_members": 8326,
|
||||
"joined_local_members": 2,
|
||||
"version": "1",
|
||||
"creator": "@foo:matrix.org",
|
||||
"encryption": null,
|
||||
"federatable": true,
|
||||
"public": true,
|
||||
"join_rules": "invite",
|
||||
"guest_access": null,
|
||||
"history_visibility": "shared",
|
||||
"state_events": 93534
|
||||
"joined_members": 8326
|
||||
},
|
||||
... (8 hidden items) ...
|
||||
{
|
||||
"room_id": "!xYvNcQPhnkrdUmYczI:matrix.org",
|
||||
"name": "This Week In Matrix (TWIM)",
|
||||
"canonical_alias": "#twim:matrix.org",
|
||||
"joined_members": 314,
|
||||
"joined_local_members": 20,
|
||||
"version": "4",
|
||||
"creator": "@foo:matrix.org",
|
||||
"encryption": "m.megolm.v1.aes-sha2",
|
||||
"federatable": true,
|
||||
"public": false,
|
||||
"join_rules": "invite",
|
||||
"guest_access": null,
|
||||
"history_visibility": "shared",
|
||||
"state_events": 8345
|
||||
"joined_members": 314
|
||||
}
|
||||
],
|
||||
"offset": 0,
|
||||
@@ -135,17 +92,7 @@ Response:
|
||||
"room_id": "!xYvNcQPhnkrdUmYczI:matrix.org",
|
||||
"name": "This Week In Matrix (TWIM)",
|
||||
"canonical_alias": "#twim:matrix.org",
|
||||
"joined_members": 314,
|
||||
"joined_local_members": 20,
|
||||
"version": "4",
|
||||
"creator": "@foo:matrix.org",
|
||||
"encryption": "m.megolm.v1.aes-sha2",
|
||||
"federatable": true,
|
||||
"public": false,
|
||||
"join_rules": "invite",
|
||||
"guest_access": null,
|
||||
"history_visibility": "shared",
|
||||
"state_events": 8
|
||||
"joined_members": 314
|
||||
}
|
||||
],
|
||||
"offset": 0,
|
||||
@@ -170,34 +117,14 @@ Response:
|
||||
"room_id": "!OGEhHVWSdvArJzumhm:matrix.org",
|
||||
"name": "Matrix HQ",
|
||||
"canonical_alias": "#matrix:matrix.org",
|
||||
"joined_members": 8326,
|
||||
"joined_local_members": 2,
|
||||
"version": "1",
|
||||
"creator": "@foo:matrix.org",
|
||||
"encryption": null,
|
||||
"federatable": true,
|
||||
"public": true,
|
||||
"join_rules": "invite",
|
||||
"guest_access": null,
|
||||
"history_visibility": "shared",
|
||||
"state_events": 93534
|
||||
"joined_members": 8326
|
||||
},
|
||||
... (98 hidden items) ...
|
||||
{
|
||||
"room_id": "!xYvNcQPhnkrdUmYczI:matrix.org",
|
||||
"name": "This Week In Matrix (TWIM)",
|
||||
"canonical_alias": "#twim:matrix.org",
|
||||
"joined_members": 314,
|
||||
"joined_local_members": 20,
|
||||
"version": "4",
|
||||
"creator": "@foo:matrix.org",
|
||||
"encryption": "m.megolm.v1.aes-sha2",
|
||||
"federatable": true,
|
||||
"public": false,
|
||||
"join_rules": "invite",
|
||||
"guest_access": null,
|
||||
"history_visibility": "shared",
|
||||
"state_events": 8345
|
||||
"joined_members": 314
|
||||
}
|
||||
],
|
||||
"offset": 0,
|
||||
@@ -227,16 +154,6 @@ Response:
|
||||
"name": "Music Theory",
|
||||
"canonical_alias": "#musictheory:matrix.org",
|
||||
"joined_members": 127
|
||||
"joined_local_members": 2,
|
||||
"version": "1",
|
||||
"creator": "@foo:matrix.org",
|
||||
"encryption": null,
|
||||
"federatable": true,
|
||||
"public": true,
|
||||
"join_rules": "invite",
|
||||
"guest_access": null,
|
||||
"history_visibility": "shared",
|
||||
"state_events": 93534
|
||||
},
|
||||
... (48 hidden items) ...
|
||||
{
|
||||
@@ -244,16 +161,6 @@ Response:
|
||||
"name": "weechat-matrix",
|
||||
"canonical_alias": "#weechat-matrix:termina.org.uk",
|
||||
"joined_members": 137
|
||||
"joined_local_members": 20,
|
||||
"version": "4",
|
||||
"creator": "@foo:termina.org.uk",
|
||||
"encryption": null,
|
||||
"federatable": true,
|
||||
"public": true,
|
||||
"join_rules": "invite",
|
||||
"guest_access": null,
|
||||
"history_visibility": "shared",
|
||||
"state_events": 8345
|
||||
}
|
||||
],
|
||||
"offset": 100,
|
||||
|
||||
@@ -23,13 +23,9 @@ namespaces:
|
||||
users: # List of users we're interested in
|
||||
- exclusive: <bool>
|
||||
regex: <regex>
|
||||
group_id: <group>
|
||||
- ...
|
||||
aliases: [] # List of aliases we're interested in
|
||||
rooms: [] # List of room ids we're interested in
|
||||
```
|
||||
|
||||
`exclusive`: If enabled, only this application service is allowed to register users in its namespace(s).
|
||||
`group_id`: All users of this application service are dynamically joined to this group. This is useful for e.g user organisation or flairs.
|
||||
|
||||
See the [spec](https://matrix.org/docs/spec/application_service/unstable.html) for further details on how application services work.
|
||||
|
||||
@@ -60,31 +60,6 @@
|
||||
|
||||
1. Restart Prometheus.
|
||||
|
||||
## Monitoring workers
|
||||
|
||||
To monitor a Synapse installation using
|
||||
[workers](https://github.com/matrix-org/synapse/blob/master/docs/workers.md),
|
||||
every worker needs to be monitored independently, in addition to
|
||||
the main homeserver process. This is because workers don't send
|
||||
their metrics to the main homeserver process, but expose them
|
||||
directly (if they are configured to do so).
|
||||
|
||||
To allow collecting metrics from a worker, you need to add a
|
||||
`metrics` listener to its configuration, by adding the following
|
||||
under `worker_listeners`:
|
||||
|
||||
```yaml
|
||||
- type: metrics
|
||||
bind_address: ''
|
||||
port: 9101
|
||||
```
|
||||
|
||||
The `bind_address` and `port` parameters should be set so that
|
||||
the resulting listener can be reached by prometheus, and they
|
||||
don't clash with an existing worker.
|
||||
With this example, the worker's metrics would then be available
|
||||
on `http://127.0.0.1:9101`.
|
||||
|
||||
## Renaming of metrics & deprecation of old names in 1.2
|
||||
|
||||
Synapse 1.2 updates the Prometheus metrics to match the naming
|
||||
|
||||
@@ -414,16 +414,6 @@ retention:
|
||||
# longest_max_lifetime: 1y
|
||||
# interval: 1d
|
||||
|
||||
# Inhibits the /requestToken endpoints from returning an error that might leak
|
||||
# information about whether an e-mail address is in use or not on this
|
||||
# homeserver.
|
||||
# Note that for some endpoints the error situation is the e-mail already being
|
||||
# used, and for others the error is entering the e-mail being unused.
|
||||
# If this option is enabled, instead of returning an error, these endpoints will
|
||||
# act as if no error happened and return a fake session ID ('sid') to clients.
|
||||
#
|
||||
#request_token_inhibit_3pid_errors: true
|
||||
|
||||
|
||||
## TLS ##
|
||||
|
||||
@@ -745,11 +735,12 @@ media_store_path: "DATADIR/media_store"
|
||||
#
|
||||
#media_storage_providers:
|
||||
# - module: file_system
|
||||
# # Whether to store newly uploaded local files
|
||||
# # Whether to write new local files.
|
||||
# store_local: false
|
||||
# # Whether to store newly downloaded remote files
|
||||
# # Whether to write new remote media
|
||||
# store_remote: false
|
||||
# # Whether to wait for successful storage for local uploads
|
||||
# # Whether to block upload requests waiting for write to this
|
||||
# # provider to complete
|
||||
# store_synchronous: false
|
||||
# config:
|
||||
# directory: /mnt/some/other/directory
|
||||
@@ -1349,32 +1340,32 @@ saml2_config:
|
||||
# remote:
|
||||
# - url: https://our_idp/metadata.xml
|
||||
#
|
||||
# # By default, the user has to go to our login page first. If you'd like
|
||||
# # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
|
||||
# # 'service.sp' section:
|
||||
# #
|
||||
# #service:
|
||||
# # sp:
|
||||
# # allow_unsolicited: true
|
||||
# # By default, the user has to go to our login page first. If you'd like
|
||||
# # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
|
||||
# # 'service.sp' section:
|
||||
# #
|
||||
# #service:
|
||||
# # sp:
|
||||
# # allow_unsolicited: true
|
||||
#
|
||||
# # The examples below are just used to generate our metadata xml, and you
|
||||
# # may well not need them, depending on your setup. Alternatively you
|
||||
# # may need a whole lot more detail - see the pysaml2 docs!
|
||||
# # The examples below are just used to generate our metadata xml, and you
|
||||
# # may well not need them, depending on your setup. Alternatively you
|
||||
# # may need a whole lot more detail - see the pysaml2 docs!
|
||||
#
|
||||
# description: ["My awesome SP", "en"]
|
||||
# name: ["Test SP", "en"]
|
||||
# description: ["My awesome SP", "en"]
|
||||
# name: ["Test SP", "en"]
|
||||
#
|
||||
# organization:
|
||||
# name: Example com
|
||||
# display_name:
|
||||
# - ["Example co", "en"]
|
||||
# url: "http://example.com"
|
||||
# organization:
|
||||
# name: Example com
|
||||
# display_name:
|
||||
# - ["Example co", "en"]
|
||||
# url: "http://example.com"
|
||||
#
|
||||
# contact_person:
|
||||
# - given_name: Bob
|
||||
# sur_name: "the Sysadmin"
|
||||
# email_address": ["admin@example.com"]
|
||||
# contact_type": technical
|
||||
# contact_person:
|
||||
# - given_name: Bob
|
||||
# sur_name: "the Sysadmin"
|
||||
# email_address": ["admin@example.com"]
|
||||
# contact_type": technical
|
||||
|
||||
# Instead of putting the config inline as above, you can specify a
|
||||
# separate pysaml2 configuration file:
|
||||
@@ -1518,30 +1509,6 @@ sso:
|
||||
#
|
||||
# * server_name: the homeserver's name.
|
||||
#
|
||||
# * HTML page which notifies the user that they are authenticating to confirm
|
||||
# an operation on their account during the user interactive authentication
|
||||
# process: 'sso_auth_confirm.html'.
|
||||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
# * redirect_url: the URL the user is about to be redirected to. Needs
|
||||
# manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
#
|
||||
# * description: the operation which the user is being asked to confirm
|
||||
#
|
||||
# * HTML page shown after a successful user interactive authentication session:
|
||||
# 'sso_auth_success.html'.
|
||||
#
|
||||
# Note that this page must include the JavaScript which notifies of a successful authentication
|
||||
# (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
|
||||
#
|
||||
# This template has no additional variables.
|
||||
#
|
||||
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
|
||||
# attempts to login: 'sso_account_deactivated.html'.
|
||||
#
|
||||
# This template has no additional variables.
|
||||
#
|
||||
# You can see the default templates at:
|
||||
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
|
||||
#
|
||||
|
||||
@@ -196,7 +196,7 @@ Asks the server for the current position of all streams.
|
||||
|
||||
#### USER_SYNC (C)
|
||||
|
||||
A user has started or stopped syncing on this process.
|
||||
A user has started or stopped syncing
|
||||
|
||||
#### CLEAR_USER_SYNC (C)
|
||||
|
||||
@@ -216,6 +216,10 @@ Asks the server for the current position of all streams.
|
||||
|
||||
Inform the server a cache should be invalidated
|
||||
|
||||
#### SYNC (S, C)
|
||||
|
||||
Used exclusively in tests
|
||||
|
||||
### REMOTE_SERVER_UP (S, C)
|
||||
|
||||
Inform other processes that a remote server may have come back online.
|
||||
|
||||
@@ -120,7 +120,7 @@ Your home server configuration file needs the following extra keys:
|
||||
As an example, here is the relevant section of the config file for matrix.org:
|
||||
|
||||
turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ]
|
||||
turn_shared_secret: "n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons"
|
||||
turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons
|
||||
turn_user_lifetime: 86400000
|
||||
turn_allow_guests: True
|
||||
|
||||
|
||||
@@ -286,8 +286,6 @@ Additionally, the following REST endpoints can be handled for GET requests:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|unstable)/pushrules/.*$
|
||||
^/_matrix/client/(api/v1|r0|unstable)/groups/.*$
|
||||
^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/account_data/
|
||||
^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/rooms/[^/]*/account_data/
|
||||
|
||||
Additionally, the following REST endpoints can be handled, but all requests must
|
||||
be routed to the same instance:
|
||||
|
||||
@@ -33,10 +33,6 @@ parts:
|
||||
python-version: python3
|
||||
python-packages:
|
||||
- '.[all]'
|
||||
- pip
|
||||
- setuptools
|
||||
- setuptools-scm
|
||||
- wheel
|
||||
build-packages:
|
||||
- libffi-dev
|
||||
- libturbojpeg0-dev
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from .sorteddict import (
|
||||
SortedDict,
|
||||
SortedKeysView,
|
||||
SortedItemsView,
|
||||
SortedValuesView,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SortedDict",
|
||||
"SortedKeysView",
|
||||
"SortedItemsView",
|
||||
"SortedValuesView",
|
||||
]
|
||||
@@ -1,124 +0,0 @@
|
||||
# stub for SortedDict. This is a lightly edited copy of
|
||||
# https://github.com/grantjenks/python-sortedcontainers/blob/eea42df1f7bad2792e8da77335ff888f04b9e5ae/sortedcontainers/sorteddict.pyi
|
||||
# (from https://github.com/grantjenks/python-sortedcontainers/pull/107)
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Hashable,
|
||||
Iterator,
|
||||
Iterable,
|
||||
ItemsView,
|
||||
KeysView,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Tuple,
|
||||
Union,
|
||||
ValuesView,
|
||||
overload,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_S = TypeVar("_S")
|
||||
_T_h = TypeVar("_T_h", bound=Hashable)
|
||||
_KT = TypeVar("_KT", bound=Hashable) # Key type.
|
||||
_VT = TypeVar("_VT") # Value type.
|
||||
_KT_co = TypeVar("_KT_co", covariant=True, bound=Hashable)
|
||||
_VT_co = TypeVar("_VT_co", covariant=True)
|
||||
_SD = TypeVar("_SD", bound=SortedDict)
|
||||
_Key = Callable[[_T], Any]
|
||||
|
||||
class SortedDict(Dict[_KT, _VT]):
|
||||
@overload
|
||||
def __init__(self, **kwargs: _VT) -> None: ...
|
||||
@overload
|
||||
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, __key: _Key[_KT], **kwargs: _VT) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, __key: _Key[_KT], __map: Mapping[_KT, _VT], **kwargs: _VT
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, __key: _Key[_KT], __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
|
||||
) -> None: ...
|
||||
@property
|
||||
def key(self) -> Optional[_Key[_KT]]: ...
|
||||
@property
|
||||
def iloc(self) -> SortedKeysView[_KT]: ...
|
||||
def clear(self) -> None: ...
|
||||
def __delitem__(self, key: _KT) -> None: ...
|
||||
def __iter__(self) -> Iterator[_KT]: ...
|
||||
def __reversed__(self) -> Iterator[_KT]: ...
|
||||
def __setitem__(self, key: _KT, value: _VT) -> None: ...
|
||||
def _setitem(self, key: _KT, value: _VT) -> None: ...
|
||||
def copy(self: _SD) -> _SD: ...
|
||||
def __copy__(self: _SD) -> _SD: ...
|
||||
@classmethod
|
||||
@overload
|
||||
def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ...
|
||||
@classmethod
|
||||
@overload
|
||||
def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ...
|
||||
def keys(self) -> SortedKeysView[_KT]: ...
|
||||
def items(self) -> SortedItemsView[_KT, _VT]: ...
|
||||
def values(self) -> SortedValuesView[_VT]: ...
|
||||
@overload
|
||||
def pop(self, key: _KT) -> _VT: ...
|
||||
@overload
|
||||
def pop(self, key: _KT, default: _T = ...) -> Union[_VT, _T]: ...
|
||||
def popitem(self, index: int = ...) -> Tuple[_KT, _VT]: ...
|
||||
def peekitem(self, index: int = ...) -> Tuple[_KT, _VT]: ...
|
||||
def setdefault(self, key: _KT, default: Optional[_VT] = ...) -> _VT: ...
|
||||
@overload
|
||||
def update(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
|
||||
@overload
|
||||
def update(self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT) -> None: ...
|
||||
@overload
|
||||
def update(self, **kwargs: _VT) -> None: ...
|
||||
def __reduce__(
|
||||
self,
|
||||
) -> Tuple[
|
||||
Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
|
||||
]: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def _check(self) -> None: ...
|
||||
def islice(
|
||||
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
|
||||
) -> Iterator[_KT]: ...
|
||||
def bisect_left(self, value: _KT) -> int: ...
|
||||
def bisect_right(self, value: _KT) -> int: ...
|
||||
|
||||
class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]):
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> _KT_co: ...
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> List[_KT_co]: ...
|
||||
def __delitem__(self, index: Union[int, slice]) -> None: ...
|
||||
|
||||
class SortedItemsView( # type: ignore
|
||||
ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]
|
||||
):
|
||||
def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ...
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ...
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> List[Tuple[_KT_co, _VT_co]]: ...
|
||||
def __delitem__(self, index: Union[int, slice]) -> None: ...
|
||||
|
||||
class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]):
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> _VT_co: ...
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> List[_VT_co]: ...
|
||||
def __delitem__(self, index: Union[int, slice]) -> None: ...
|
||||
@@ -1,40 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains *incomplete* type hints for txredisapi.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
class RedisProtocol:
|
||||
def publish(self, channel: str, message: bytes): ...
|
||||
|
||||
class SubscriberProtocol:
|
||||
def subscribe(self, channels: Union[str, List[str]]): ...
|
||||
|
||||
def lazyConnection(
|
||||
host: str = ...,
|
||||
port: int = ...,
|
||||
dbid: Optional[int] = ...,
|
||||
reconnect: bool = ...,
|
||||
charset: str = ...,
|
||||
password: Optional[str] = ...,
|
||||
connectTimeout: Optional[int] = ...,
|
||||
replyTimeout: Optional[int] = ...,
|
||||
convertNumbers: bool = ...,
|
||||
) -> RedisProtocol: ...
|
||||
|
||||
class SubscriberFactory:
|
||||
def buildProtocol(self, addr): ...
|
||||
@@ -36,7 +36,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.12.4"
|
||||
__version__ = "1.12.3"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
||||
@@ -97,8 +97,6 @@ class EventTypes(object):
|
||||
|
||||
Retention = "m.room.retention"
|
||||
|
||||
Presence = "m.presence"
|
||||
|
||||
|
||||
class RejectedReason(object):
|
||||
AUTH_ERROR = "auth_error"
|
||||
|
||||
@@ -17,9 +17,6 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, Iterable
|
||||
|
||||
from typing_extensions import ContextManager
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.web.resource import NoResource
|
||||
@@ -41,14 +38,14 @@ from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.federation import send_queue
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
from synapse.handlers.presence import BasePresenceHandler, get_interested_parties
|
||||
from synapse.handlers.presence import PresenceHandler, get_interested_parties
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||
@@ -113,10 +110,6 @@ from synapse.rest.client.v1.voip import VoipRestServlet
|
||||
from synapse.rest.client.v2_alpha import groups, sync, user_directory
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
|
||||
from synapse.rest.client.v2_alpha.account_data import (
|
||||
AccountDataServlet,
|
||||
RoomAccountDataServlet,
|
||||
)
|
||||
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
|
||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
from synapse.rest.client.versions import VersionsRestServlet
|
||||
@@ -228,32 +221,23 @@ class KeyUploadServlet(RestServlet):
|
||||
return 200, {"one_time_key_counts": result}
|
||||
|
||||
|
||||
class _NullContextManager(ContextManager[None]):
|
||||
"""A context manager which does nothing."""
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
UPDATE_SYNCING_USERS_MS = 10 * 1000
|
||||
|
||||
|
||||
class GenericWorkerPresence(BasePresenceHandler):
|
||||
class GenericWorkerPresence(object):
|
||||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
|
||||
self._presence_enabled = hs.config.use_presence
|
||||
|
||||
# The number of ongoing syncs on this process, by user id.
|
||||
# Empty if _presence_enabled is false.
|
||||
self._user_to_num_current_syncs = {} # type: Dict[str, int]
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.user_to_num_current_syncs = {}
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.instance_id = hs.get_instance_id()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
self.user_to_current_state = {state.user_id: state for state in active_presence}
|
||||
|
||||
# user_id -> last_sync_ms. Lists the users that have stopped syncing
|
||||
# but we haven't notified the master of that yet
|
||||
self.users_going_offline = {}
|
||||
@@ -271,13 +255,13 @@ class GenericWorkerPresence(BasePresenceHandler):
|
||||
)
|
||||
|
||||
def _on_shutdown(self):
|
||||
if self._presence_enabled:
|
||||
if self.hs.config.use_presence:
|
||||
self.hs.get_tcp_replication().send_command(
|
||||
ClearUserSyncsCommand(self.instance_id)
|
||||
)
|
||||
|
||||
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||
if self._presence_enabled:
|
||||
if self.hs.config.use_presence:
|
||||
self.hs.get_tcp_replication().send_user_sync(
|
||||
self.instance_id, user_id, is_syncing, last_sync_ms
|
||||
)
|
||||
@@ -319,33 +303,28 @@ class GenericWorkerPresence(BasePresenceHandler):
|
||||
# TODO Hows this supposed to work?
|
||||
return defer.succeed(None)
|
||||
|
||||
async def user_syncing(
|
||||
self, user_id: str, affect_presence: bool
|
||||
) -> ContextManager[None]:
|
||||
"""Record that a user is syncing.
|
||||
get_states = __func__(PresenceHandler.get_states)
|
||||
get_state = __func__(PresenceHandler.get_state)
|
||||
current_state_for_users = __func__(PresenceHandler.current_state_for_users)
|
||||
|
||||
Called by the sync and events servlets to record that a user has connected to
|
||||
this worker and is waiting for some events.
|
||||
"""
|
||||
if not affect_presence or not self._presence_enabled:
|
||||
return _NullContextManager()
|
||||
def user_syncing(self, user_id, affect_presence):
|
||||
if affect_presence:
|
||||
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
|
||||
self.user_to_num_current_syncs[user_id] = curr_sync + 1
|
||||
|
||||
curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
|
||||
self._user_to_num_current_syncs[user_id] = curr_sync + 1
|
||||
|
||||
# If we went from no in flight sync to some, notify replication
|
||||
if self._user_to_num_current_syncs[user_id] == 1:
|
||||
self.mark_as_coming_online(user_id)
|
||||
# If we went from no in flight sync to some, notify replication
|
||||
if self.user_to_num_current_syncs[user_id] == 1:
|
||||
self.mark_as_coming_online(user_id)
|
||||
|
||||
def _end():
|
||||
# We check that the user_id is in user_to_num_current_syncs because
|
||||
# user_to_num_current_syncs may have been cleared if we are
|
||||
# shutting down.
|
||||
if user_id in self._user_to_num_current_syncs:
|
||||
self._user_to_num_current_syncs[user_id] -= 1
|
||||
if affect_presence and user_id in self.user_to_num_current_syncs:
|
||||
self.user_to_num_current_syncs[user_id] -= 1
|
||||
|
||||
# If we went from one in flight sync to non, notify replication
|
||||
if self._user_to_num_current_syncs[user_id] == 0:
|
||||
if self.user_to_num_current_syncs[user_id] == 0:
|
||||
self.mark_as_going_offline(user_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -355,7 +334,7 @@ class GenericWorkerPresence(BasePresenceHandler):
|
||||
finally:
|
||||
_end()
|
||||
|
||||
return _user_syncing()
|
||||
return defer.succeed(_user_syncing())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_from_replication(self, states, stream_id):
|
||||
@@ -390,12 +369,15 @@ class GenericWorkerPresence(BasePresenceHandler):
|
||||
stream_id = token
|
||||
yield self.notify_from_replication(states, stream_id)
|
||||
|
||||
def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
|
||||
return [
|
||||
user_id
|
||||
for user_id, count in self._user_to_num_current_syncs.items()
|
||||
if count > 0
|
||||
]
|
||||
def get_currently_syncing_users(self):
|
||||
if self.hs.config.use_presence:
|
||||
return [
|
||||
user_id
|
||||
for user_id, count in self.user_to_num_current_syncs.items()
|
||||
if count > 0
|
||||
]
|
||||
else:
|
||||
return set()
|
||||
|
||||
|
||||
class GenericWorkerTyping(object):
|
||||
@@ -519,8 +501,6 @@ class GenericWorkerServer(HomeServer):
|
||||
ProfileDisplaynameRestServlet(self).register(resource)
|
||||
ProfileRestServlet(self).register(resource)
|
||||
KeyUploadServlet(self).register(resource)
|
||||
AccountDataServlet(self).register(resource)
|
||||
RoomAccountDataServlet(self).register(resource)
|
||||
|
||||
sync.register_servlets(self, resource)
|
||||
events.register_servlets(self, resource)
|
||||
@@ -639,7 +619,8 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.typing_handler = hs.get_typing_handler()
|
||||
self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence
|
||||
# NB this is a SynchrotronPresence, not a normal PresenceHandler
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
self.notify_pushers = hs.config.start_pushers
|
||||
@@ -960,22 +941,17 @@ def start(config_options):
|
||||
|
||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
hs = GenericWorkerServer(
|
||||
ss = GenericWorkerServer(
|
||||
config.server_name,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
)
|
||||
|
||||
setup_logging(hs, config, use_worker_options=True)
|
||||
|
||||
hs.setup()
|
||||
|
||||
# Ensure the replication streamer is always started in case we write to any
|
||||
# streams. Will no-op if no streams can be written to by this worker.
|
||||
hs.get_replication_streamer()
|
||||
setup_logging(ss, config, use_worker_options=True)
|
||||
|
||||
ss.setup()
|
||||
reactor.addSystemEventTrigger(
|
||||
"before", "startup", _base.start, hs, config.worker_listeners
|
||||
"before", "startup", _base.start, ss, config.worker_listeners
|
||||
)
|
||||
|
||||
_base.start_worker_reactor("synapse-generic-worker", config)
|
||||
|
||||
@@ -273,12 +273,6 @@ class SynapseHomeServer(HomeServer):
|
||||
def start_listening(self, listeners):
|
||||
config = self.get_config()
|
||||
|
||||
if config.redis_enabled:
|
||||
# If redis is enabled we connect via the replication command handler
|
||||
# in the same way as the workers (since we're effectively a client
|
||||
# rather than a server).
|
||||
self.get_tcp_replication().start_replication(self)
|
||||
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listening_services.extend(self._listener_http(config, listener))
|
||||
|
||||
@@ -657,12 +657,6 @@ def read_config_files(config_files):
|
||||
for config_file in config_files:
|
||||
with open(config_file) as file_stream:
|
||||
yaml_config = yaml.safe_load(file_stream)
|
||||
|
||||
if not isinstance(yaml_config, dict):
|
||||
err = "File %r is empty or doesn't parse into a key-value map. IGNORING."
|
||||
print(err % (config_file,))
|
||||
continue
|
||||
|
||||
specified_config.update(yaml_config)
|
||||
|
||||
if "server_name" not in specified_config:
|
||||
|
||||
@@ -138,7 +138,7 @@ class DatabaseConfig(Config):
|
||||
database_path = config.get("database_path")
|
||||
|
||||
if multi_database_config and database_config:
|
||||
raise ConfigError("Can't specify both 'database' and 'databases' in config")
|
||||
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
|
||||
|
||||
if multi_database_config:
|
||||
if database_path:
|
||||
|
||||
@@ -31,7 +31,6 @@ from .password import PasswordConfig
|
||||
from .password_auth_providers import PasswordAuthProviderConfig
|
||||
from .push import PushConfig
|
||||
from .ratelimiting import RatelimitConfig
|
||||
from .redis import RedisConfig
|
||||
from .registration import RegistrationConfig
|
||||
from .repository import ContentRepositoryConfig
|
||||
from .room_directory import RoomDirectoryConfig
|
||||
@@ -83,5 +82,4 @@ class HomeServerConfig(RootConfig):
|
||||
RoomDirectoryConfig,
|
||||
ThirdPartyRulesConfig,
|
||||
TracerConfig,
|
||||
RedisConfig,
|
||||
]
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.python_dependencies import check_requirements
|
||||
|
||||
|
||||
class RedisConfig(Config):
|
||||
section = "redis"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
redis_config = config.get("redis", {})
|
||||
self.redis_enabled = redis_config.get("enabled", False)
|
||||
|
||||
if not self.redis_enabled:
|
||||
return
|
||||
|
||||
check_requirements("redis")
|
||||
|
||||
self.redis_host = redis_config.get("host", "localhost")
|
||||
self.redis_port = redis_config.get("port", 6379)
|
||||
self.redis_dbid = redis_config.get("dbid")
|
||||
self.redis_password = redis_config.get("password")
|
||||
@@ -224,11 +224,12 @@ class ContentRepositoryConfig(Config):
|
||||
#
|
||||
#media_storage_providers:
|
||||
# - module: file_system
|
||||
# # Whether to store newly uploaded local files
|
||||
# # Whether to write new local files.
|
||||
# store_local: false
|
||||
# # Whether to store newly downloaded remote files
|
||||
# # Whether to write new remote media
|
||||
# store_remote: false
|
||||
# # Whether to wait for successful storage for local uploads
|
||||
# # Whether to block upload requests waiting for write to this
|
||||
# # provider to complete
|
||||
# store_synchronous: false
|
||||
# config:
|
||||
# directory: /mnt/some/other/directory
|
||||
|
||||
@@ -248,32 +248,32 @@ class SAML2Config(Config):
|
||||
# remote:
|
||||
# - url: https://our_idp/metadata.xml
|
||||
#
|
||||
# # By default, the user has to go to our login page first. If you'd like
|
||||
# # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
|
||||
# # 'service.sp' section:
|
||||
# #
|
||||
# #service:
|
||||
# # sp:
|
||||
# # allow_unsolicited: true
|
||||
# # By default, the user has to go to our login page first. If you'd like
|
||||
# # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
|
||||
# # 'service.sp' section:
|
||||
# #
|
||||
# #service:
|
||||
# # sp:
|
||||
# # allow_unsolicited: true
|
||||
#
|
||||
# # The examples below are just used to generate our metadata xml, and you
|
||||
# # may well not need them, depending on your setup. Alternatively you
|
||||
# # may need a whole lot more detail - see the pysaml2 docs!
|
||||
# # The examples below are just used to generate our metadata xml, and you
|
||||
# # may well not need them, depending on your setup. Alternatively you
|
||||
# # may need a whole lot more detail - see the pysaml2 docs!
|
||||
#
|
||||
# description: ["My awesome SP", "en"]
|
||||
# name: ["Test SP", "en"]
|
||||
# description: ["My awesome SP", "en"]
|
||||
# name: ["Test SP", "en"]
|
||||
#
|
||||
# organization:
|
||||
# name: Example com
|
||||
# display_name:
|
||||
# - ["Example co", "en"]
|
||||
# url: "http://example.com"
|
||||
# organization:
|
||||
# name: Example com
|
||||
# display_name:
|
||||
# - ["Example co", "en"]
|
||||
# url: "http://example.com"
|
||||
#
|
||||
# contact_person:
|
||||
# - given_name: Bob
|
||||
# sur_name: "the Sysadmin"
|
||||
# email_address": ["admin@example.com"]
|
||||
# contact_type": technical
|
||||
# contact_person:
|
||||
# - given_name: Bob
|
||||
# sur_name: "the Sysadmin"
|
||||
# email_address": ["admin@example.com"]
|
||||
# contact_type": technical
|
||||
|
||||
# Instead of putting the config inline as above, you can specify a
|
||||
# separate pysaml2 configuration file:
|
||||
|
||||
@@ -507,17 +507,6 @@ class ServerConfig(Config):
|
||||
|
||||
self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False)
|
||||
|
||||
# Inhibits the /requestToken endpoints from returning an error that might leak
|
||||
# information about whether an e-mail address is in use or not on this
|
||||
# homeserver, and instead return a 200 with a fake sid if this kind of error is
|
||||
# met, without sending anything.
|
||||
# This is a compromise between sending an email, which could be a spam vector,
|
||||
# and letting the client know which email address is bound to an account and
|
||||
# which one isn't.
|
||||
self.request_token_inhibit_3pid_errors = config.get(
|
||||
"request_token_inhibit_3pid_errors", False,
|
||||
)
|
||||
|
||||
def has_tls_listener(self) -> bool:
|
||||
return any(l["tls"] for l in self.listeners)
|
||||
|
||||
@@ -983,16 +972,6 @@ class ServerConfig(Config):
|
||||
# - shortest_max_lifetime: 3d
|
||||
# longest_max_lifetime: 1y
|
||||
# interval: 1d
|
||||
|
||||
# Inhibits the /requestToken endpoints from returning an error that might leak
|
||||
# information about whether an e-mail address is in use or not on this
|
||||
# homeserver.
|
||||
# Note that for some endpoints the error situation is the e-mail already being
|
||||
# used, and for others the error is entering the e-mail being unused.
|
||||
# If this option is enabled, instead of returning an error, these endpoints will
|
||||
# act as if no error happened and return a fake session ID ('sid') to clients.
|
||||
#
|
||||
#request_token_inhibit_3pid_errors: true
|
||||
"""
|
||||
% locals()
|
||||
)
|
||||
|
||||
@@ -43,12 +43,6 @@ class SSOConfig(Config):
|
||||
),
|
||||
"sso_account_deactivated_template",
|
||||
)
|
||||
self.sso_auth_success_template = self.read_file(
|
||||
os.path.join(
|
||||
self.sso_redirect_confirm_template_dir, "sso_auth_success.html"
|
||||
),
|
||||
"sso_auth_success_template",
|
||||
)
|
||||
|
||||
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
|
||||
|
||||
@@ -113,30 +107,6 @@ class SSOConfig(Config):
|
||||
#
|
||||
# * server_name: the homeserver's name.
|
||||
#
|
||||
# * HTML page which notifies the user that they are authenticating to confirm
|
||||
# an operation on their account during the user interactive authentication
|
||||
# process: 'sso_auth_confirm.html'.
|
||||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
# * redirect_url: the URL the user is about to be redirected to. Needs
|
||||
# manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
#
|
||||
# * description: the operation which the user is being asked to confirm
|
||||
#
|
||||
# * HTML page shown after a successful user interactive authentication session:
|
||||
# 'sso_auth_success.html'.
|
||||
#
|
||||
# Note that this page must include the JavaScript which notifies of a successful authentication
|
||||
# (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
|
||||
#
|
||||
# This template has no additional variables.
|
||||
#
|
||||
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
|
||||
# attempts to login: 'sso_account_deactivated.html'.
|
||||
#
|
||||
# This template has no additional variables.
|
||||
#
|
||||
# You can see the default templates at:
|
||||
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
|
||||
#
|
||||
|
||||
@@ -399,24 +399,20 @@ class TransportLayerClient(object):
|
||||
{
|
||||
"device_keys": {
|
||||
"<user_id>": ["<device_id>"]
|
||||
}
|
||||
}
|
||||
} }
|
||||
|
||||
Response:
|
||||
{
|
||||
"device_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": {...}
|
||||
}
|
||||
},
|
||||
"master_key": {
|
||||
} }
|
||||
"master_keys": {
|
||||
"<user_id>": {...}
|
||||
}
|
||||
},
|
||||
"self_signing_key": {
|
||||
} }
|
||||
"self_signing_keys": {
|
||||
"<user_id>": {...}
|
||||
}
|
||||
}
|
||||
} } }
|
||||
|
||||
Args:
|
||||
destination(str): The server to query.
|
||||
@@ -440,22 +436,8 @@ class TransportLayerClient(object):
|
||||
{
|
||||
"stream_id": "...",
|
||||
"devices": [ { ... } ],
|
||||
"master_key": {
|
||||
"user_id": "<user_id>",
|
||||
"usage": [...],
|
||||
"keys": {...},
|
||||
"signatures": {
|
||||
"<user_id>": {...}
|
||||
}
|
||||
},
|
||||
"self_signing_key": {
|
||||
"user_id": "<user_id>",
|
||||
"usage": [...],
|
||||
"keys": {...},
|
||||
"signatures": {
|
||||
"<user_id>": {...}
|
||||
}
|
||||
}
|
||||
"master_key": { ... },
|
||||
"self_signing_key: { ... }
|
||||
}
|
||||
|
||||
Args:
|
||||
@@ -480,10 +462,8 @@ class TransportLayerClient(object):
|
||||
{
|
||||
"one_time_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": "<algorithm>"
|
||||
}
|
||||
}
|
||||
}
|
||||
"<device_id>": "<algorithm>"
|
||||
} } }
|
||||
|
||||
Response:
|
||||
{
|
||||
@@ -491,16 +471,13 @@ class TransportLayerClient(object):
|
||||
"<user_id>": {
|
||||
"<device_id>": {
|
||||
"<algorithm>:<key_id>": "<key_base64>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} } } }
|
||||
|
||||
Args:
|
||||
destination(str): The server to query.
|
||||
query_content(dict): The user ids to query.
|
||||
Returns:
|
||||
A dict containing the one-time keys.
|
||||
A dict containg the one-time keys.
|
||||
"""
|
||||
|
||||
path = _create_v1_path("/user/keys/claim")
|
||||
|
||||
@@ -37,16 +37,15 @@ An attestation is a signed blob of json that looks like:
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -163,19 +162,19 @@ class GroupAttestionRenewer(object):
|
||||
def _start_renew_attestations(self):
|
||||
return run_as_background_process("renew_attestations", self._renew_attestations)
|
||||
|
||||
async def _renew_attestations(self):
|
||||
@defer.inlineCallbacks
|
||||
def _renew_attestations(self):
|
||||
"""Called periodically to check if we need to update any of our attestations
|
||||
"""
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
rows = await self.store.get_attestations_need_renewals(
|
||||
rows = yield self.store.get_attestations_need_renewals(
|
||||
now + UPDATE_ATTESTATION_TIME_MS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _renew_attestation(group_user: Tuple[str, str]):
|
||||
group_id, user_id = group_user
|
||||
def _renew_attestation(group_id, user_id):
|
||||
try:
|
||||
if not self.is_mine_id(group_id):
|
||||
destination = get_domain_from_id(group_id)
|
||||
@@ -208,6 +207,8 @@ class GroupAttestionRenewer(object):
|
||||
"Error renewing attestation of %r in %r", user_id, group_id
|
||||
)
|
||||
|
||||
await yieldable_gather_results(
|
||||
_renew_attestation, ((row["group_id"], row["user_id"]) for row in rows)
|
||||
)
|
||||
for row in rows:
|
||||
group_id = row["group_id"]
|
||||
user_id = row["user_id"]
|
||||
|
||||
run_in_background(_renew_attestation, group_id, user_id)
|
||||
|
||||
@@ -51,6 +51,31 @@ from ._base import BaseHandler
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SUCCESS_TEMPLATE = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Success!</title>
|
||||
<meta name='viewport' content='width=device-width, initial-scale=1,
|
||||
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
|
||||
<script>
|
||||
if (window.onAuthDone) {
|
||||
window.onAuthDone();
|
||||
} else if (window.opener && window.opener.postMessage) {
|
||||
window.opener.postMessage("authDone", "*");
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<p>Thank you</p>
|
||||
<p>You may now close this window and return to the application</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
class AuthHandler(BaseHandler):
|
||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||
|
||||
@@ -134,11 +159,6 @@ class AuthHandler(BaseHandler):
|
||||
self._sso_auth_confirm_template = load_jinja2_templates(
|
||||
hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"],
|
||||
)[0]
|
||||
# The following template is shown after a successful user interactive
|
||||
# authentication session. It tells the user they can close the window.
|
||||
self._sso_auth_success_template = hs.config.sso_auth_success_template
|
||||
# The following template is shown during the SSO authentication process if
|
||||
# the account is deactivated.
|
||||
self._sso_account_deactivated_template = (
|
||||
hs.config.sso_account_deactivated_template
|
||||
)
|
||||
@@ -257,6 +277,10 @@ class AuthHandler(BaseHandler):
|
||||
Takes a dictionary sent by the client in the login / registration
|
||||
protocol and handles the User-Interactive Auth flow.
|
||||
|
||||
As a side effect, this function fills in the 'creds' key on the user's
|
||||
session with a map, which maps each auth-type (str) to the relevant
|
||||
identity authenticated by that auth-type (mostly str, but for captcha, bool).
|
||||
|
||||
If no auth flows have been completed successfully, raises an
|
||||
InteractiveAuthIncompleteError. To handle this, you can use
|
||||
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
|
||||
@@ -300,47 +324,50 @@ class AuthHandler(BaseHandler):
|
||||
del clientdict["auth"]
|
||||
if "session" in authdict:
|
||||
sid = authdict["session"]
|
||||
session = self._get_session_info(sid)
|
||||
|
||||
# If there's no session ID, create a new session.
|
||||
if not sid:
|
||||
session = self._create_session(
|
||||
clientdict, (request.uri, request.method, clientdict), description
|
||||
if len(clientdict) > 0:
|
||||
# This was designed to allow the client to omit the parameters
|
||||
# and just supply the session in subsequent calls so it split
|
||||
# auth between devices by just sharing the session, (eg. so you
|
||||
# could continue registration from your phone having clicked the
|
||||
# email auth link on there). It's probably too open to abuse
|
||||
# because it lets unauthenticated clients store arbitrary objects
|
||||
# on a homeserver.
|
||||
# Revisit: Assuming the REST APIs do sensible validation, the data
|
||||
# isn't arbintrary.
|
||||
session["clientdict"] = clientdict
|
||||
self._save_session(session)
|
||||
elif "clientdict" in session:
|
||||
clientdict = session["clientdict"]
|
||||
|
||||
# Ensure that the queried operation does not vary between stages of
|
||||
# the UI authentication session. This is done by generating a stable
|
||||
# comparator based on the URI, method, and body (minus the auth dict)
|
||||
# and storing it during the initial query. Subsequent queries ensure
|
||||
# that this comparator has not changed.
|
||||
comparator = (request.uri, request.method, clientdict)
|
||||
if "ui_auth" not in session:
|
||||
session["ui_auth"] = comparator
|
||||
self._save_session(session)
|
||||
elif session["ui_auth"] != comparator:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Requested operation has changed during the UI authentication session.",
|
||||
)
|
||||
session_id = session["id"]
|
||||
|
||||
else:
|
||||
session = self._get_session_info(sid)
|
||||
session_id = sid
|
||||
|
||||
if not clientdict:
|
||||
# This was designed to allow the client to omit the parameters
|
||||
# and just supply the session in subsequent calls so it split
|
||||
# auth between devices by just sharing the session, (eg. so you
|
||||
# could continue registration from your phone having clicked the
|
||||
# email auth link on there). It's probably too open to abuse
|
||||
# because it lets unauthenticated clients store arbitrary objects
|
||||
# on a homeserver.
|
||||
# Revisit: Assuming the REST APIs do sensible validation, the data
|
||||
# isn't arbitrary.
|
||||
clientdict = session["clientdict"]
|
||||
|
||||
# Ensure that the queried operation does not vary between stages of
|
||||
# the UI authentication session. This is done by generating a stable
|
||||
# comparator based on the URI, method, and body (minus the auth dict)
|
||||
# and storing it during the initial query. Subsequent queries ensure
|
||||
# that this comparator has not changed.
|
||||
comparator = (request.uri, request.method, clientdict)
|
||||
if session["ui_auth"] != comparator:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Requested operation has changed during the UI authentication session.",
|
||||
)
|
||||
# Add a human readable description to the session.
|
||||
if "description" not in session:
|
||||
session["description"] = description
|
||||
self._save_session(session)
|
||||
|
||||
if not authdict:
|
||||
raise InteractiveAuthIncompleteError(
|
||||
self._auth_dict_for_flows(flows, session_id)
|
||||
self._auth_dict_for_flows(flows, session)
|
||||
)
|
||||
|
||||
if "creds" not in session:
|
||||
session["creds"] = {}
|
||||
creds = session["creds"]
|
||||
|
||||
# check auth type currently being presented
|
||||
@@ -380,9 +407,9 @@ class AuthHandler(BaseHandler):
|
||||
list(clientdict),
|
||||
)
|
||||
|
||||
return creds, clientdict, session_id
|
||||
return creds, clientdict, session["id"]
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, session_id)
|
||||
ret = self._auth_dict_for_flows(flows, session)
|
||||
ret["completed"] = list(creds)
|
||||
ret.update(errordict)
|
||||
raise InteractiveAuthIncompleteError(ret)
|
||||
@@ -400,6 +427,8 @@ class AuthHandler(BaseHandler):
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
sess = self._get_session_info(authdict["session"])
|
||||
if "creds" not in sess:
|
||||
sess["creds"] = {}
|
||||
creds = sess["creds"]
|
||||
|
||||
result = await self.checkers[stagetype].check_auth(authdict, clientip)
|
||||
@@ -439,7 +468,7 @@ class AuthHandler(BaseHandler):
|
||||
value: The data to store
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
sess["serverdict"][key] = value
|
||||
sess.setdefault("serverdict", {})[key] = value
|
||||
self._save_session(sess)
|
||||
|
||||
def get_session_data(
|
||||
@@ -454,7 +483,7 @@ class AuthHandler(BaseHandler):
|
||||
default: Value to return if the key has not been set
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
return sess["serverdict"].get(key, default)
|
||||
return sess.setdefault("serverdict", {}).get(key, default)
|
||||
|
||||
async def _check_auth_dict(
|
||||
self, authdict: Dict[str, Any], clientip: str
|
||||
@@ -510,7 +539,7 @@ class AuthHandler(BaseHandler):
|
||||
}
|
||||
|
||||
def _auth_dict_for_flows(
|
||||
self, flows: List[List[str]], session_id: str,
|
||||
self, flows: List[List[str]], session: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
public_flows = []
|
||||
for f in flows:
|
||||
@@ -529,72 +558,29 @@ class AuthHandler(BaseHandler):
|
||||
params[stage] = get_params[stage]()
|
||||
|
||||
return {
|
||||
"session": session_id,
|
||||
"session": session["id"],
|
||||
"flows": [{"stages": f} for f in public_flows],
|
||||
"params": params,
|
||||
}
|
||||
|
||||
def _create_session(
|
||||
self,
|
||||
clientdict: Dict[str, Any],
|
||||
ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
|
||||
description: str,
|
||||
) -> dict:
|
||||
def _get_session_info(self, session_id: Optional[str]) -> dict:
|
||||
"""
|
||||
Creates a new user interactive authentication session.
|
||||
Gets or creates a session given a session ID.
|
||||
|
||||
The session can be used to track data across multiple requests, e.g. for
|
||||
interactive authentication.
|
||||
|
||||
Each session has the following keys:
|
||||
|
||||
id:
|
||||
A unique identifier for this session. Passed back to the client
|
||||
and returned for each stage.
|
||||
clientdict:
|
||||
The dictionary from the client root level, not the 'auth' key.
|
||||
ui_auth:
|
||||
A tuple which is checked at each stage of the authentication to
|
||||
ensure that the asked for operation has not changed.
|
||||
creds:
|
||||
A map, which maps each auth-type (str) to the relevant identity
|
||||
authenticated by that auth-type (mostly str, but for captcha, bool).
|
||||
serverdict:
|
||||
A map of data that is stored server-side and cannot be modified
|
||||
by the client.
|
||||
description:
|
||||
A string description of the operation that the current
|
||||
authentication is authorising.
|
||||
Returns:
|
||||
The newly created session.
|
||||
"""
|
||||
session_id = None
|
||||
while session_id is None or session_id in self.sessions:
|
||||
session_id = stringutils.random_string(24)
|
||||
if session_id not in self.sessions:
|
||||
session_id = None
|
||||
|
||||
self.sessions[session_id] = {
|
||||
"id": session_id,
|
||||
"clientdict": clientdict,
|
||||
"ui_auth": ui_auth,
|
||||
"creds": {},
|
||||
"serverdict": {},
|
||||
"description": description,
|
||||
}
|
||||
if not session_id:
|
||||
# create a new session
|
||||
while session_id is None or session_id in self.sessions:
|
||||
session_id = stringutils.random_string(24)
|
||||
self.sessions[session_id] = {"id": session_id}
|
||||
|
||||
return self.sessions[session_id]
|
||||
|
||||
def _get_session_info(self, session_id: str) -> dict:
|
||||
"""
|
||||
Gets a session given a session ID.
|
||||
|
||||
The session can be used to track data across multiple requests, e.g. for
|
||||
interactive authentication.
|
||||
"""
|
||||
try:
|
||||
return self.sessions[session_id]
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
async def get_access_token_for_user_id(
|
||||
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
|
||||
):
|
||||
@@ -1064,8 +1050,11 @@ class AuthHandler(BaseHandler):
|
||||
The HTML to render.
|
||||
"""
|
||||
session = self._get_session_info(session_id)
|
||||
# Get the human readable operation of what is occurring, falling back to
|
||||
# a generic message if it isn't available for some reason.
|
||||
description = session.get("description", "modify your account")
|
||||
return self._sso_auth_confirm_template.render(
|
||||
description=session["description"], redirect_url=redirect_url,
|
||||
description=description, redirect_url=redirect_url,
|
||||
)
|
||||
|
||||
def complete_sso_ui_auth(
|
||||
@@ -1081,6 +1070,8 @@ class AuthHandler(BaseHandler):
|
||||
"""
|
||||
# Mark the stage of the authentication as successful.
|
||||
sess = self._get_session_info(session_id)
|
||||
if "creds" not in sess:
|
||||
sess["creds"] = {}
|
||||
creds = sess["creds"]
|
||||
|
||||
# Save the user who authenticated with SSO, this will be used to ensure
|
||||
@@ -1089,7 +1080,7 @@ class AuthHandler(BaseHandler):
|
||||
self._save_session(sess)
|
||||
|
||||
# Render the HTML and return.
|
||||
html_bytes = self._sso_auth_success_template.encode("utf-8")
|
||||
html_bytes = SUCCESS_TEMPLATE.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
@@ -1115,12 +1106,12 @@ class AuthHandler(BaseHandler):
|
||||
# flow.
|
||||
deactivated = await self.store.get_user_deactivated_status(registered_user_id)
|
||||
if deactivated:
|
||||
html_bytes = self._sso_account_deactivated_template.encode("utf-8")
|
||||
html = self._sso_account_deactivated_template.encode("utf-8")
|
||||
|
||||
request.setResponseCode(403)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
request.write(html_bytes)
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html),))
|
||||
request.write(html)
|
||||
finish_request(request)
|
||||
return
|
||||
|
||||
@@ -1162,7 +1153,7 @@ class AuthHandler(BaseHandler):
|
||||
# URL we redirect users to.
|
||||
redirect_url_no_params = client_redirect_url.split("?")[0]
|
||||
|
||||
html_bytes = self._sso_redirect_confirm_template.render(
|
||||
html = self._sso_redirect_confirm_template.render(
|
||||
display_url=redirect_url_no_params,
|
||||
redirect_url=redirect_url,
|
||||
server_name=self._server_name,
|
||||
@@ -1170,8 +1161,8 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
request.write(html_bytes)
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html),))
|
||||
request.write(html)
|
||||
finish_request(request)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -880,6 +880,7 @@ class E2eKeysHandler(object):
|
||||
|
||||
try:
|
||||
# get our user-signing key to verify the signatures
|
||||
logger.info("***Getting the user_signing")
|
||||
(
|
||||
user_signing_key,
|
||||
user_signing_key_id,
|
||||
@@ -903,6 +904,7 @@ class E2eKeysHandler(object):
|
||||
try:
|
||||
# get the target user's master key, to make sure it matches
|
||||
# what was sent
|
||||
logger.info("***Getting the master")
|
||||
(
|
||||
master_key,
|
||||
master_key_id,
|
||||
@@ -985,14 +987,11 @@ class E2eKeysHandler(object):
|
||||
SynapseError: if `user_id` is invalid
|
||||
"""
|
||||
user = UserID.from_string(user_id)
|
||||
logger.info("***Trying to get a %s key for %s from storage...", key_type, user_id)
|
||||
key = yield self.store.get_e2e_cross_signing_key(
|
||||
user_id, key_type, from_user_id
|
||||
)
|
||||
|
||||
if key:
|
||||
# We found a copy of this key in our database. Decode and return it
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(key)
|
||||
return key, key_id, verify_key
|
||||
logger.info("***Well, we got this: %s", key)
|
||||
|
||||
# If we couldn't find the key locally, and we're looking for keys of
|
||||
# another user then attempt to fetch the missing key from the remote
|
||||
@@ -1000,22 +999,37 @@ class E2eKeysHandler(object):
|
||||
#
|
||||
# We may run into this in possible edge cases where a user tries to
|
||||
# cross-sign a remote user, but does not share any rooms with them yet.
|
||||
# Thus, we would not have their key list yet. We instead fetch the key,
|
||||
# Thus, we would not have their key list yet. We fetch the key here,
|
||||
# store it and notify clients of new, associated device IDs.
|
||||
if self.is_mine(user) or key_type not in ["master", "self_signing"]:
|
||||
# Note that master and self_signing keys are the only cross-signing keys we
|
||||
# can request over federation
|
||||
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
|
||||
|
||||
(
|
||||
key,
|
||||
key_id,
|
||||
verify_key,
|
||||
) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
|
||||
logger.info("***Checking if we'll do our thingy")
|
||||
if (
|
||||
key is None
|
||||
and not self.is_mine(user)
|
||||
# We only get "master" and "self_signing" keys from remote servers
|
||||
and key_type in ["master", "self_signing"]
|
||||
):
|
||||
logger.info("***Doing our thingy")
|
||||
(
|
||||
key,
|
||||
key_id,
|
||||
verify_key,
|
||||
) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
|
||||
|
||||
if key is None:
|
||||
logger.warning("No %s key found for %s", key_type, user_id)
|
||||
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
|
||||
|
||||
try:
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(key)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Invalid %s key retrieved: %s - %s %s", key_type, key, type(e), e,
|
||||
)
|
||||
raise SynapseError(
|
||||
502, "Invalid %s key retrieved from remote server" % (key_type,)
|
||||
)
|
||||
|
||||
logger.info("***Finally returning %s - %s - %s", key, key_id, verify_key)
|
||||
return key, key_id, verify_key
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -1040,6 +1054,7 @@ class E2eKeysHandler(object):
|
||||
remote_result = yield self.federation.query_user_devices(
|
||||
user.domain, user.to_string()
|
||||
)
|
||||
logger.info("***Got our remote_result: %s", remote_result)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Unable to query %s for cross-signing keys of user %s: %s %s",
|
||||
@@ -1051,38 +1066,31 @@ class E2eKeysHandler(object):
|
||||
return None, None, None
|
||||
|
||||
# Process each of the retrieved cross-signing keys
|
||||
desired_key = None
|
||||
desired_key_id = None
|
||||
desired_verify_key = None
|
||||
retrieved_device_ids = []
|
||||
final_key = None
|
||||
final_key_id = None
|
||||
final_verify_key = None
|
||||
device_ids = []
|
||||
for key_type in ["master", "self_signing"]:
|
||||
logger.info("***Processing retrieved key type: %s", key_type)
|
||||
key_content = remote_result.get(key_type + "_key")
|
||||
logger.info("***Got key_content: %s", key_content)
|
||||
if not key_content:
|
||||
continue
|
||||
|
||||
# Ensure these keys belong to the correct user
|
||||
if "user_id" not in key_content:
|
||||
logger.warning(
|
||||
"Invalid %s key retrieved, missing user_id field: %s",
|
||||
key_type,
|
||||
key_content,
|
||||
)
|
||||
continue
|
||||
if user.to_string() != key_content["user_id"]:
|
||||
logger.warning(
|
||||
"Found %s key of user %s when querying for keys of user %s",
|
||||
key_type,
|
||||
key_content["user_id"],
|
||||
user.to_string(),
|
||||
)
|
||||
continue
|
||||
# At the same time, store this key in the db for
|
||||
# subsequent queries
|
||||
yield self.store.set_e2e_cross_signing_key(
|
||||
user.to_string(), key_type, key_content
|
||||
)
|
||||
logger.info("***Stored key")
|
||||
|
||||
# Validate the key contents
|
||||
# Note down the device ID attached to this key
|
||||
try:
|
||||
# verify_key is a VerifyKey from signedjson, which uses
|
||||
# .version to denote the portion of the key ID after the
|
||||
# algorithm and colon, which is the device ID
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(key_content)
|
||||
logger.info("***Verified key: %s - %s", key_id, verify_key)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Invalid %s key retrieved: %s - %s %s",
|
||||
@@ -1092,29 +1100,23 @@ class E2eKeysHandler(object):
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
# Note down the device ID attached to this key
|
||||
retrieved_device_ids.append(verify_key.version)
|
||||
logger.info("***Appending device id: %s - %s", verify_key, verify_key.version)
|
||||
device_ids.append(verify_key.version)
|
||||
|
||||
# If this is the desired key type, save it and its ID/VerifyKey
|
||||
if key_type == desired_key_type:
|
||||
desired_key = key_content
|
||||
desired_verify_key = verify_key
|
||||
desired_key_id = key_id
|
||||
|
||||
# At the same time, store this key in the db for subsequent queries
|
||||
yield self.store.set_e2e_cross_signing_key(
|
||||
user.to_string(), key_type, key_content
|
||||
)
|
||||
logger.info("***We found our desired key type, %s!", key_type)
|
||||
final_key = key_content
|
||||
final_verify_key = verify_key
|
||||
final_key_id = key_id
|
||||
|
||||
# Notify clients that new devices for this user have been discovered
|
||||
if retrieved_device_ids:
|
||||
# XXX is this necessary?
|
||||
yield self.device_handler.notify_device_update(
|
||||
user.to_string(), retrieved_device_ids
|
||||
)
|
||||
if device_ids:
|
||||
logger.info("***Updating clients with devices: %s", device_ids)
|
||||
yield self.device_handler.notify_device_update(user.to_string(), device_ids)
|
||||
|
||||
return desired_key, desired_key_id, desired_verify_key
|
||||
logger.info("***Returning %s - %s - %s", final_key, final_key_id, final_verify_key)
|
||||
return final_key, final_key_id, final_verify_key
|
||||
|
||||
|
||||
def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
|
||||
|
||||
@@ -19,7 +19,6 @@ import random
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.types import UserID
|
||||
from synapse.visibility import filter_events_for_client
|
||||
@@ -98,8 +97,6 @@ class EventStreamHandler(BaseHandler):
|
||||
explicit_room_id=room_id,
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
# When the user joins a new room, or another user joins a currently
|
||||
# joined room, we need to send down presence for those users.
|
||||
to_add = []
|
||||
@@ -115,20 +112,19 @@ class EventStreamHandler(BaseHandler):
|
||||
users = await self.state.get_current_users_in_room(
|
||||
event.room_id
|
||||
)
|
||||
states = await presence_handler.get_states(users, as_event=True)
|
||||
to_add.extend(states)
|
||||
else:
|
||||
users = [event.state_key]
|
||||
|
||||
states = await presence_handler.get_states(users)
|
||||
to_add.extend(
|
||||
{
|
||||
"type": EventTypes.Presence,
|
||||
"content": format_user_presence_state(state, time_now),
|
||||
}
|
||||
for state in states
|
||||
)
|
||||
ev = await presence_handler.get_state(
|
||||
UserID.from_string(event.state_key), as_event=True
|
||||
)
|
||||
to_add.append(ev)
|
||||
|
||||
events.extend(to_add)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
chunks = await self._event_serializer.serialize_events(
|
||||
events,
|
||||
time_now,
|
||||
|
||||
@@ -343,7 +343,7 @@ class FederationHandler(BaseHandler):
|
||||
ours = await self.state_store.get_state_groups_ids(room_id, seen)
|
||||
|
||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||
state_maps = list(ours.values()) # type: List[StateMap[str]]
|
||||
state_maps = list(ours.values()) # type: list[StateMap[str]]
|
||||
|
||||
# we don't need this any more, let's delete it.
|
||||
del ours
|
||||
@@ -1694,15 +1694,16 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
return None
|
||||
|
||||
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_pdu(self, room_id, event_id):
|
||||
"""Returns the state at the event. i.e. not including said event.
|
||||
"""
|
||||
|
||||
event = await self.store.get_event(
|
||||
event = yield self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
|
||||
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
|
||||
state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
|
||||
|
||||
if state_groups:
|
||||
_, state = list(iteritems(state_groups)).pop()
|
||||
@@ -1713,7 +1714,7 @@ class FederationHandler(BaseHandler):
|
||||
if "replaces_state" in event.unsigned:
|
||||
prev_id = event.unsigned["replaces_state"]
|
||||
if prev_id != event.event_id:
|
||||
prev_event = await self.store.get_event(prev_id)
|
||||
prev_event = yield self.store.get_event(prev_id)
|
||||
results[(event.type, event.state_key)] = prev_event
|
||||
else:
|
||||
del results[(event.type, event.state_key)]
|
||||
@@ -1723,14 +1724,15 @@ class FederationHandler(BaseHandler):
|
||||
else:
|
||||
return []
|
||||
|
||||
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_pdu(self, room_id, event_id):
|
||||
"""Returns the state at the event. i.e. not including said event.
|
||||
"""
|
||||
event = await self.store.get_event(
|
||||
event = yield self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
|
||||
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
|
||||
state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
|
||||
|
||||
if state_groups:
|
||||
_, state = list(state_groups.items()).pop()
|
||||
@@ -1749,50 +1751,49 @@ class FederationHandler(BaseHandler):
|
||||
else:
|
||||
return []
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
async def on_backfill_request(
|
||||
self, origin: str, room_id: str, pdu_list: List[str], limit: int
|
||||
) -> List[EventBase]:
|
||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
||||
def on_backfill_request(self, origin, room_id, pdu_list, limit):
|
||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
# Synapse asks for 100 events per backfill request. Do not allow more.
|
||||
limit = min(limit, 100)
|
||||
|
||||
events = await self.store.get_backfill_events(room_id, pdu_list, limit)
|
||||
events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
|
||||
|
||||
events = await filter_events_for_server(self.storage, origin, events)
|
||||
events = yield filter_events_for_server(self.storage, origin, events)
|
||||
|
||||
return events
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
async def get_persisted_pdu(
|
||||
self, origin: str, event_id: str
|
||||
) -> Optional[EventBase]:
|
||||
def get_persisted_pdu(self, origin, event_id):
|
||||
"""Get an event from the database for the given server.
|
||||
|
||||
Args:
|
||||
origin: hostname of server which is requesting the event; we
|
||||
origin [str]: hostname of server which is requesting the event; we
|
||||
will check that the server is allowed to see it.
|
||||
event_id: id of the event being requested
|
||||
event_id [str]: id of the event being requested
|
||||
|
||||
Returns:
|
||||
None if we know nothing about the event; otherwise the (possibly-redacted) event.
|
||||
Deferred[EventBase|None]: None if we know nothing about the event;
|
||||
otherwise the (possibly-redacted) event.
|
||||
|
||||
Raises:
|
||||
AuthError if the server is not currently in the room
|
||||
"""
|
||||
event = await self.store.get_event(
|
||||
event = yield self.store.get_event(
|
||||
event_id, allow_none=True, allow_rejected=True
|
||||
)
|
||||
|
||||
if event:
|
||||
in_room = await self.auth.check_host_in_room(event.room_id, origin)
|
||||
in_room = yield self.auth.check_host_in_room(event.room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
events = await filter_events_for_server(self.storage, origin, [event])
|
||||
events = yield filter_events_for_server(self.storage, origin, [event])
|
||||
event = events[0]
|
||||
return event
|
||||
else:
|
||||
@@ -2396,7 +2397,7 @@ class FederationHandler(BaseHandler):
|
||||
"""
|
||||
# exclude the state key of the new event from the current_state in the context.
|
||||
if event.is_state():
|
||||
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
|
||||
event_key = (event.type, event.state_key)
|
||||
else:
|
||||
event_key = None
|
||||
state_updates = {
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
"""Utilities for interacting with Identity Servers"""
|
||||
|
||||
import logging
|
||||
import urllib.parse
|
||||
import urllib
|
||||
|
||||
from canonicaljson import json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
|
||||
@@ -381,16 +381,10 @@ class InitialSyncHandler(BaseHandler):
|
||||
return []
|
||||
|
||||
states = await presence_handler.get_states(
|
||||
[m.user_id for m in room_members]
|
||||
[m.user_id for m in room_members], as_event=True
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"type": EventTypes.Presence,
|
||||
"content": format_user_presence_state(s, time_now),
|
||||
}
|
||||
for s in states
|
||||
]
|
||||
return states
|
||||
|
||||
async def get_receipts():
|
||||
receipts = await self.store.get_linearized_receipts_for_room(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -22,10 +21,10 @@ The methods that define policy are:
|
||||
- PresenceHandler._handle_timeouts
|
||||
- should_notify
|
||||
"""
|
||||
import abc
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterable, List, Set
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from six import iteritems, itervalues
|
||||
|
||||
@@ -42,7 +41,7 @@ from synapse.logging.utils import log_function
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.metrics import Measure
|
||||
@@ -100,106 +99,13 @@ EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
|
||||
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
|
||||
|
||||
|
||||
class BasePresenceHandler(abc.ABC):
|
||||
"""Parts of the PresenceHandler that are shared between workers and master"""
|
||||
|
||||
class PresenceHandler(object):
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
self.user_to_current_state = {state.user_id: state for state in active_presence}
|
||||
|
||||
@abc.abstractmethod
|
||||
async def user_syncing(
|
||||
self, user_id: str, affect_presence: bool
|
||||
) -> ContextManager[None]:
|
||||
"""Returns a context manager that should surround any stream requests
|
||||
from the user.
|
||||
|
||||
This allows us to keep track of who is currently streaming and who isn't
|
||||
without having to have timers outside of this module to avoid flickering
|
||||
when users disconnect/reconnect.
|
||||
|
||||
Args:
|
||||
user_id: the user that is starting a sync
|
||||
affect_presence: If false this function will be a no-op.
|
||||
Useful for streams that are not associated with an actual
|
||||
client that is being used by a user.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
|
||||
"""Get an iterable of syncing users on this worker, to send to the presence handler
|
||||
|
||||
This is called when a replication connection is established. It should return
|
||||
a list of user ids, which are then sent as USER_SYNC commands to inform the
|
||||
process handling presence about those users.
|
||||
|
||||
Returns:
|
||||
An iterable of user_id strings.
|
||||
"""
|
||||
|
||||
async def get_state(self, target_user: UserID) -> UserPresenceState:
|
||||
results = await self.get_states([target_user.to_string()])
|
||||
return results[0]
|
||||
|
||||
async def get_states(
|
||||
self, target_user_ids: Iterable[str]
|
||||
) -> List[UserPresenceState]:
|
||||
"""Get the presence state for users."""
|
||||
|
||||
updates_d = await self.current_state_for_users(target_user_ids)
|
||||
updates = list(updates_d.values())
|
||||
|
||||
for user_id in set(target_user_ids) - {u.user_id for u in updates}:
|
||||
updates.append(UserPresenceState.default(user_id))
|
||||
|
||||
return updates
|
||||
|
||||
async def current_state_for_users(
|
||||
self, user_ids: Iterable[str]
|
||||
) -> Dict[str, UserPresenceState]:
|
||||
"""Get the current presence state for multiple users.
|
||||
|
||||
Returns:
|
||||
dict: `user_id` -> `UserPresenceState`
|
||||
"""
|
||||
states = {
|
||||
user_id: self.user_to_current_state.get(user_id, None)
|
||||
for user_id in user_ids
|
||||
}
|
||||
|
||||
missing = [user_id for user_id, state in iteritems(states) if not state]
|
||||
if missing:
|
||||
# There are things not in our in memory cache. Lets pull them out of
|
||||
# the database.
|
||||
res = await self.store.get_presence_for_users(missing)
|
||||
states.update(res)
|
||||
|
||||
missing = [user_id for user_id, state in iteritems(states) if not state]
|
||||
if missing:
|
||||
new = {
|
||||
user_id: UserPresenceState.default(user_id) for user_id in missing
|
||||
}
|
||||
states.update(new)
|
||||
self.user_to_current_state.update(new)
|
||||
|
||||
return states
|
||||
|
||||
@abc.abstractmethod
|
||||
async def set_state(
|
||||
self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
|
||||
) -> None:
|
||||
"""Set the presence state of the user. """
|
||||
|
||||
|
||||
class PresenceHandler(BasePresenceHandler):
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.server_name = hs.hostname
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.wheel_timer = WheelTimer()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.federation = hs.get_federation_sender()
|
||||
@@ -209,6 +115,13 @@ class PresenceHandler(BasePresenceHandler):
|
||||
|
||||
federation_registry.register_edu_handler("m.presence", self.incoming_presence)
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
|
||||
# A dictionary of the current state of users. This is prefilled with
|
||||
# non-offline presence from the DB. We should fetch from the DB if
|
||||
# we can't find a users presence in here.
|
||||
self.user_to_current_state = {state.user_id: state for state in active_presence}
|
||||
|
||||
LaterGauge(
|
||||
"synapse_handlers_presence_user_to_current_state_size",
|
||||
"",
|
||||
@@ -217,7 +130,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||
)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
for state in self.user_to_current_state.values():
|
||||
for state in active_presence:
|
||||
self.wheel_timer.insert(
|
||||
now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
|
||||
)
|
||||
@@ -448,18 +361,10 @@ class PresenceHandler(BasePresenceHandler):
|
||||
|
||||
timers_fired_counter.inc(len(states))
|
||||
|
||||
syncing_user_ids = {
|
||||
user_id
|
||||
for user_id, count in self.user_to_num_current_syncs.items()
|
||||
if count
|
||||
}
|
||||
for user_ids in self.external_process_to_current_syncs.values():
|
||||
syncing_user_ids.update(user_ids)
|
||||
|
||||
changes = handle_timeouts(
|
||||
states,
|
||||
is_mine_fn=self.is_mine_id,
|
||||
syncing_user_ids=syncing_user_ids,
|
||||
syncing_user_ids=self.get_currently_syncing_users(),
|
||||
now=now,
|
||||
)
|
||||
|
||||
@@ -557,9 +462,22 @@ class PresenceHandler(BasePresenceHandler):
|
||||
|
||||
return _user_syncing()
|
||||
|
||||
def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
|
||||
# since we are the process handling presence, there is nothing to do here.
|
||||
return []
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the set of user ids that are currently syncing on this HS.
|
||||
Returns:
|
||||
set(str): A set of user_id strings.
|
||||
"""
|
||||
if self.hs.config.use_presence:
|
||||
syncing_user_ids = {
|
||||
user_id
|
||||
for user_id, count in self.user_to_num_current_syncs.items()
|
||||
if count
|
||||
}
|
||||
for user_ids in self.external_process_to_current_syncs.values():
|
||||
syncing_user_ids.update(user_ids)
|
||||
return syncing_user_ids
|
||||
else:
|
||||
return set()
|
||||
|
||||
async def update_external_syncs_row(
|
||||
self, process_id, user_id, is_syncing, sync_time_msec
|
||||
@@ -636,6 +554,34 @@ class PresenceHandler(BasePresenceHandler):
|
||||
res = await self.current_state_for_users([user_id])
|
||||
return res[user_id]
|
||||
|
||||
async def current_state_for_users(self, user_ids):
|
||||
"""Get the current presence state for multiple users.
|
||||
|
||||
Returns:
|
||||
dict: `user_id` -> `UserPresenceState`
|
||||
"""
|
||||
states = {
|
||||
user_id: self.user_to_current_state.get(user_id, None)
|
||||
for user_id in user_ids
|
||||
}
|
||||
|
||||
missing = [user_id for user_id, state in iteritems(states) if not state]
|
||||
if missing:
|
||||
# There are things not in our in memory cache. Lets pull them out of
|
||||
# the database.
|
||||
res = await self.store.get_presence_for_users(missing)
|
||||
states.update(res)
|
||||
|
||||
missing = [user_id for user_id, state in iteritems(states) if not state]
|
||||
if missing:
|
||||
new = {
|
||||
user_id: UserPresenceState.default(user_id) for user_id in missing
|
||||
}
|
||||
states.update(new)
|
||||
self.user_to_current_state.update(new)
|
||||
|
||||
return states
|
||||
|
||||
async def _persist_and_notify(self, states):
|
||||
"""Persist states in the database, poke the notifier and send to
|
||||
interested remote servers
|
||||
@@ -723,6 +669,40 @@ class PresenceHandler(BasePresenceHandler):
|
||||
federation_presence_counter.inc(len(updates))
|
||||
await self._update_states(updates)
|
||||
|
||||
async def get_state(self, target_user, as_event=False):
|
||||
results = await self.get_states([target_user.to_string()], as_event=as_event)
|
||||
|
||||
return results[0]
|
||||
|
||||
async def get_states(self, target_user_ids, as_event=False):
|
||||
"""Get the presence state for users.
|
||||
|
||||
Args:
|
||||
target_user_ids (list)
|
||||
as_event (bool): Whether to format it as a client event or not.
|
||||
|
||||
Returns:
|
||||
list
|
||||
"""
|
||||
|
||||
updates = await self.current_state_for_users(target_user_ids)
|
||||
updates = list(updates.values())
|
||||
|
||||
for user_id in set(target_user_ids) - {u.user_id for u in updates}:
|
||||
updates.append(UserPresenceState.default(user_id))
|
||||
|
||||
now = self.clock.time_msec()
|
||||
if as_event:
|
||||
return [
|
||||
{
|
||||
"type": "m.presence",
|
||||
"content": format_user_presence_state(state, now),
|
||||
}
|
||||
for state in updates
|
||||
]
|
||||
else:
|
||||
return updates
|
||||
|
||||
async def set_state(self, target_user, state, ignore_status_msg=False):
|
||||
"""Set the presence state of the user.
|
||||
"""
|
||||
@@ -909,7 +889,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||
user_ids = await self.state.get_current_users_in_room(room_id)
|
||||
user_ids = list(filter(self.is_mine_id, user_ids))
|
||||
|
||||
states_d = await self.current_state_for_users(user_ids)
|
||||
states = await self.current_state_for_users(user_ids)
|
||||
|
||||
# Filter out old presence, i.e. offline presence states where
|
||||
# the user hasn't been active for a week. We can change this
|
||||
@@ -919,7 +899,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||
now = self.clock.time_msec()
|
||||
states = [
|
||||
state
|
||||
for state in states_d.values()
|
||||
for state in states.values()
|
||||
if state.state != PresenceState.OFFLINE
|
||||
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
|
||||
or state.status_msg is not None
|
||||
|
||||
@@ -434,27 +434,21 @@ class MatrixFederationHttpClient(object):
|
||||
logger.info("Failed to send request: %s", e)
|
||||
raise_from(RequestSendFailed(e, can_retry=True), e)
|
||||
|
||||
logger.info(
|
||||
"{%s} [%s] Got response headers: %d %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
)
|
||||
|
||||
incoming_responses_counter.labels(method_bytes, response.code).inc()
|
||||
|
||||
set_tag(tags.HTTP_STATUS_CODE, response.code)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
logger.debug(
|
||||
"{%s} [%s] Got response headers: %d %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
)
|
||||
pass
|
||||
else:
|
||||
logger.info(
|
||||
"{%s} [%s] Got response headers: %d %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
)
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
d = treq.content(response)
|
||||
|
||||
@@ -220,6 +220,12 @@ class Notifier(object):
|
||||
"""
|
||||
self.replication_callbacks.append(cb)
|
||||
|
||||
def add_remote_server_up_callback(self, cb: Callable[[str], None]):
|
||||
"""Add a callback that will be called when synapse detects a server
|
||||
has been
|
||||
"""
|
||||
self.remote_server_up_callbacks.append(cb)
|
||||
|
||||
def on_new_room_event(
|
||||
self, event, room_stream_id, max_room_stream_id, extra_users=[]
|
||||
):
|
||||
@@ -538,3 +544,6 @@ class Notifier(object):
|
||||
# circular dependencies.
|
||||
if self.federation_sender:
|
||||
self.federation_sender.wake_destination(server)
|
||||
|
||||
for cb in self.remote_server_up_callbacks:
|
||||
cb(server)
|
||||
|
||||
@@ -16,11 +16,9 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Pattern
|
||||
|
||||
from six import string_types
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import UserID
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
@@ -58,18 +56,18 @@ def _test_ineq_condition(condition, number):
|
||||
rhs = m.group(2)
|
||||
if not rhs.isdigit():
|
||||
return False
|
||||
rhs_int = int(rhs)
|
||||
rhs = int(rhs)
|
||||
|
||||
if ineq == "" or ineq == "==":
|
||||
return number == rhs_int
|
||||
return number == rhs
|
||||
elif ineq == "<":
|
||||
return number < rhs_int
|
||||
return number < rhs
|
||||
elif ineq == ">":
|
||||
return number > rhs_int
|
||||
return number > rhs
|
||||
elif ineq == ">=":
|
||||
return number >= rhs_int
|
||||
return number >= rhs
|
||||
elif ineq == "<=":
|
||||
return number <= rhs_int
|
||||
return number <= rhs
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -85,13 +83,7 @@ def tweaks_for_actions(actions):
|
||||
|
||||
|
||||
class PushRuleEvaluatorForEvent(object):
|
||||
def __init__(
|
||||
self,
|
||||
event: EventBase,
|
||||
room_member_count: int,
|
||||
sender_power_level: int,
|
||||
power_levels: dict,
|
||||
):
|
||||
def __init__(self, event, room_member_count, sender_power_level, power_levels):
|
||||
self._event = event
|
||||
self._room_member_count = room_member_count
|
||||
self._sender_power_level = sender_power_level
|
||||
@@ -100,7 +92,7 @@ class PushRuleEvaluatorForEvent(object):
|
||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||
self._value_cache = _flatten_dict(event)
|
||||
|
||||
def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
|
||||
def matches(self, condition, user_id, display_name):
|
||||
if condition["kind"] == "event_match":
|
||||
return self._event_match(condition, user_id)
|
||||
elif condition["kind"] == "contains_display_name":
|
||||
@@ -114,7 +106,7 @@ class PushRuleEvaluatorForEvent(object):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _event_match(self, condition: dict, user_id: str) -> bool:
|
||||
def _event_match(self, condition, user_id):
|
||||
pattern = condition.get("pattern", None)
|
||||
|
||||
if not pattern:
|
||||
@@ -142,7 +134,7 @@ class PushRuleEvaluatorForEvent(object):
|
||||
|
||||
return _glob_matches(pattern, haystack)
|
||||
|
||||
def _contains_display_name(self, display_name: str) -> bool:
|
||||
def _contains_display_name(self, display_name):
|
||||
if not display_name:
|
||||
return False
|
||||
|
||||
@@ -150,52 +142,51 @@ class PushRuleEvaluatorForEvent(object):
|
||||
if not body:
|
||||
return False
|
||||
|
||||
# Similar to _glob_matches, but do not treat display_name as a glob.
|
||||
r = regex_cache.get((display_name, False, True), None)
|
||||
if not r:
|
||||
r = re.escape(display_name)
|
||||
r = _re_word_boundary(r)
|
||||
r = re.compile(r, flags=re.IGNORECASE)
|
||||
regex_cache[(display_name, False, True)] = r
|
||||
return _glob_matches(display_name, body, word_boundary=True)
|
||||
|
||||
return r.search(body)
|
||||
|
||||
def _get_value(self, dotted_key: str) -> str:
|
||||
def _get_value(self, dotted_key):
|
||||
return self._value_cache.get(dotted_key, None)
|
||||
|
||||
|
||||
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
|
||||
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
|
||||
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
|
||||
register_cache("cache", "regex_push_cache", regex_cache)
|
||||
|
||||
|
||||
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
|
||||
def _glob_matches(glob, value, word_boundary=False):
|
||||
"""Tests if value matches glob.
|
||||
|
||||
Args:
|
||||
glob
|
||||
value: String to test against glob.
|
||||
word_boundary: Whether to match against word boundaries or entire
|
||||
glob (string)
|
||||
value (string): String to test against glob.
|
||||
word_boundary (bool): Whether to match against word boundaries or entire
|
||||
string. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
try:
|
||||
r = regex_cache.get((glob, True, word_boundary), None)
|
||||
r = regex_cache.get((glob, word_boundary), None)
|
||||
if not r:
|
||||
r = _glob_to_re(glob, word_boundary)
|
||||
regex_cache[(glob, True, word_boundary)] = r
|
||||
regex_cache[(glob, word_boundary)] = r
|
||||
return r.search(value)
|
||||
except re.error:
|
||||
logger.warning("Failed to parse glob to regex: %r", glob)
|
||||
return False
|
||||
|
||||
|
||||
def _glob_to_re(glob: str, word_boundary: bool) -> Pattern:
|
||||
def _glob_to_re(glob, word_boundary):
|
||||
"""Generates regex for a given glob.
|
||||
|
||||
Args:
|
||||
glob
|
||||
word_boundary: Whether to match against word boundaries or entire string.
|
||||
glob (string)
|
||||
word_boundary (bool): Whether to match against word boundaries or entire
|
||||
string. Defaults to False.
|
||||
|
||||
Returns:
|
||||
regex object
|
||||
"""
|
||||
if IS_GLOB.search(glob):
|
||||
r = re.escape(glob)
|
||||
@@ -228,7 +219,7 @@ def _glob_to_re(glob: str, word_boundary: bool) -> Pattern:
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def _re_word_boundary(r: str) -> str:
|
||||
def _re_word_boundary(r):
|
||||
"""
|
||||
Adds word boundary characters to the start and end of an
|
||||
expression to require that the match occur as a whole word,
|
||||
|
||||
@@ -98,7 +98,6 @@ CONDITIONAL_REQUIREMENTS = {
|
||||
"sentry": ["sentry-sdk>=0.7.2"],
|
||||
"opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
|
||||
"jwt": ["pyjwt>=1.6.4"],
|
||||
"redis": ["txredisapi>=1.4.7"],
|
||||
}
|
||||
|
||||
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
|
||||
|
||||
@@ -28,7 +28,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||
|
||||
The API looks like:
|
||||
|
||||
GET /_synapse/replication/get_repl_stream_updates/<stream name>?from_token=0&to_token=10
|
||||
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
|
||||
|
||||
200 OK
|
||||
|
||||
@@ -38,9 +38,6 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||
limited: False,
|
||||
}
|
||||
|
||||
If there are more rows than can sensibly be returned in one lump, `limited` will be
|
||||
set to true, and the caller should call again with a new `from_token`.
|
||||
|
||||
"""
|
||||
|
||||
NAME = "get_repl_stream_updates"
|
||||
@@ -55,8 +52,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||
self.streams = hs.get_replication_streamer().get_streams()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(stream_name, from_token, upto_token):
|
||||
return {"from_token": from_token, "upto_token": upto_token}
|
||||
def _serialize_payload(stream_name, from_token, upto_token, limit):
|
||||
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
|
||||
|
||||
async def _handle_request(self, request, stream_name):
|
||||
stream = self.streams.get(stream_name)
|
||||
@@ -65,9 +62,10 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||
|
||||
from_token = parse_integer(request, "from_token", required=True)
|
||||
upto_token = parse_integer(request, "upto_token", required=True)
|
||||
limit = parse_integer(request, "limit", required=True)
|
||||
|
||||
updates, upto_token, limited = await stream.get_updates_since(
|
||||
from_token, upto_token
|
||||
from_token, upto_token, limit
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
|
||||
class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
"""Factory for building connections to the master. Will reconnect if the
|
||||
connection is lost.
|
||||
|
||||
|
||||
@@ -210,10 +210,7 @@ class ReplicateCommand(Command):
|
||||
|
||||
class UserSyncCommand(Command):
|
||||
"""Sent by the client to inform the server that a user has started or
|
||||
stopped syncing on this process.
|
||||
|
||||
This is used by the process handling presence (typically the master) to
|
||||
calculate who is online and who is not.
|
||||
stopped syncing. Used to calculate presence on the master.
|
||||
|
||||
Includes a timestamp of when the last user sync was.
|
||||
|
||||
@@ -221,7 +218,7 @@ class UserSyncCommand(Command):
|
||||
|
||||
USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
|
||||
|
||||
Where <state> is either "start" or "end"
|
||||
Where <state> is either "start" or "stop"
|
||||
"""
|
||||
|
||||
NAME = "USER_SYNC"
|
||||
@@ -457,21 +454,3 @@ VALID_CLIENT_COMMANDS = (
|
||||
ErrorCommand.NAME,
|
||||
RemoteServerUpCommand.NAME,
|
||||
)
|
||||
|
||||
|
||||
def parse_command_from_line(line: str) -> Command:
|
||||
"""Parses a command from a received line.
|
||||
|
||||
Line should already be stripped of whitespace and be checked if blank.
|
||||
"""
|
||||
|
||||
idx = line.find(" ")
|
||||
if idx >= 0:
|
||||
cmd_name = line[:idx]
|
||||
rest_of_line = line[idx + 1 :]
|
||||
else:
|
||||
cmd_name = line
|
||||
rest_of_line = ""
|
||||
|
||||
cmd_cls = COMMAND_MAP[cmd_name]
|
||||
return cmd_cls.from_line(rest_of_line)
|
||||
|
||||
@@ -15,25 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
|
||||
from synapse.replication.tcp.client import ReplicationClientFactory
|
||||
from synapse.replication.tcp.commands import (
|
||||
ClearUserSyncsCommand,
|
||||
Command,
|
||||
@@ -94,7 +81,7 @@ class ReplicationCommandHandler:
|
||||
self._pending_batches = {} # type: Dict[str, List[Any]]
|
||||
|
||||
# The factory used to create connections.
|
||||
self._factory = None # type: Optional[ReconnectingClientFactory]
|
||||
self._factory = None # type: Optional[ReplicationClientFactory]
|
||||
|
||||
# The currently connected connections.
|
||||
self._connections = [] # type: List[AbstractConnection]
|
||||
@@ -115,52 +102,19 @@ class ReplicationCommandHandler:
|
||||
self._server_notices_sender = None
|
||||
if self._is_master:
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
|
||||
|
||||
def start_replication(self, hs):
|
||||
"""Helper method to start a replication connection to the remote server
|
||||
using TCP.
|
||||
"""
|
||||
if hs.config.redis.redis_enabled:
|
||||
from synapse.replication.tcp.redis import (
|
||||
RedisDirectTcpReplicationClientFactory,
|
||||
)
|
||||
import txredisapi
|
||||
client_name = hs.config.worker_name
|
||||
self._factory = ReplicationClientFactory(hs, client_name, self)
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, self._factory)
|
||||
|
||||
logger.info(
|
||||
"Connecting to redis (host=%r port=%r DBID=%r)",
|
||||
hs.config.redis_host,
|
||||
hs.config.redis_port,
|
||||
hs.config.redis_dbid,
|
||||
)
|
||||
|
||||
# We need two connections to redis, one for the subscription stream and
|
||||
# one to send commands to (as you can't send further redis commands to a
|
||||
# connection after SUBSCRIBE is called).
|
||||
|
||||
# First create the connection for sending commands.
|
||||
outbound_redis_connection = txredisapi.lazyConnection(
|
||||
host=hs.config.redis_host,
|
||||
port=hs.config.redis_port,
|
||||
dbid=hs.config.redis_dbid,
|
||||
password=hs.config.redis.redis_password,
|
||||
reconnect=True,
|
||||
)
|
||||
|
||||
# Now create the factory/connection for the subscription stream.
|
||||
self._factory = RedisDirectTcpReplicationClientFactory(
|
||||
hs, outbound_redis_connection
|
||||
)
|
||||
hs.get_reactor().connectTCP(
|
||||
hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
|
||||
)
|
||||
else:
|
||||
client_name = hs.config.worker_name
|
||||
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, self._factory)
|
||||
|
||||
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||
async def on_REPLICATE(self, cmd: ReplicateCommand):
|
||||
# We only want to announce positions by the writer of the streams.
|
||||
# Currently this is just the master process.
|
||||
if not self._is_master:
|
||||
@@ -170,7 +124,7 @@ class ReplicationCommandHandler:
|
||||
current_token = stream.current_token()
|
||||
self.send_command(PositionCommand(stream_name, current_token))
|
||||
|
||||
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
|
||||
async def on_USER_SYNC(self, cmd: UserSyncCommand):
|
||||
user_sync_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
@@ -178,23 +132,17 @@ class ReplicationCommandHandler:
|
||||
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
)
|
||||
|
||||
async def on_CLEAR_USER_SYNC(
|
||||
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
|
||||
):
|
||||
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
|
||||
if self._is_master:
|
||||
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||
|
||||
async def on_FEDERATION_ACK(
|
||||
self, conn: AbstractConnection, cmd: FederationAckCommand
|
||||
):
|
||||
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
|
||||
federation_ack_counter.inc()
|
||||
|
||||
if self._federation_sender:
|
||||
self._federation_sender.federation_ack(cmd.token)
|
||||
|
||||
async def on_REMOVE_PUSHER(
|
||||
self, conn: AbstractConnection, cmd: RemovePusherCommand
|
||||
):
|
||||
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
|
||||
remove_pusher_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
@@ -204,9 +152,7 @@ class ReplicationCommandHandler:
|
||||
|
||||
self._notifier.on_new_replication_data()
|
||||
|
||||
async def on_INVALIDATE_CACHE(
|
||||
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
|
||||
):
|
||||
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
|
||||
invalidate_cache_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
@@ -216,7 +162,7 @@ class ReplicationCommandHandler:
|
||||
cmd.cache_func, tuple(cmd.keys)
|
||||
)
|
||||
|
||||
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
|
||||
async def on_USER_IP(self, cmd: UserIpCommand):
|
||||
user_ip_cache_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
@@ -232,7 +178,7 @@ class ReplicationCommandHandler:
|
||||
if self._server_notices_sender:
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
|
||||
async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
||||
async def on_RDATA(self, cmd: RdataCommand):
|
||||
stream_name = cmd.stream_name
|
||||
inbound_rdata_count.labels(stream_name).inc()
|
||||
|
||||
@@ -283,7 +229,7 @@ class ReplicationCommandHandler:
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
await self._replication_data_handler.on_rdata(stream_name, token, rows)
|
||||
|
||||
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||
async def on_POSITION(self, cmd: PositionCommand):
|
||||
stream = self._streams.get(cmd.stream_name)
|
||||
if not stream:
|
||||
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
||||
@@ -322,14 +268,11 @@ class ReplicationCommandHandler:
|
||||
missing_updates,
|
||||
) = await stream.get_updates_since(current_token, cmd.token)
|
||||
|
||||
# TODO: add some tests for this
|
||||
|
||||
# Some streams return multiple rows with the same stream IDs,
|
||||
# which need to be processed in batches.
|
||||
|
||||
for token, rows in _batch_updates(updates):
|
||||
if updates:
|
||||
await self.on_rdata(
|
||||
cmd.stream_name, token, [stream.parse_row(row) for row in rows],
|
||||
cmd.stream_name,
|
||||
current_token,
|
||||
[stream.parse_row(update[1]) for update in updates],
|
||||
)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
@@ -337,30 +280,19 @@ class ReplicationCommandHandler:
|
||||
|
||||
self._streams_connected.add(cmd.stream_name)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(
|
||||
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
|
||||
):
|
||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||
|
||||
self._notifier.notify_remote_server_up(cmd.data)
|
||||
if self._is_master:
|
||||
self._notifier.notify_remote_server_up(cmd.data)
|
||||
|
||||
# We relay to all other connections to ensure every instance gets the
|
||||
# notification.
|
||||
#
|
||||
# When configured to use redis we'll always only have one connection and
|
||||
# so this is a no-op (all instances will have already received the same
|
||||
# REMOTE_SERVER_UP command).
|
||||
#
|
||||
# For direct TCP connections this will relay to all other connections
|
||||
# connected to us. When on master this will correctly fan out to all
|
||||
# other direct TCP clients and on workers there'll only be the one
|
||||
# connection to master.
|
||||
#
|
||||
# (The logic here should also be sound if we have a mix of Redis and
|
||||
# direct TCP connections so long as there is only one traffic route
|
||||
# between two instances, but that is not currently supported).
|
||||
self.send_command(cmd, ignore_conn=conn)
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users.
|
||||
"""
|
||||
return self._presence_handler.get_currently_syncing_users()
|
||||
|
||||
def new_connection(self, connection: AbstractConnection):
|
||||
"""Called when we have a new connection.
|
||||
@@ -379,11 +311,9 @@ class ReplicationCommandHandler:
|
||||
if self._factory:
|
||||
self._factory.resetDelay()
|
||||
|
||||
# Tell the other end if we have any users currently syncing.
|
||||
currently_syncing = (
|
||||
self._presence_handler.get_currently_syncing_users_for_replication()
|
||||
)
|
||||
|
||||
# Tell the server if we have any users currently syncing (should only
|
||||
# happen on synchrotrons)
|
||||
currently_syncing = self.get_currently_syncing_users()
|
||||
now = self._clock.time_msec()
|
||||
for user_id in currently_syncing:
|
||||
connection.send_command(
|
||||
@@ -405,21 +335,11 @@ class ReplicationCommandHandler:
|
||||
"""
|
||||
return bool(self._connections)
|
||||
|
||||
def send_command(
|
||||
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
|
||||
):
|
||||
def send_command(self, cmd: Command):
|
||||
"""Send a command to all connected connections.
|
||||
|
||||
Args:
|
||||
cmd
|
||||
ignore_conn: If set don't send command to the given connection.
|
||||
Used when relaying commands from one connection to all others.
|
||||
"""
|
||||
if self._connections:
|
||||
for connection in self._connections:
|
||||
if connection == ignore_conn:
|
||||
continue
|
||||
|
||||
try:
|
||||
connection.send_command(cmd)
|
||||
except Exception:
|
||||
@@ -484,52 +404,3 @@ class ReplicationCommandHandler:
|
||||
We need to check if the client is interested in the stream or not
|
||||
"""
|
||||
self.send_command(RdataCommand(stream_name, token, data))
|
||||
|
||||
|
||||
UpdateToken = TypeVar("UpdateToken")
|
||||
UpdateRow = TypeVar("UpdateRow")
|
||||
|
||||
|
||||
def _batch_updates(
|
||||
updates: Iterable[Tuple[UpdateToken, UpdateRow]]
|
||||
) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
|
||||
"""Collect stream updates with the same token together
|
||||
|
||||
Given a series of updates returned by Stream.get_updates_since(), collects
|
||||
the updates which share the same stream_id together.
|
||||
|
||||
For example:
|
||||
|
||||
[(1, a), (1, b), (2, c), (3, d), (3, e)]
|
||||
|
||||
becomes:
|
||||
|
||||
[
|
||||
(1, [a, b]),
|
||||
(2, [c]),
|
||||
(3, [d, e]),
|
||||
]
|
||||
"""
|
||||
|
||||
update_iter = iter(updates)
|
||||
|
||||
first_update = next(update_iter, None)
|
||||
if first_update is None:
|
||||
# empty input
|
||||
return
|
||||
|
||||
current_batch_token = first_update[0]
|
||||
current_batch = [first_update[1]]
|
||||
|
||||
for token, row in update_iter:
|
||||
if token != current_batch_token:
|
||||
# different token to the previous row: flush the previous
|
||||
# batch and start anew
|
||||
yield current_batch_token, current_batch
|
||||
current_batch_token = token
|
||||
current_batch = []
|
||||
|
||||
current_batch.append(row)
|
||||
|
||||
# flush the final batch
|
||||
yield current_batch_token, current_batch
|
||||
|
||||
@@ -50,7 +50,10 @@ import abc
|
||||
import fcntl
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, List
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, DefaultDict, List
|
||||
|
||||
from six import iteritems
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
@@ -60,6 +63,7 @@ from twisted.python.failure import Failure
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.commands import (
|
||||
COMMAND_MAP,
|
||||
VALID_CLIENT_COMMANDS,
|
||||
VALID_SERVER_COMMANDS,
|
||||
Command,
|
||||
@@ -68,7 +72,6 @@ from synapse.replication.tcp.commands import (
|
||||
PingCommand,
|
||||
ReplicateCommand,
|
||||
ServerCommand,
|
||||
parse_command_from_line,
|
||||
)
|
||||
from synapse.types import Collection
|
||||
from synapse.util import Clock
|
||||
@@ -83,18 +86,6 @@ connection_close_counter = Counter(
|
||||
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
|
||||
)
|
||||
|
||||
tcp_inbound_commands_counter = Counter(
|
||||
"synapse_replication_tcp_protocol_inbound_commands",
|
||||
"Number of commands received from replication, by command and name of process connected to",
|
||||
["command", "name"],
|
||||
)
|
||||
|
||||
tcp_outbound_commands_counter = Counter(
|
||||
"synapse_replication_tcp_protocol_outbound_commands",
|
||||
"Number of commands sent to replication, by command and name of process connected to",
|
||||
["command", "name"],
|
||||
)
|
||||
|
||||
# A list of all connected protocols. This allows us to send metrics about the
|
||||
# connections.
|
||||
connected_connections = []
|
||||
@@ -160,6 +151,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
# The LoopingCall for sending pings.
|
||||
self._send_ping_loop = None
|
||||
|
||||
self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
|
||||
self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
|
||||
|
||||
def connectionMade(self):
|
||||
logger.info("[%s] Connection established", self.id())
|
||||
|
||||
@@ -216,21 +210,37 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
|
||||
linestr = line.decode("utf-8")
|
||||
|
||||
try:
|
||||
cmd = parse_command_from_line(linestr)
|
||||
except Exception as e:
|
||||
logger.exception("[%s] failed to parse line: %r", self.id(), linestr)
|
||||
self.send_error("failed to parse line: %r (%r):" % (e, linestr))
|
||||
return
|
||||
# split at the first " ", handling one-word commands
|
||||
idx = linestr.index(" ")
|
||||
if idx >= 0:
|
||||
cmd_name = linestr[:idx]
|
||||
rest_of_line = linestr[idx + 1 :]
|
||||
else:
|
||||
cmd_name = linestr
|
||||
rest_of_line = ""
|
||||
|
||||
if cmd.NAME not in self.VALID_INBOUND_COMMANDS:
|
||||
logger.error("[%s] invalid command %s", self.id(), cmd.NAME)
|
||||
self.send_error("invalid command: %s", cmd.NAME)
|
||||
if cmd_name not in self.VALID_INBOUND_COMMANDS:
|
||||
logger.error("[%s] invalid command %s", self.id(), cmd_name)
|
||||
self.send_error("invalid command: %s", cmd_name)
|
||||
return
|
||||
|
||||
self.last_received_command = self.clock.time_msec()
|
||||
|
||||
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
|
||||
self.inbound_commands_counter[cmd_name] = (
|
||||
self.inbound_commands_counter[cmd_name] + 1
|
||||
)
|
||||
|
||||
cmd_cls = COMMAND_MAP[cmd_name]
|
||||
try:
|
||||
cmd = cmd_cls.from_line(rest_of_line)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
|
||||
)
|
||||
self.send_error(
|
||||
"failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line)
|
||||
)
|
||||
return
|
||||
|
||||
# Now lets try and call on_<CMD_NAME> function
|
||||
run_as_background_process(
|
||||
@@ -260,7 +270,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(self, cmd)
|
||||
await cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
@@ -296,8 +306,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
self._queue_command(cmd)
|
||||
return
|
||||
|
||||
tcp_outbound_commands_counter.labels(cmd.NAME, self.name).inc()
|
||||
|
||||
self.outbound_commands_counter[cmd.NAME] = (
|
||||
self.outbound_commands_counter[cmd.NAME] + 1
|
||||
)
|
||||
string = "%s %s" % (cmd.NAME, cmd.to_line())
|
||||
if "\n" in string:
|
||||
raise Exception("Unexpected newline in command: %r", string)
|
||||
@@ -549,3 +560,26 @@ tcp_transport_kernel_read_buffer = LaterGauge(
|
||||
for p in connected_connections
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
tcp_inbound_commands = LaterGauge(
|
||||
"synapse_replication_tcp_protocol_inbound_commands",
|
||||
"",
|
||||
["command", "name"],
|
||||
lambda: {
|
||||
(k, p.name): count
|
||||
for p in connected_connections
|
||||
for k, count in iteritems(p.inbound_commands_counter)
|
||||
},
|
||||
)
|
||||
|
||||
tcp_outbound_commands = LaterGauge(
|
||||
"synapse_replication_tcp_protocol_outbound_commands",
|
||||
"",
|
||||
["command", "name"],
|
||||
lambda: {
|
||||
(k, p.name): count
|
||||
for p in connected_connections
|
||||
for k, count in iteritems(p.outbound_commands_counter)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,193 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import txredisapi
|
||||
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.commands import (
|
||||
Command,
|
||||
ReplicateCommand,
|
||||
parse_command_from_line,
|
||||
)
|
||||
from synapse.replication.tcp.protocol import (
|
||||
AbstractConnection,
|
||||
tcp_inbound_commands_counter,
|
||||
tcp_outbound_commands_counter,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
"""Connection to redis subscribed to replication stream.
|
||||
|
||||
Parses incoming messages from redis into replication commands, and passes
|
||||
them to `ReplicationCommandHandler`
|
||||
|
||||
Due to the vagaries of `txredisapi` we don't want to have a custom
|
||||
constructor, so instead we expect the defined attributes below to be set
|
||||
immediately after initialisation.
|
||||
|
||||
Attributes:
|
||||
handler: The command handler to handle incoming commands.
|
||||
stream_name: The *redis* stream name to subscribe to (not anything to
|
||||
do with Synapse replication streams).
|
||||
outbound_redis_connection: The connection to redis to use to send
|
||||
commands.
|
||||
"""
|
||||
|
||||
handler = None # type: ReplicationCommandHandler
|
||||
stream_name = None # type: str
|
||||
outbound_redis_connection = None # type: txredisapi.RedisProtocol
|
||||
|
||||
def connectionMade(self):
|
||||
logger.info("Connected to redis instance")
|
||||
self.subscribe(self.stream_name)
|
||||
self.send_command(ReplicateCommand())
|
||||
|
||||
self.handler.new_connection(self)
|
||||
|
||||
def messageReceived(self, pattern: str, channel: str, message: str):
|
||||
"""Received a message from redis.
|
||||
"""
|
||||
|
||||
if message.strip() == "":
|
||||
# Ignore blank lines
|
||||
return
|
||||
|
||||
try:
|
||||
cmd = parse_command_from_line(message)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[%s] failed to parse line: %r", message,
|
||||
)
|
||||
return
|
||||
|
||||
# We use "redis" as the name here as we don't have 1:1 connections to
|
||||
# remote instances.
|
||||
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
|
||||
|
||||
# Now lets try and call on_<CMD_NAME> function
|
||||
run_as_background_process(
|
||||
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
|
||||
)
|
||||
|
||||
async def handle_command(self, cmd: Command):
|
||||
"""Handle a command we have received over the replication stream.
|
||||
|
||||
By default delegates to on_<COMMAND>, which should return an awaitable.
|
||||
|
||||
Args:
|
||||
cmd: received command
|
||||
"""
|
||||
handled = False
|
||||
|
||||
# First call any command handlers on this instance. These are for redis
|
||||
# specific handling.
|
||||
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(self, cmd)
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
logger.warning("Unhandled command: %r", cmd)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
logger.info("Lost connection to redis instance")
|
||||
self.handler.lost_connection(self)
|
||||
|
||||
def send_command(self, cmd: Command):
|
||||
"""Send a command if connection has been established.
|
||||
|
||||
Args:
|
||||
cmd (Command)
|
||||
"""
|
||||
string = "%s %s" % (cmd.NAME, cmd.to_line())
|
||||
if "\n" in string:
|
||||
raise Exception("Unexpected newline in command: %r", string)
|
||||
|
||||
encoded_string = string.encode("utf-8")
|
||||
|
||||
# We use "redis" as the name here as we don't have 1:1 connections to
|
||||
# remote instances.
|
||||
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
|
||||
|
||||
async def _send():
|
||||
with PreserveLoggingContext():
|
||||
# Note that we use the other connection as we can't send
|
||||
# commands using the subscription connection.
|
||||
await self.outbound_redis_connection.publish(
|
||||
self.stream_name, encoded_string
|
||||
)
|
||||
|
||||
run_as_background_process("send-cmd", _send)
|
||||
|
||||
|
||||
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
||||
"""This is a reconnecting factory that connects to redis and immediately
|
||||
subscribes to a stream.
|
||||
|
||||
Args:
|
||||
hs
|
||||
outbound_redis_connection: A connection to redis that will be used to
|
||||
send outbound commands (this is seperate to the redis connection
|
||||
used to subscribe).
|
||||
"""
|
||||
|
||||
maxDelay = 5
|
||||
continueTrying = True
|
||||
protocol = RedisSubscriber
|
||||
|
||||
def __init__(
|
||||
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# This sets the password on the RedisFactory base class (as
|
||||
# SubscriberFactory constructor doesn't pass it through).
|
||||
self.password = hs.config.redis.redis_password
|
||||
|
||||
self.handler = hs.get_tcp_replication()
|
||||
self.stream_name = hs.hostname
|
||||
|
||||
self.outbound_redis_connection = outbound_redis_connection
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = super().buildProtocol(addr) # type: RedisSubscriber
|
||||
|
||||
# We do this here rather than add to the constructor of `RedisSubcriber`
|
||||
# as to do so would involve overriding `buildProtocol` entirely, however
|
||||
# the base method does some other things than just instantiating the
|
||||
# protocol.
|
||||
p.handler = self.handler
|
||||
p.outbound_redis_connection = self.outbound_redis_connection
|
||||
p.stream_name = self.stream_name
|
||||
|
||||
return p
|
||||
@@ -17,7 +17,9 @@
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict, List
|
||||
from typing import Dict
|
||||
|
||||
from six import itervalues
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
@@ -69,28 +71,29 @@ class ReplicationStreamer(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
|
||||
self._replication_torture_level = hs.config.replication_torture_level
|
||||
|
||||
# Work out list of streams that this instance is the source of.
|
||||
self.streams = [] # type: List[Stream]
|
||||
if hs.config.worker_app is None:
|
||||
for stream in STREAMS_MAP.values():
|
||||
if stream == FederationStream and hs.config.send_federation:
|
||||
# We only support federation stream if federation sending
|
||||
# hase been disabled on the master.
|
||||
continue
|
||||
|
||||
self.streams.append(stream(hs))
|
||||
# List of streams that clients can subscribe to.
|
||||
# We only support federation stream if federation sending hase been
|
||||
# disabled on the master.
|
||||
self.streams = [
|
||||
stream(hs)
|
||||
for stream in itervalues(STREAMS_MAP)
|
||||
if stream != FederationStream or not hs.config.send_federation
|
||||
]
|
||||
|
||||
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
|
||||
|
||||
# Only bother registering the notifier callback if we have streams to
|
||||
# publish.
|
||||
if self.streams:
|
||||
self.notifier.add_replication_callback(self.on_notifier_poke)
|
||||
self.federation_sender = None
|
||||
if not hs.config.send_federation:
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
|
||||
self.notifier.add_replication_callback(self.on_notifier_poke)
|
||||
|
||||
# Keeps track of whether we are currently checking for updates
|
||||
self.is_looping = False
|
||||
|
||||
@@ -25,6 +25,8 @@ Each stream is defined by the following information:
|
||||
update_function: The function that returns a list of updates between two tokens
|
||||
"""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
AccountDataStream,
|
||||
BackfillStream,
|
||||
@@ -65,7 +67,8 @@ STREAMS_MAP = {
|
||||
GroupServerStream,
|
||||
UserSignatureStream,
|
||||
)
|
||||
}
|
||||
} # type: Dict[str, Type[Stream]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"STREAMS_MAP",
|
||||
|
||||
@@ -16,16 +16,17 @@
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# the number of rows to request from an update_function.
|
||||
_STREAM_UPDATE_TARGET_ROW_COUNT = 100
|
||||
|
||||
MAX_EVENTS_BEHIND = 500000
|
||||
|
||||
|
||||
# Some type aliases to make things a bit easier.
|
||||
@@ -33,36 +34,8 @@ _STREAM_UPDATE_TARGET_ROW_COUNT = 100
|
||||
# A stream position token
|
||||
Token = int
|
||||
|
||||
# The type of a stream update row, after JSON deserialisation, but before
|
||||
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
|
||||
# just a row from a database query, though this is dependent on the stream in question.
|
||||
#
|
||||
StreamRow = Tuple
|
||||
|
||||
# The type returned by the update_function of a stream, as well as get_updates(),
|
||||
# get_updates_since, etc.
|
||||
#
|
||||
# It consists of a triplet `(updates, new_last_token, limited)`, where:
|
||||
# * `updates` is a list of `(token, row)` entries.
|
||||
# * `new_last_token` is the new position in stream.
|
||||
# * `limited` is whether there are more updates to fetch.
|
||||
#
|
||||
StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
|
||||
|
||||
# The type of an update_function for a stream
|
||||
#
|
||||
# The arguments are:
|
||||
#
|
||||
# * from_token: the previous stream token: the starting point for fetching the
|
||||
# updates
|
||||
# * to_token: the new stream token: the point to get updates up to
|
||||
# * target_row_count: a target for the number of rows to be returned.
|
||||
#
|
||||
# The update_function is expected to return up to _approximately_ target_row_count rows.
|
||||
# If there are more updates available, it should set `limited` in the result, and
|
||||
# it will be called again to get the next batch.
|
||||
#
|
||||
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
|
||||
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
|
||||
StreamRow = Tuple[Token, tuple]
|
||||
|
||||
|
||||
class Stream(object):
|
||||
@@ -77,7 +50,7 @@ class Stream(object):
|
||||
ROW_TYPE = None # type: Any
|
||||
|
||||
@classmethod
|
||||
def parse_row(cls, row: StreamRow):
|
||||
def parse_row(cls, row):
|
||||
"""Parse a row received over replication
|
||||
|
||||
By default, assumes that the row data is an array object and passes its contents
|
||||
@@ -91,28 +64,7 @@ class Stream(object):
|
||||
"""
|
||||
return cls.ROW_TYPE(*row)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
current_token_function: Callable[[], Token],
|
||||
update_function: UpdateFunction,
|
||||
):
|
||||
"""Instantiate a Stream
|
||||
|
||||
current_token_function and update_function are callbacks which should be
|
||||
implemented by subclasses.
|
||||
|
||||
current_token_function is called to get the current token of the underlying
|
||||
stream.
|
||||
|
||||
update_function is called to get updates for this stream between a pair of
|
||||
stream tokens. See the UpdateFunction type definition for more info.
|
||||
|
||||
Args:
|
||||
current_token_function: callback to get the current token, as above
|
||||
update_function: callback go get stream updates, as above
|
||||
"""
|
||||
self.current_token = current_token_function
|
||||
self.update_function = update_function
|
||||
def __init__(self, hs):
|
||||
|
||||
# The token from which we last asked for updates
|
||||
self.last_token = self.current_token()
|
||||
@@ -123,7 +75,7 @@ class Stream(object):
|
||||
"""
|
||||
self.last_token = self.current_token()
|
||||
|
||||
async def get_updates(self) -> StreamUpdateResult:
|
||||
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||
"""Gets all updates since the last time this function was called (or
|
||||
since the stream was constructed if it hadn't been called before).
|
||||
|
||||
@@ -142,8 +94,8 @@ class Stream(object):
|
||||
return updates, current_token, limited
|
||||
|
||||
async def get_updates_since(
|
||||
self, from_token: Token, upto_token: Token
|
||||
) -> StreamUpdateResult:
|
||||
self, from_token: Token, upto_token: Token, limit: int = 100
|
||||
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||
"""Like get_updates except allows specifying from when we should
|
||||
stream updates
|
||||
|
||||
@@ -160,14 +112,33 @@ class Stream(object):
|
||||
return [], upto_token, False
|
||||
|
||||
updates, upto_token, limited = await self.update_function(
|
||||
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||
from_token, upto_token, limit=limit,
|
||||
)
|
||||
return updates, upto_token, limited
|
||||
|
||||
def current_token(self):
|
||||
"""Gets the current token of the underlying streams. Should be provided
|
||||
by the sub classes
|
||||
|
||||
Returns:
|
||||
int
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_function(self, from_token, current_token, limit):
|
||||
"""Get updates between from_token and to_token.
|
||||
|
||||
Returns:
|
||||
Deferred(list(tuple)): the first entry in the tuple is the token for
|
||||
that update, and the rest of the tuple gets used to construct
|
||||
a ``ROW_TYPE`` instance
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def db_query_to_update_function(
|
||||
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
|
||||
) -> UpdateFunction:
|
||||
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||
"""Wraps a db query function which returns a list of rows to make it
|
||||
suitable for use as an `update_function` for the Stream class
|
||||
"""
|
||||
@@ -177,16 +148,17 @@ def db_query_to_update_function(
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
if len(updates) == limit:
|
||||
upto_token = updates[-1][0]
|
||||
upto_token = rows[-1][0]
|
||||
limited = True
|
||||
assert len(updates) <= limit
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
return update_function
|
||||
|
||||
|
||||
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
|
||||
def make_http_update_function(
|
||||
hs, stream_name: str
|
||||
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||
"""Makes a suitable function for use as an `update_function` that queries
|
||||
the master process for updates.
|
||||
"""
|
||||
@@ -195,9 +167,12 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
|
||||
|
||||
async def update_function(
|
||||
from_token: int, upto_token: int, limit: int
|
||||
) -> StreamUpdateResult:
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
result = await client(
|
||||
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
|
||||
stream_name=stream_name,
|
||||
from_token=from_token,
|
||||
upto_token=upto_token,
|
||||
limit=limit,
|
||||
)
|
||||
return result["updates"], result["upto_token"], result["limited"]
|
||||
|
||||
@@ -226,10 +201,10 @@ class BackfillStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
store.get_current_backfill_token,
|
||||
db_query_to_update_function(store.get_all_new_backfill_event_rows),
|
||||
)
|
||||
self.current_token = store.get_current_backfill_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
|
||||
|
||||
super(BackfillStream, self).__init__(hs)
|
||||
|
||||
|
||||
class PresenceStream(Stream):
|
||||
@@ -251,18 +226,19 @@ class PresenceStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
presence_handler = hs.get_presence_handler()
|
||||
|
||||
self._is_worker = hs.config.worker_app is not None
|
||||
|
||||
self.current_token = store.get_current_presence_token # type: ignore
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
# on the master, query the presence handler
|
||||
presence_handler = hs.get_presence_handler()
|
||||
update_function = db_query_to_update_function(
|
||||
presence_handler.get_all_presence_updates
|
||||
)
|
||||
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
|
||||
else:
|
||||
# Query master process
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||
|
||||
super().__init__(store.get_current_presence_token, update_function)
|
||||
super(PresenceStream, self).__init__(hs)
|
||||
|
||||
|
||||
class TypingStream(Stream):
|
||||
@@ -276,16 +252,15 @@ class TypingStream(Stream):
|
||||
def __init__(self, hs):
|
||||
typing_handler = hs.get_typing_handler()
|
||||
|
||||
self.current_token = typing_handler.get_current_token # type: ignore
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
# on the master, query the typing handler
|
||||
update_function = db_query_to_update_function(
|
||||
typing_handler.get_all_typing_updates
|
||||
)
|
||||
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
|
||||
else:
|
||||
# Query master process
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||
|
||||
super().__init__(typing_handler.get_current_token, update_function)
|
||||
super(TypingStream, self).__init__(hs)
|
||||
|
||||
|
||||
class ReceiptsStream(Stream):
|
||||
@@ -305,10 +280,11 @@ class ReceiptsStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
store.get_max_receipt_stream_id,
|
||||
db_query_to_update_function(store.get_all_updated_receipts),
|
||||
)
|
||||
|
||||
self.current_token = store.get_max_receipt_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
|
||||
|
||||
super(ReceiptsStream, self).__init__(hs)
|
||||
|
||||
|
||||
class PushRulesStream(Stream):
|
||||
@@ -322,15 +298,13 @@ class PushRulesStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
super(PushRulesStream, self).__init__(
|
||||
self._current_token, self._update_function
|
||||
)
|
||||
super(PushRulesStream, self).__init__(hs)
|
||||
|
||||
def _current_token(self) -> int:
|
||||
def current_token(self):
|
||||
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
||||
return push_rules_token
|
||||
|
||||
async def _update_function(self, from_token: Token, to_token: Token, limit: int):
|
||||
async def update_function(self, from_token, to_token, limit):
|
||||
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
||||
|
||||
limited = False
|
||||
@@ -356,10 +330,10 @@ class PushersStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
super().__init__(
|
||||
store.get_pushers_stream_token,
|
||||
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
||||
)
|
||||
self.current_token = store.get_pushers_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
|
||||
|
||||
super(PushersStream, self).__init__(hs)
|
||||
|
||||
|
||||
class CachesStream(Stream):
|
||||
@@ -387,10 +361,11 @@ class CachesStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
store.get_cache_stream_token,
|
||||
db_query_to_update_function(store.get_all_updated_caches),
|
||||
)
|
||||
|
||||
self.current_token = store.get_cache_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
|
||||
|
||||
super(CachesStream, self).__init__(hs)
|
||||
|
||||
|
||||
class PublicRoomsStream(Stream):
|
||||
@@ -412,10 +387,11 @@ class PublicRoomsStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
store.get_current_public_room_stream_id,
|
||||
db_query_to_update_function(store.get_all_new_public_rooms),
|
||||
)
|
||||
|
||||
self.current_token = store.get_current_public_room_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
|
||||
|
||||
super(PublicRoomsStream, self).__init__(hs)
|
||||
|
||||
|
||||
class DeviceListsStream(Stream):
|
||||
@@ -432,10 +408,11 @@ class DeviceListsStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
store.get_device_stream_token,
|
||||
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
||||
)
|
||||
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
|
||||
|
||||
super(DeviceListsStream, self).__init__(hs)
|
||||
|
||||
|
||||
class ToDeviceStream(Stream):
|
||||
@@ -449,10 +426,11 @@ class ToDeviceStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
store.get_to_device_stream_token,
|
||||
db_query_to_update_function(store.get_all_new_device_messages),
|
||||
)
|
||||
|
||||
self.current_token = store.get_to_device_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
|
||||
|
||||
super(ToDeviceStream, self).__init__(hs)
|
||||
|
||||
|
||||
class TagAccountDataStream(Stream):
|
||||
@@ -468,10 +446,11 @@ class TagAccountDataStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
store.get_max_account_data_stream_id,
|
||||
db_query_to_update_function(store.get_all_updated_tags),
|
||||
)
|
||||
|
||||
self.current_token = store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
|
||||
|
||||
super(TagAccountDataStream, self).__init__(hs)
|
||||
|
||||
|
||||
class AccountDataStream(Stream):
|
||||
@@ -487,10 +466,11 @@ class AccountDataStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
super().__init__(
|
||||
self.store.get_max_account_data_stream_id,
|
||||
db_query_to_update_function(self._update_function),
|
||||
)
|
||||
|
||||
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||
|
||||
super(AccountDataStream, self).__init__(hs)
|
||||
|
||||
async def _update_function(self, from_token, to_token, limit):
|
||||
global_results, room_results = await self.store.get_all_updated_account_data(
|
||||
@@ -517,10 +497,11 @@ class GroupServerStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
store.get_group_stream_token,
|
||||
db_query_to_update_function(store.get_all_groups_changes),
|
||||
)
|
||||
|
||||
self.current_token = store.get_group_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
|
||||
|
||||
super(GroupServerStream, self).__init__(hs)
|
||||
|
||||
|
||||
class UserSignatureStream(Stream):
|
||||
@@ -534,9 +515,8 @@ class UserSignatureStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
store.get_device_stream_token,
|
||||
db_query_to_update_function(
|
||||
store.get_all_user_signature_changes_for_remotes
|
||||
),
|
||||
)
|
||||
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
|
||||
|
||||
super(UserSignatureStream, self).__init__(hs)
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import heapq
|
||||
from collections import Iterable
|
||||
from typing import List, Tuple, Type
|
||||
from typing import Tuple, Type
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Stream, StreamUpdateResult, Token
|
||||
from ._base import Stream, db_query_to_update_function
|
||||
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
@@ -117,121 +116,29 @@ class EventsStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
self._store = hs.get_datastore()
|
||||
super().__init__(
|
||||
self._store.get_current_events_token, self._update_function,
|
||||
self.current_token = self._store.get_current_events_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||
|
||||
super(EventsStream, self).__init__(hs)
|
||||
|
||||
async def _update_function(self, from_token, current_token, limit=None):
|
||||
event_rows = await self._store.get_all_new_forward_event_rows(
|
||||
from_token, current_token, limit
|
||||
)
|
||||
event_updates = (
|
||||
(row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
|
||||
)
|
||||
|
||||
async def _update_function(
|
||||
self, from_token: Token, current_token: Token, target_row_count: int
|
||||
) -> StreamUpdateResult:
|
||||
|
||||
# the events stream merges together three separate sources:
|
||||
# * new events
|
||||
# * current_state changes
|
||||
# * events which were previously outliers, but have now been de-outliered.
|
||||
#
|
||||
# The merge operation is complicated by the fact that we only have a single
|
||||
# "stream token" which is supposed to indicate how far we have got through
|
||||
# all three streams. It's therefore no good to return rows 1-1000 from the
|
||||
# "new events" table if the state_deltas are limited to rows 1-100 by the
|
||||
# target_row_count.
|
||||
#
|
||||
# In other words: we must pick a new upper limit, and must return *all* rows
|
||||
# up to that point for each of the three sources.
|
||||
#
|
||||
# Start by trying to split the target_row_count up. We expect to have a
|
||||
# negligible number of ex-outliers, and a rough approximation based on recent
|
||||
# traffic on sw1v.org shows that there are approximately the same number of
|
||||
# event rows between a given pair of stream ids as there are state
|
||||
# updates, so let's split our target_row_count among those two types. The target
|
||||
# is only an approximation - it doesn't matter if we end up going a bit over it.
|
||||
|
||||
target_row_count //= 2
|
||||
|
||||
# now we fetch up to that many rows from the events table
|
||||
|
||||
event_rows = await self._store.get_all_new_forward_event_rows(
|
||||
from_token, current_token, target_row_count
|
||||
) # type: List[Tuple]
|
||||
|
||||
# we rely on get_all_new_forward_event_rows strictly honouring the limit, so
|
||||
# that we know it is safe to just take upper_limit = event_rows[-1][0].
|
||||
assert (
|
||||
len(event_rows) <= target_row_count
|
||||
), "get_all_new_forward_event_rows did not honour row limit"
|
||||
|
||||
# if we hit the limit on event_updates, there's no point in going beyond the
|
||||
# last stream_id in the batch for the other sources.
|
||||
|
||||
if len(event_rows) == target_row_count:
|
||||
limited = True
|
||||
upper_limit = event_rows[-1][0] # type: int
|
||||
else:
|
||||
limited = False
|
||||
upper_limit = current_token
|
||||
|
||||
# next up is the state delta table
|
||||
|
||||
state_rows = await self._store.get_all_updated_current_state_deltas(
|
||||
from_token, upper_limit, target_row_count
|
||||
) # type: List[Tuple]
|
||||
|
||||
assert len(state_rows) <= target_row_count
|
||||
|
||||
# there can be more than one row per stream_id in that table, so if we hit
|
||||
# the limit there, we'll need to truncate the results so that we have a complete
|
||||
# set of changes for all the stream IDs we include.
|
||||
if len(state_rows) == target_row_count:
|
||||
assert state_rows[-1][0] <= upper_limit
|
||||
upper_limit = state_rows[-1][0] - 1
|
||||
|
||||
# search for the point to truncate the list
|
||||
for idx in range(len(state_rows) - 1, 0, -1):
|
||||
if state_rows[idx - 1][0] <= upper_limit:
|
||||
state_rows = state_rows[:idx]
|
||||
break
|
||||
else:
|
||||
# bother. We didn't get a full set of changes for even a single
|
||||
# stream id. let's run the query again, without a row limit, but for
|
||||
# just one stream id.
|
||||
upper_limit += 1
|
||||
state_rows = await self._store.get_all_updated_current_state_deltas(
|
||||
from_token, upper_limit, limit=None
|
||||
)
|
||||
|
||||
limited = True
|
||||
|
||||
# finally, fetch the ex-outliers rows. We assume there are few enough of these
|
||||
# not to bother with the limit.
|
||||
|
||||
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
|
||||
from_token, upper_limit
|
||||
) # type: List[Tuple]
|
||||
|
||||
# we now need to turn the raw database rows returned into tuples suitable
|
||||
# for the replication protocol (basically, we add an identifier to
|
||||
# distinguish the row type). At the same time, we can limit the event_rows
|
||||
# to the max stream_id from state_rows.
|
||||
|
||||
event_updates = (
|
||||
(stream_id, (EventsStreamEventRow.TypeId, rest))
|
||||
for (stream_id, *rest) in event_rows
|
||||
if stream_id <= upper_limit
|
||||
) # type: Iterable[Tuple[int, Tuple]]
|
||||
|
||||
from_token, current_token, limit
|
||||
)
|
||||
state_updates = (
|
||||
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
|
||||
for (stream_id, *rest) in state_rows
|
||||
) # type: Iterable[Tuple[int, Tuple]]
|
||||
(row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
|
||||
)
|
||||
|
||||
ex_outliers_updates = (
|
||||
(stream_id, (EventsStreamEventRow.TypeId, rest))
|
||||
for (stream_id, *rest) in ex_outliers_rows
|
||||
) # type: Iterable[Tuple[int, Tuple]]
|
||||
all_updates = heapq.merge(event_updates, state_updates)
|
||||
|
||||
# we need to return a sorted list, so merge them together.
|
||||
updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
|
||||
return updates, upper_limit, limited
|
||||
return all_updates
|
||||
|
||||
@classmethod
|
||||
def parse_row(cls, row):
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
from collections import namedtuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
|
||||
|
||||
|
||||
@@ -33,6 +35,7 @@ class FederationStream(Stream):
|
||||
|
||||
NAME = "federation"
|
||||
ROW_TYPE = FederationStreamRow
|
||||
_QUERY_MASTER = True
|
||||
|
||||
def __init__(self, hs):
|
||||
# Not all synapse instances will have a federation sender instance,
|
||||
@@ -40,16 +43,10 @@ class FederationStream(Stream):
|
||||
# so we stub the stream out when that is the case.
|
||||
if hs.config.worker_app is None or hs.should_send_federation():
|
||||
federation_sender = hs.get_federation_sender()
|
||||
current_token = federation_sender.get_current_token
|
||||
update_function = db_query_to_update_function(
|
||||
federation_sender.get_replication_rows
|
||||
)
|
||||
self.current_token = federation_sender.get_current_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
|
||||
else:
|
||||
current_token = lambda: 0
|
||||
update_function = self._stub_update_function
|
||||
self.current_token = lambda: 0 # type: ignore
|
||||
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
|
||||
|
||||
super().__init__(current_token, update_function)
|
||||
|
||||
@staticmethod
|
||||
async def _stub_update_function(from_token, upto_token, limit):
|
||||
return [], upto_token, False
|
||||
super(FederationStream, self).__init__(hs)
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
<html>
|
||||
<head>
|
||||
<title>Authentication Successful</title>
|
||||
<script>
|
||||
if (window.onAuthDone) {
|
||||
window.onAuthDone();
|
||||
} else if (window.opener && window.opener.postMessage) {
|
||||
window.opener.postMessage("authDone", "*");
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<p>Thank you</p>
|
||||
<p>You may now close this window and return to the application</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -183,23 +183,10 @@ class ListRoomRestServlet(RestServlet):
|
||||
# Extract query parameters
|
||||
start = parse_integer(request, "from", default=0)
|
||||
limit = parse_integer(request, "limit", default=100)
|
||||
order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
|
||||
order_by = parse_string(request, "order_by", default="alphabetical")
|
||||
if order_by not in (
|
||||
RoomSortOrder.ALPHABETICAL.value,
|
||||
RoomSortOrder.SIZE.value,
|
||||
RoomSortOrder.NAME.value,
|
||||
RoomSortOrder.CANONICAL_ALIAS.value,
|
||||
RoomSortOrder.JOINED_MEMBERS.value,
|
||||
RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
|
||||
RoomSortOrder.VERSION.value,
|
||||
RoomSortOrder.CREATOR.value,
|
||||
RoomSortOrder.ENCRYPTION.value,
|
||||
RoomSortOrder.FEDERATABLE.value,
|
||||
RoomSortOrder.PUBLIC.value,
|
||||
RoomSortOrder.JOIN_RULES.value,
|
||||
RoomSortOrder.GUEST_ACCESS.value,
|
||||
RoomSortOrder.HISTORY_VISIBILITY.value,
|
||||
RoomSortOrder.STATE_EVENTS.value,
|
||||
):
|
||||
raise SynapseError(
|
||||
400,
|
||||
|
||||
@@ -30,7 +30,7 @@ from synapse.http.servlet import (
|
||||
)
|
||||
from synapse.push.mailer import Mailer, load_jinja2_templates
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
||||
from synapse.util.stringutils import assert_valid_client_secret
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
@@ -100,11 +100,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
if existing_user_id is None:
|
||||
if self.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
|
||||
@@ -395,11 +390,6 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
if existing_user_id is not None:
|
||||
if self.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
|
||||
@@ -463,11 +453,6 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
|
||||
|
||||
if existing_user_id is not None:
|
||||
if self.hs.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
if not self.hs.config.account_threepid_delegate_msisdn:
|
||||
|
||||
@@ -38,12 +38,8 @@ class AccountDataServlet(RestServlet):
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
self._is_worker = hs.config.worker_app is not None
|
||||
|
||||
async def on_PUT(self, request, user_id, account_data_type):
|
||||
if self._is_worker:
|
||||
raise Exception("Cannot handle PUT /account_data on worker")
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot add account data for other users.")
|
||||
@@ -90,12 +86,8 @@ class RoomAccountDataServlet(RestServlet):
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
self._is_worker = hs.config.worker_app is not None
|
||||
|
||||
async def on_PUT(self, request, user_id, room_id, account_data_type):
|
||||
if self._is_worker:
|
||||
raise Exception("Cannot handle PUT /account_data on worker")
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot add account data for other users.")
|
||||
|
||||
@@ -18,6 +18,7 @@ import logging
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.urls import CLIENT_API_PREFIX
|
||||
from synapse.handlers.auth import SUCCESS_TEMPLATE
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import RestServlet, parse_string
|
||||
|
||||
@@ -89,30 +90,6 @@ TERMS_TEMPLATE = """
|
||||
</html>
|
||||
"""
|
||||
|
||||
SUCCESS_TEMPLATE = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Success!</title>
|
||||
<meta name='viewport' content='width=device-width, initial-scale=1,
|
||||
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
|
||||
<script>
|
||||
if (window.onAuthDone) {
|
||||
window.onAuthDone();
|
||||
} else if (window.opener && window.opener.postMessage) {
|
||||
window.opener.postMessage("authDone", "*");
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<p>Thank you</p>
|
||||
<p>You may now close this window and return to the application</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
class AuthRestServlet(RestServlet):
|
||||
"""
|
||||
|
||||
@@ -49,7 +49,7 @@ from synapse.http.servlet import (
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
||||
from synapse.util.stringutils import assert_valid_client_secret
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
@@ -135,11 +135,6 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
if existing_user_id is not None:
|
||||
if self.hs.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
|
||||
@@ -207,11 +202,6 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
if existing_user_id is not None:
|
||||
if self.hs.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(
|
||||
400, "Phone number is already in use", Codes.THREEPID_IN_USE
|
||||
)
|
||||
|
||||
@@ -402,7 +402,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
logger.debug("Running url preview cache expiry")
|
||||
logger.info("Running url preview cache expiry")
|
||||
|
||||
if not (await self.store.db.updates.has_completed_background_updates()):
|
||||
logger.info("Still running DB updates; skipping expiry")
|
||||
@@ -435,8 +435,6 @@ class PreviewUrlResource(DirectServeResource):
|
||||
|
||||
if removed_media:
|
||||
logger.info("Deleted %d entries from url cache", len(removed_media))
|
||||
else:
|
||||
logger.debug("No entries removed from url cache")
|
||||
|
||||
# Now we delete old images associated with the url cache.
|
||||
# These may be cached for a bit on the client (i.e., they
|
||||
@@ -483,10 +481,7 @@ class PreviewUrlResource(DirectServeResource):
|
||||
|
||||
await self.store.delete_url_cache_media(removed_media)
|
||||
|
||||
if removed_media:
|
||||
logger.info("Deleted %d media from url cache", len(removed_media))
|
||||
else:
|
||||
logger.debug("No media removed from url cache")
|
||||
logger.info("Deleted %d media from url cache", len(removed_media))
|
||||
|
||||
|
||||
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
||||
|
||||
@@ -25,7 +25,6 @@ import synapse.server_notices.server_notices_manager
|
||||
import synapse.server_notices.server_notices_sender
|
||||
import synapse.state
|
||||
import synapse.storage
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
|
||||
class HomeServer(object):
|
||||
@property
|
||||
@@ -98,7 +97,7 @@ class HomeServer(object):
|
||||
pass
|
||||
def get_notifier(self) -> synapse.notifier.Notifier:
|
||||
pass
|
||||
def get_presence_handler(self) -> synapse.handlers.presence.BasePresenceHandler:
|
||||
def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler:
|
||||
pass
|
||||
def get_clock(self) -> synapse.util.Clock:
|
||||
pass
|
||||
@@ -122,7 +121,3 @@ class HomeServer(object):
|
||||
pass
|
||||
def get_instance_id(self) -> str:
|
||||
pass
|
||||
def get_event_builder_factory(self) -> EventBuilderFactory:
|
||||
pass
|
||||
def get_storage(self) -> synapse.storage.Storage:
|
||||
pass
|
||||
|
||||
@@ -973,18 +973,8 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
return self._stream_id_gen.get_current_token()
|
||||
|
||||
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
|
||||
"""Returns new events, for the Events replication stream
|
||||
|
||||
Args:
|
||||
last_id: the last stream_id from the previous batch.
|
||||
current_id: the maximum stream_id to return up to
|
||||
limit: the maximum number of rows to return
|
||||
|
||||
Returns: Deferred[List[Tuple]]
|
||||
a list of events stream rows. Each tuple consists of a stream id as
|
||||
the first element, followed by fields suitable for casting into an
|
||||
EventsStreamRow.
|
||||
"""
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_new_forward_event_rows(txn):
|
||||
sql = (
|
||||
@@ -999,26 +989,13 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
new_event_updates = txn.fetchall()
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
||||
)
|
||||
if len(new_event_updates) == limit:
|
||||
upper_bound = new_event_updates[-1][0]
|
||||
else:
|
||||
upper_bound = current_id
|
||||
|
||||
def get_ex_outlier_stream_rows(self, last_id, current_id):
|
||||
"""Returns de-outliered events, for the Events replication stream
|
||||
|
||||
Args:
|
||||
last_id: the last stream_id from the previous batch.
|
||||
current_id: the maximum stream_id to return up to
|
||||
|
||||
Returns: Deferred[List[Tuple]]
|
||||
a list of events stream rows. Each tuple consists of a stream id as
|
||||
the first element, followed by fields suitable for casting into an
|
||||
EventsStreamRow.
|
||||
"""
|
||||
|
||||
def get_ex_outlier_stream_rows_txn(txn):
|
||||
sql = (
|
||||
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
@@ -1029,14 +1006,15 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? < event_stream_ordering"
|
||||
" AND event_stream_ordering <= ?"
|
||||
" ORDER BY event_stream_ordering ASC"
|
||||
" ORDER BY event_stream_ordering DESC"
|
||||
)
|
||||
txn.execute(sql, (last_id, upper_bound))
|
||||
new_event_updates.extend(txn)
|
||||
|
||||
txn.execute(sql, (last_id, current_id))
|
||||
return txn.fetchall()
|
||||
return new_event_updates
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
|
||||
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
||||
)
|
||||
|
||||
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
|
||||
@@ -1084,23 +1062,15 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
|
||||
)
|
||||
|
||||
def get_all_updated_current_state_deltas(
|
||||
self, from_token: int, to_token: int, limit: Optional[int]
|
||||
):
|
||||
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
|
||||
def get_all_updated_current_state_deltas_txn(txn):
|
||||
sql = """
|
||||
SELECT stream_id, room_id, type, state_key, event_id
|
||||
FROM current_state_delta_stream
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
ORDER BY stream_id ASC LIMIT ?
|
||||
"""
|
||||
params = [from_token, to_token]
|
||||
|
||||
if limit is not None:
|
||||
sql += "LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
txn.execute(sql, params)
|
||||
txn.execute(sql, (from_token, to_token, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db.runInteraction(
|
||||
|
||||
@@ -52,28 +52,12 @@ class RoomSortOrder(Enum):
|
||||
"""
|
||||
Enum to define the sorting method used when returning rooms with get_rooms_paginate
|
||||
|
||||
NAME = sort rooms alphabetically by name
|
||||
JOINED_MEMBERS = sort rooms by membership size, highest to lowest
|
||||
ALPHABETICAL = sort rooms alphabetically by name
|
||||
SIZE = sort rooms by membership size, highest to lowest
|
||||
"""
|
||||
|
||||
# ALPHABETICAL and SIZE are deprecated.
|
||||
# ALPHABETICAL is the same as NAME.
|
||||
ALPHABETICAL = "alphabetical"
|
||||
# SIZE is the same as JOINED_MEMBERS.
|
||||
SIZE = "size"
|
||||
NAME = "name"
|
||||
CANONICAL_ALIAS = "canonical_alias"
|
||||
JOINED_MEMBERS = "joined_members"
|
||||
JOINED_LOCAL_MEMBERS = "joined_local_members"
|
||||
VERSION = "version"
|
||||
CREATOR = "creator"
|
||||
ENCRYPTION = "encryption"
|
||||
FEDERATABLE = "federatable"
|
||||
PUBLIC = "public"
|
||||
JOIN_RULES = "join_rules"
|
||||
GUEST_ACCESS = "guest_access"
|
||||
HISTORY_VISIBILITY = "history_visibility"
|
||||
STATE_EVENTS = "state_events"
|
||||
|
||||
|
||||
class RoomWorkerStore(SQLBaseStore):
|
||||
@@ -345,52 +329,12 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
|
||||
# Set ordering
|
||||
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
|
||||
# Deprecated in favour of RoomSortOrder.JOINED_MEMBERS
|
||||
order_by_column = "curr.joined_members"
|
||||
order_by_asc = False
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL:
|
||||
# Deprecated in favour of RoomSortOrder.NAME
|
||||
# Sort alphabetically
|
||||
order_by_column = "state.name"
|
||||
order_by_asc = True
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.NAME:
|
||||
order_by_column = "state.name"
|
||||
order_by_asc = True
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.CANONICAL_ALIAS:
|
||||
order_by_column = "state.canonical_alias"
|
||||
order_by_asc = True
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_MEMBERS:
|
||||
order_by_column = "curr.joined_members"
|
||||
order_by_asc = False
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_LOCAL_MEMBERS:
|
||||
order_by_column = "curr.local_users_in_room"
|
||||
order_by_asc = False
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.VERSION:
|
||||
order_by_column = "rooms.room_version"
|
||||
order_by_asc = False
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.CREATOR:
|
||||
order_by_column = "rooms.creator"
|
||||
order_by_asc = True
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.ENCRYPTION:
|
||||
order_by_column = "state.encryption"
|
||||
order_by_asc = True
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.FEDERATABLE:
|
||||
order_by_column = "state.is_federatable"
|
||||
order_by_asc = True
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.PUBLIC:
|
||||
order_by_column = "rooms.is_public"
|
||||
order_by_asc = True
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.JOIN_RULES:
|
||||
order_by_column = "state.join_rules"
|
||||
order_by_asc = True
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.GUEST_ACCESS:
|
||||
order_by_column = "state.guest_access"
|
||||
order_by_asc = True
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.HISTORY_VISIBILITY:
|
||||
order_by_column = "state.history_visibility"
|
||||
order_by_asc = True
|
||||
elif RoomSortOrder(order_by) == RoomSortOrder.STATE_EVENTS:
|
||||
order_by_column = "curr.current_state_events"
|
||||
order_by_asc = False
|
||||
else:
|
||||
raise StoreError(
|
||||
500, "Incorrect value for order_by provided: %s" % order_by
|
||||
@@ -405,13 +349,9 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
# for, and another query for getting the total number of events that could be
|
||||
# returned. Thus allowing us to see if there are more events to paginate through
|
||||
info_sql = """
|
||||
SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members,
|
||||
curr.local_users_in_room, rooms.room_version, rooms.creator,
|
||||
state.encryption, state.is_federatable, rooms.is_public, state.join_rules,
|
||||
state.guest_access, state.history_visibility, curr.current_state_events
|
||||
SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members
|
||||
FROM room_stats_state state
|
||||
INNER JOIN room_stats_current curr USING (room_id)
|
||||
INNER JOIN rooms USING (room_id)
|
||||
%s
|
||||
ORDER BY %s %s
|
||||
LIMIT ?
|
||||
@@ -449,16 +389,6 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
"name": room[1],
|
||||
"canonical_alias": room[2],
|
||||
"joined_members": room[3],
|
||||
"joined_local_members": room[4],
|
||||
"version": room[5],
|
||||
"creator": room[6],
|
||||
"encryption": room[7],
|
||||
"federatable": room[8],
|
||||
"public": room[9],
|
||||
"join_rules": room[10],
|
||||
"guest_access": room[11],
|
||||
"history_visibility": room[12],
|
||||
"state_events": room[13],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Set
|
||||
|
||||
from six import integer_types
|
||||
|
||||
@@ -24,11 +23,8 @@ from synapse.util import caches
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# for now, assume all entities in the cache are strings
|
||||
EntityType = str
|
||||
|
||||
|
||||
class StreamChangeCache:
|
||||
class StreamChangeCache(object):
|
||||
"""Keeps track of the stream positions of the latest change in a set of entities.
|
||||
|
||||
Typically the entity will be a room or user id.
|
||||
@@ -38,23 +34,10 @@ class StreamChangeCache:
|
||||
old then the cache will simply return all given entities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
current_stream_pos: int,
|
||||
max_size=10000,
|
||||
prefilled_cache: Optional[Mapping[EntityType, int]] = None,
|
||||
):
|
||||
def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None):
|
||||
self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR)
|
||||
self._entity_to_key = {} # type: Dict[EntityType, int]
|
||||
|
||||
# map from stream id to the a set of entities which changed at that stream id.
|
||||
self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]]
|
||||
|
||||
# the earliest stream_pos for which we can reliably answer
|
||||
# get_all_entities_changed. In other words, one less than the earliest
|
||||
# stream_pos for which we know _cache is valid.
|
||||
#
|
||||
self._entity_to_key = {}
|
||||
self._cache = SortedDict()
|
||||
self._earliest_known_stream_pos = current_stream_pos
|
||||
self.name = name
|
||||
self.metrics = caches.register_cache("cache", self.name, self._cache)
|
||||
@@ -63,7 +46,7 @@ class StreamChangeCache:
|
||||
for entity, stream_pos in prefilled_cache.items():
|
||||
self.entity_has_changed(entity, stream_pos)
|
||||
|
||||
def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
|
||||
def has_entity_changed(self, entity, stream_pos):
|
||||
"""Returns True if the entity may have been updated since stream_pos
|
||||
"""
|
||||
assert type(stream_pos) in integer_types
|
||||
@@ -84,17 +67,22 @@ class StreamChangeCache:
|
||||
self.metrics.inc_hits()
|
||||
return False
|
||||
|
||||
def get_entities_changed(
|
||||
self, entities: Iterable[EntityType], stream_pos: int
|
||||
) -> Set[EntityType]:
|
||||
def get_entities_changed(self, entities, stream_pos):
|
||||
"""
|
||||
Returns subset of entities that have had new things since the given
|
||||
position. Entities unknown to the cache will be returned. If the
|
||||
position is too old it will just return the given list.
|
||||
"""
|
||||
changed_entities = self.get_all_entities_changed(stream_pos)
|
||||
if changed_entities is not None:
|
||||
result = set(changed_entities).intersection(entities)
|
||||
assert type(stream_pos) is int
|
||||
|
||||
if stream_pos >= self._earliest_known_stream_pos:
|
||||
changed_entities = {
|
||||
self._cache[k]
|
||||
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
|
||||
}
|
||||
|
||||
result = changed_entities.intersection(entities)
|
||||
|
||||
self.metrics.inc_hits()
|
||||
else:
|
||||
result = set(entities)
|
||||
@@ -102,13 +90,13 @@ class StreamChangeCache:
|
||||
|
||||
return result
|
||||
|
||||
def has_any_entity_changed(self, stream_pos: int) -> bool:
|
||||
def has_any_entity_changed(self, stream_pos):
|
||||
"""Returns if any entity has changed
|
||||
"""
|
||||
assert type(stream_pos) is int
|
||||
|
||||
if not self._cache:
|
||||
# If the cache is empty, nothing can have changed.
|
||||
# If we have no cache, nothing can have changed.
|
||||
return False
|
||||
|
||||
if stream_pos >= self._earliest_known_stream_pos:
|
||||
@@ -118,58 +106,42 @@ class StreamChangeCache:
|
||||
self.metrics.inc_misses()
|
||||
return True
|
||||
|
||||
def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
|
||||
"""Returns all entities that have had new things since the given
|
||||
def get_all_entities_changed(self, stream_pos):
|
||||
"""Returns all entites that have had new things since the given
|
||||
position. If the position is too old it will return None.
|
||||
|
||||
Returns the entities in the order that they were changed.
|
||||
"""
|
||||
assert type(stream_pos) is int
|
||||
|
||||
if stream_pos < self._earliest_known_stream_pos:
|
||||
if stream_pos >= self._earliest_known_stream_pos:
|
||||
return [
|
||||
self._cache[k]
|
||||
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
|
||||
]
|
||||
else:
|
||||
return None
|
||||
|
||||
changed_entities = [] # type: List[EntityType]
|
||||
|
||||
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
|
||||
changed_entities.extend(self._cache[k])
|
||||
return changed_entities
|
||||
|
||||
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
|
||||
def entity_has_changed(self, entity, stream_pos):
|
||||
"""Informs the cache that the entity has been changed at the given
|
||||
position.
|
||||
"""
|
||||
assert type(stream_pos) is int
|
||||
|
||||
if stream_pos <= self._earliest_known_stream_pos:
|
||||
return
|
||||
if stream_pos > self._earliest_known_stream_pos:
|
||||
old_pos = self._entity_to_key.get(entity, None)
|
||||
if old_pos is not None:
|
||||
stream_pos = max(stream_pos, old_pos)
|
||||
self._cache.pop(old_pos, None)
|
||||
self._cache[stream_pos] = entity
|
||||
self._entity_to_key[entity] = stream_pos
|
||||
|
||||
old_pos = self._entity_to_key.get(entity, None)
|
||||
if old_pos is not None:
|
||||
if old_pos >= stream_pos:
|
||||
# nothing to do
|
||||
return
|
||||
e = self._cache[old_pos]
|
||||
e.remove(entity)
|
||||
if not e:
|
||||
# cache at this point is now empty
|
||||
del self._cache[old_pos]
|
||||
|
||||
e1 = self._cache.get(stream_pos)
|
||||
if e1 is None:
|
||||
e1 = self._cache[stream_pos] = set()
|
||||
e1.add(entity)
|
||||
self._entity_to_key[entity] = stream_pos
|
||||
|
||||
# if the cache is too big, remove entries
|
||||
while len(self._cache) > self._max_size:
|
||||
k, r = self._cache.popitem(0)
|
||||
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
|
||||
for entity in r:
|
||||
del self._entity_to_key[entity]
|
||||
|
||||
def get_max_pos_of_last_change(self, entity: EntityType) -> int:
|
||||
while len(self._cache) > self._max_size:
|
||||
k, r = self._cache.popitem(0)
|
||||
self._earliest_known_stream_pos = max(
|
||||
k, self._earliest_known_stream_pos
|
||||
)
|
||||
self._entity_to_key.pop(r, None)
|
||||
|
||||
def get_max_pos_of_last_change(self, entity):
|
||||
"""Returns an upper bound of the stream id of the last change to an
|
||||
entity.
|
||||
"""
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
event = FrozenEvent(
|
||||
{
|
||||
"event_id": "$event_id",
|
||||
"type": "m.room.history_visibility",
|
||||
"sender": "@user:test",
|
||||
"state_key": "",
|
||||
"room_id": "@room:test",
|
||||
"content": {"body": "foo bar baz"},
|
||||
},
|
||||
RoomVersions.V1,
|
||||
)
|
||||
room_member_count = 0
|
||||
sender_power_level = 0
|
||||
power_levels = {}
|
||||
self.evaluator = PushRuleEvaluatorForEvent(
|
||||
event, room_member_count, sender_power_level, power_levels
|
||||
)
|
||||
|
||||
def test_display_name(self):
|
||||
"""Check for a matching display name in the body of the event."""
|
||||
condition = {
|
||||
"kind": "contains_display_name",
|
||||
}
|
||||
|
||||
# Blank names are skipped.
|
||||
self.assertFalse(self.evaluator.matches(condition, "@user:test", ""))
|
||||
|
||||
# Check a display name that doesn't match.
|
||||
self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found"))
|
||||
|
||||
# Check a display name which matches.
|
||||
self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo"))
|
||||
|
||||
# A display name that matches, but not a full word does not result in a match.
|
||||
self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba"))
|
||||
|
||||
# A display name should not be interpreted as a regular expression.
|
||||
self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]"))
|
||||
|
||||
# A display name with spaces should work fine.
|
||||
self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar"))
|
||||
@@ -16,7 +16,7 @@
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
from synapse.replication.tcp.client import (
|
||||
DirectTcpReplicationClientFactory,
|
||||
ReplicationClientFactory,
|
||||
ReplicationDataHandler,
|
||||
)
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
@@ -61,7 +61,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.slaved_store
|
||||
)
|
||||
|
||||
client_factory = DirectTcpReplicationClientFactory(
|
||||
client_factory = ReplicationClientFactory(
|
||||
self.hs, "client_name", self.replication_handler
|
||||
)
|
||||
client_factory.handler = self.replication_handler
|
||||
|
||||
@@ -13,72 +13,38 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from mock import Mock
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.web.http import HTTPChannel
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
"""Base class for tests of the replication streams"""
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.test_handler = Mock(wraps=TestReplicationDataHandler())
|
||||
return self.setup_test_homeserver(replication_data_handler=self.test_handler)
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# build a replication server
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
self.server = server_factory.buildProtocol(None)
|
||||
|
||||
# Make a new HomeServer object for the worker
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.generic_worker"
|
||||
config["worker_replication_host"] = "testserv"
|
||||
config["worker_replication_http_port"] = "8765"
|
||||
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
self.worker_hs = self.setup_test_homeserver(
|
||||
http_client=None,
|
||||
homeserverToUse=GenericWorkerServer,
|
||||
config=config,
|
||||
reactor=self.reactor,
|
||||
)
|
||||
|
||||
# Since we use sqlite in memory databases we need to make sure the
|
||||
# databases objects are the same.
|
||||
self.worker_hs.get_datastore().db = hs.get_datastore().db
|
||||
|
||||
self.test_handler = self._build_replication_data_handler()
|
||||
self.worker_hs.replication_data_handler = self.test_handler
|
||||
|
||||
repl_handler = ReplicationCommandHandler(self.worker_hs)
|
||||
repl_handler = ReplicationCommandHandler(hs)
|
||||
repl_handler.handler = self.test_handler
|
||||
self.client = ClientReplicationStreamProtocol(
|
||||
self.worker_hs, "client", "test", clock, repl_handler,
|
||||
hs, "client", "test", clock, repl_handler,
|
||||
)
|
||||
|
||||
self._client_transport = None
|
||||
self._server_transport = None
|
||||
|
||||
def _build_replication_data_handler(self):
|
||||
return TestReplicationDataHandler(self.worker_hs.get_datastore())
|
||||
|
||||
def reconnect(self):
|
||||
if self._client_transport:
|
||||
self.client.close()
|
||||
@@ -108,204 +74,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
self.streamer.on_notifier_poke()
|
||||
self.pump(0.1)
|
||||
|
||||
def handle_http_replication_attempt(self) -> SynapseRequest:
|
||||
"""Asserts that a connection attempt was made to the master HS on the
|
||||
HTTP replication port, then proxies it to the master HS object to be
|
||||
handled.
|
||||
|
||||
Returns:
|
||||
The request object received by master HS.
|
||||
"""
|
||||
|
||||
# We should have an outbound connection attempt.
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 8765)
|
||||
|
||||
# Set up client side protocol
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
|
||||
request_factory = OneShotRequestFactory()
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = _PushHTTPChannel(self.reactor)
|
||||
channel.requestFactory = request_factory
|
||||
channel.site = self.site
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
channel, self.reactor, client_protocol
|
||||
)
|
||||
client_protocol.makeConnection(client_to_server_transport)
|
||||
|
||||
server_to_client_transport = FakeTransport(
|
||||
client_protocol, self.reactor, channel
|
||||
)
|
||||
channel.makeConnection(server_to_client_transport)
|
||||
|
||||
# The request will now be processed by `self.site` and the response
|
||||
# streamed back.
|
||||
self.reactor.advance(0)
|
||||
|
||||
# We tear down the connection so it doesn't get reused without our
|
||||
# knowledge.
|
||||
server_to_client_transport.loseConnection()
|
||||
client_to_server_transport.loseConnection()
|
||||
|
||||
return request_factory.request
|
||||
|
||||
def assert_request_is_get_repl_stream_updates(
|
||||
self, request: SynapseRequest, stream_name: str
|
||||
):
|
||||
"""Asserts that the given request is a HTTP replication request for
|
||||
fetching updates for given stream.
|
||||
"""
|
||||
|
||||
self.assertRegex(
|
||||
request.path,
|
||||
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
|
||||
% (stream_name.encode("ascii"),),
|
||||
)
|
||||
|
||||
self.assertEqual(request.method, b"GET")
|
||||
|
||||
|
||||
class TestReplicationDataHandler(ReplicationDataHandler):
|
||||
class TestReplicationDataHandler:
|
||||
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
||||
|
||||
def __init__(self, store: BaseSlavedStore):
|
||||
super().__init__(store)
|
||||
|
||||
# streams to subscribe to: map from stream id to position
|
||||
self.stream_positions = {} # type: Dict[str, int]
|
||||
|
||||
# list of received (stream_name, token, row) tuples
|
||||
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
|
||||
def __init__(self):
|
||||
self.streams = set()
|
||||
self._received_rdata_rows = []
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
return self.stream_positions
|
||||
positions = {s: 0 for s in self.streams}
|
||||
for stream, token, _ in self._received_rdata_rows:
|
||||
if stream in self.streams:
|
||||
positions[stream] = max(token, positions.get(stream, 0))
|
||||
return positions
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
await super().on_rdata(stream_name, token, rows)
|
||||
for r in rows:
|
||||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
self._received_rdata_rows.append((stream_name, token, r))
|
||||
|
||||
if (
|
||||
stream_name in self.stream_positions
|
||||
and token > self.stream_positions[stream_name]
|
||||
):
|
||||
self.stream_positions[stream_name] = token
|
||||
|
||||
|
||||
@attr.s()
|
||||
class OneShotRequestFactory:
|
||||
"""A simple request factory that generates a single `SynapseRequest` and
|
||||
stores it for future use. Can only be used once.
|
||||
"""
|
||||
|
||||
request = attr.ib(default=None)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert self.request is None
|
||||
|
||||
self.request = SynapseRequest(*args, **kwargs)
|
||||
return self.request
|
||||
|
||||
|
||||
class _PushHTTPChannel(HTTPChannel):
|
||||
"""A HTTPChannel that wraps pull producers to push producers.
|
||||
|
||||
This is a hack to get around the fact that HTTPChannel transparently wraps a
|
||||
pull producer (which is what Synapse uses to reply to requests) with
|
||||
`_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
|
||||
uses the standard reactor rather than letting us use our test reactor, which
|
||||
makes it very hard to test.
|
||||
"""
|
||||
|
||||
def __init__(self, reactor: IReactorTime):
|
||||
super().__init__()
|
||||
self.reactor = reactor
|
||||
|
||||
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
# Convert pull producers to push producer.
|
||||
if not streaming:
|
||||
self._pull_to_push_producer = _PullToPushProducer(
|
||||
self.reactor, producer, self
|
||||
)
|
||||
producer = self._pull_to_push_producer
|
||||
|
||||
super().registerProducer(producer, True)
|
||||
|
||||
def unregisterProducer(self):
|
||||
if self._pull_to_push_producer:
|
||||
# We need to manually stop the _PullToPushProducer.
|
||||
self._pull_to_push_producer.stop()
|
||||
|
||||
|
||||
class _PullToPushProducer:
|
||||
"""A push producer that wraps a pull producer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
|
||||
):
|
||||
self._clock = Clock(reactor)
|
||||
self._producer = producer
|
||||
self._consumer = consumer
|
||||
|
||||
# While running we use a looping call with a zero delay to call
|
||||
# resumeProducing on given producer.
|
||||
self._looping_call = None # type: Optional[LoopingCall]
|
||||
|
||||
# We start writing next reactor tick.
|
||||
self._start_loop()
|
||||
|
||||
def _start_loop(self):
|
||||
"""Start the looping call to
|
||||
"""
|
||||
|
||||
if not self._looping_call:
|
||||
# Start a looping call which runs every tick.
|
||||
self._looping_call = self._clock.looping_call(self._run_once, 0)
|
||||
|
||||
def stop(self):
|
||||
"""Stops calling resumeProducing.
|
||||
"""
|
||||
if self._looping_call:
|
||||
self._looping_call.stop()
|
||||
self._looping_call = None
|
||||
|
||||
def pauseProducing(self):
|
||||
"""Implements IPushProducer
|
||||
"""
|
||||
self.stop()
|
||||
|
||||
def resumeProducing(self):
|
||||
"""Implements IPushProducer
|
||||
"""
|
||||
self._start_loop()
|
||||
|
||||
def stopProducing(self):
|
||||
"""Implements IPushProducer
|
||||
"""
|
||||
self.stop()
|
||||
self._producer.stopProducing()
|
||||
|
||||
def _run_once(self):
|
||||
"""Calls resumeProducing on producer once.
|
||||
"""
|
||||
|
||||
try:
|
||||
self._producer.resumeProducing()
|
||||
except Exception:
|
||||
logger.exception("Failed to call resumeProducing")
|
||||
try:
|
||||
self._consumer.unregisterProducer()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.stopProducing()
|
||||
async def on_position(self, stream_name, token):
|
||||
pass
|
||||
|
||||
@@ -1,390 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStreamCurrentStateRow,
|
||||
EventsStreamEventRow,
|
||||
EventsStreamRow,
|
||||
)
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import login, room
|
||||
|
||||
from tests.replication.tcp.streams._base import BaseStreamTestCase
|
||||
from tests.test_utils.event_injection import inject_event, inject_member_event
|
||||
|
||||
|
||||
class EventsStreamTestCase(BaseStreamTestCase):
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
super().prepare(reactor, clock, hs)
|
||||
self.user_id = self.register_user("u1", "pass")
|
||||
self.user_tok = self.login("u1", "pass")
|
||||
|
||||
self.reconnect()
|
||||
self.test_handler.stream_positions["events"] = 0
|
||||
|
||||
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
||||
self.test_handler.received_rdata_rows.clear()
|
||||
|
||||
def test_update_function_event_row_limit(self):
|
||||
"""Test replication with many non-state events
|
||||
|
||||
Checks that all events are correctly replicated when there are lots of
|
||||
event rows to be replicated.
|
||||
"""
|
||||
|
||||
# generate lots of non-state events. We inject them using inject_event
|
||||
# so that they are not send out over replication until we call self.replicate().
|
||||
events = [
|
||||
self._inject_test_event()
|
||||
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1)
|
||||
]
|
||||
|
||||
# also one state event
|
||||
state_event = self._inject_state_event()
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# receieved
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now fire up the replicator
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
for event in events:
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "ev")
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, event.event_id)
|
||||
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, state_event.event_id)
|
||||
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "state")
|
||||
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
|
||||
self.assertEqual(row.data.event_id, state_event.event_id)
|
||||
|
||||
self.assertEqual([], received_rows)
|
||||
|
||||
def test_update_function_huge_state_change(self):
|
||||
"""Test replication with many state events
|
||||
|
||||
Ensures that all events are correctly replicated when there are lots of
|
||||
state change rows to be replicated.
|
||||
"""
|
||||
|
||||
# we want to generate lots of state changes at a single stream ID.
|
||||
#
|
||||
# We do this by having two branches in the DAG. On one, we have a moderator
|
||||
# which that generates lots of state; on the other, we de-op the moderator,
|
||||
# thus invalidating all the state.
|
||||
|
||||
OTHER_USER = "@other_user:localhost"
|
||||
|
||||
# have the user join
|
||||
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
|
||||
|
||||
# Update existing power levels with mod at PL50
|
||||
pls = self.helper.get_state(
|
||||
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
|
||||
)
|
||||
pls["users"][OTHER_USER] = 50
|
||||
self.helper.send_state(
|
||||
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
|
||||
)
|
||||
|
||||
# this is the point in the DAG where we make a fork
|
||||
fork_point = self.get_success(
|
||||
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
|
||||
) # type: List[str]
|
||||
|
||||
events = [
|
||||
self._inject_state_event(sender=OTHER_USER)
|
||||
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
|
||||
]
|
||||
|
||||
self.replicate()
|
||||
# all those events and state changes should have landed
|
||||
self.assertGreaterEqual(
|
||||
len(self.test_handler.received_rdata_rows), 2 * len(events)
|
||||
)
|
||||
self.test_handler.received_rdata_rows.clear()
|
||||
|
||||
# a state event which doesn't get rolled back, to check that the state
|
||||
# before the huge update comes through ok
|
||||
state1 = self._inject_state_event()
|
||||
|
||||
# roll back all the state by de-modding the user
|
||||
prev_events = fork_point
|
||||
pls["users"][OTHER_USER] = 0
|
||||
pl_event = inject_event(
|
||||
self.hs,
|
||||
prev_event_ids=prev_events,
|
||||
type=EventTypes.PowerLevels,
|
||||
state_key="",
|
||||
sender=self.user_id,
|
||||
room_id=self.room_id,
|
||||
content=pls,
|
||||
)
|
||||
|
||||
# one more bit of state that doesn't get rolled back
|
||||
state2 = self._inject_state_event()
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# receieved
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now fire up the replicator
|
||||
self.replicate()
|
||||
|
||||
# now we should have received all the expected rows in the right order.
|
||||
#
|
||||
# we expect:
|
||||
#
|
||||
# - two rows for state1
|
||||
# - the PL event row, plus state rows for the PL event and each
|
||||
# of the states that got reverted.
|
||||
# - two rows for state2
|
||||
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
|
||||
# first check the first two rows, which should be state1
|
||||
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "ev")
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, state1.event_id)
|
||||
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "state")
|
||||
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
|
||||
self.assertEqual(row.data.event_id, state1.event_id)
|
||||
|
||||
# now the last two rows, which should be state2
|
||||
stream_name, token, row = received_rows.pop(-2)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "ev")
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, state2.event_id)
|
||||
|
||||
stream_name, token, row = received_rows.pop(-1)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "state")
|
||||
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
|
||||
self.assertEqual(row.data.event_id, state2.event_id)
|
||||
|
||||
# that should leave us with the rows for the PL event
|
||||
self.assertEqual(len(received_rows), len(events) + 2)
|
||||
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "ev")
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, pl_event.event_id)
|
||||
|
||||
# the state rows are unsorted
|
||||
state_rows = [] # type: List[EventsStreamCurrentStateRow]
|
||||
for stream_name, token, row in received_rows:
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "state")
|
||||
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
|
||||
state_rows.append(row.data)
|
||||
|
||||
state_rows.sort(key=lambda r: r.state_key)
|
||||
|
||||
sr = state_rows.pop(0)
|
||||
self.assertEqual(sr.type, EventTypes.PowerLevels)
|
||||
self.assertEqual(sr.event_id, pl_event.event_id)
|
||||
for sr in state_rows:
|
||||
self.assertEqual(sr.type, "test_state_event")
|
||||
# "None" indicates the state has been deleted
|
||||
self.assertIsNone(sr.event_id)
|
||||
|
||||
def test_update_function_state_row_limit(self):
|
||||
"""Test replication with many state events over several stream ids.
|
||||
"""
|
||||
|
||||
# we want to generate lots of state changes, but for this test, we want to
|
||||
# spread out the state changes over a few stream IDs.
|
||||
#
|
||||
# We do this by having two branches in the DAG. On one, we have four moderators,
|
||||
# each of which that generates lots of state; on the other, we de-op the users,
|
||||
# thus invalidating all the state.
|
||||
|
||||
NUM_USERS = 4
|
||||
STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1
|
||||
|
||||
user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)]
|
||||
|
||||
# have the users join
|
||||
for u in user_ids:
|
||||
inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
|
||||
|
||||
# Update existing power levels with mod at PL50
|
||||
pls = self.helper.get_state(
|
||||
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
|
||||
)
|
||||
pls["users"].update({u: 50 for u in user_ids})
|
||||
self.helper.send_state(
|
||||
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
|
||||
)
|
||||
|
||||
# this is the point in the DAG where we make a fork
|
||||
fork_point = self.get_success(
|
||||
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
|
||||
) # type: List[str]
|
||||
|
||||
events = [] # type: List[EventBase]
|
||||
for user in user_ids:
|
||||
events.extend(
|
||||
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
|
||||
)
|
||||
|
||||
self.replicate()
|
||||
# all those events and state changes should have landed
|
||||
self.assertGreaterEqual(
|
||||
len(self.test_handler.received_rdata_rows), 2 * len(events)
|
||||
)
|
||||
self.test_handler.received_rdata_rows.clear()
|
||||
|
||||
# now roll back all that state by de-modding the users
|
||||
prev_events = fork_point
|
||||
pl_events = []
|
||||
for u in user_ids:
|
||||
pls["users"][u] = 0
|
||||
e = inject_event(
|
||||
self.hs,
|
||||
prev_event_ids=prev_events,
|
||||
type=EventTypes.PowerLevels,
|
||||
state_key="",
|
||||
sender=self.user_id,
|
||||
room_id=self.room_id,
|
||||
content=pls,
|
||||
)
|
||||
prev_events = [e.event_id]
|
||||
pl_events.append(e)
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# receieved
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now fire up the replicator
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
self.assertGreaterEqual(len(received_rows), len(events))
|
||||
for i in range(NUM_USERS):
|
||||
# for each user, we expect the PL event row, followed by state rows for
|
||||
# the PL event and each of the states that got reverted.
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "ev")
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, pl_events[i].event_id)
|
||||
|
||||
# the state rows are unsorted
|
||||
state_rows = [] # type: List[EventsStreamCurrentStateRow]
|
||||
for j in range(STATES_PER_USER + 1):
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "state")
|
||||
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
|
||||
state_rows.append(row.data)
|
||||
|
||||
state_rows.sort(key=lambda r: r.state_key)
|
||||
|
||||
sr = state_rows.pop(0)
|
||||
self.assertEqual(sr.type, EventTypes.PowerLevels)
|
||||
self.assertEqual(sr.event_id, pl_events[i].event_id)
|
||||
for sr in state_rows:
|
||||
self.assertEqual(sr.type, "test_state_event")
|
||||
# "None" indicates the state has been deleted
|
||||
self.assertIsNone(sr.event_id)
|
||||
|
||||
self.assertEqual([], received_rows)
|
||||
|
||||
event_count = 0
|
||||
|
||||
def _inject_test_event(
|
||||
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
|
||||
) -> EventBase:
|
||||
if sender is None:
|
||||
sender = self.user_id
|
||||
|
||||
if body is None:
|
||||
body = "event %i" % (self.event_count,)
|
||||
self.event_count += 1
|
||||
|
||||
return inject_event(
|
||||
self.hs,
|
||||
room_id=self.room_id,
|
||||
sender=sender,
|
||||
type="test_event",
|
||||
content={"body": body},
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _inject_state_event(
|
||||
self,
|
||||
body: Optional[str] = None,
|
||||
state_key: Optional[str] = None,
|
||||
sender: Optional[str] = None,
|
||||
) -> EventBase:
|
||||
if sender is None:
|
||||
sender = self.user_id
|
||||
|
||||
if state_key is None:
|
||||
state_key = "state_%i" % (self.event_count,)
|
||||
self.event_count += 1
|
||||
|
||||
if body is None:
|
||||
body = "state event %s" % (state_key,)
|
||||
|
||||
return inject_event(
|
||||
self.hs,
|
||||
room_id=self.room_id,
|
||||
sender=sender,
|
||||
type="test_state_event",
|
||||
state_key=state_key,
|
||||
content={"body": body},
|
||||
)
|
||||
@@ -12,11 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# type: ignore
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from synapse.replication.tcp.streams._base import ReceiptsStream
|
||||
|
||||
from tests.replication.tcp.streams._base import BaseStreamTestCase
|
||||
@@ -25,14 +20,11 @@ USER_ID = "@feeling:blue"
|
||||
|
||||
|
||||
class ReceiptsStreamTestCase(BaseStreamTestCase):
|
||||
def _build_replication_data_handler(self):
|
||||
return Mock(wraps=super()._build_replication_data_handler())
|
||||
|
||||
def test_receipt(self):
|
||||
self.reconnect()
|
||||
|
||||
# make the client subscribe to the receipts stream
|
||||
self.test_handler.stream_positions.update({"receipts": 0})
|
||||
self.test_handler.streams.add("receipts")
|
||||
|
||||
# tell the master to send a new receipt
|
||||
self.get_success(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user