1
0

Compare commits

...

41 Commits

Author SHA1 Message Date
Richard van der Hoff 9347b961b2 changelog 2023-10-30 12:22:46 +00:00
Richard van der Hoff 6dbad83998 Implement MSC4072: return result for all /keys/claims 2023-10-30 12:03:36 +00:00
Richard van der Hoff 4c586567f6 Improve tracing for claim_one_time_keys
`set_tag` isn't really appropriate here, since we're logging an entire dict
rather than something useful to search on. (Also, it will get truncated.) Use
`log_kv` instead.
2023-10-27 15:21:58 +01:00
Richard van der Hoff 27546ac171 claim_local_one_time_keys: pass in result object
Rather than making the caller merge the results together, do it inside
`claim_local_one_time_keys`.
2023-10-27 15:10:55 +01:00
Richard van der Hoff dd45ba4d67 Fix types for OTK claims
We don't know for certain that keys will be `JsonDict`s -- indeed if the key is
not signed it will just be a string. Fix up the types to reflect this.
2023-10-27 15:00:30 +01:00
Erik Johnston 928e964857 Fix cross-worker ratelimiting (#16558)
c.f. #16481
2023-10-27 12:52:40 +01:00
Erik Johnston 0680d76659 Reduce replication traffic due to reflected cache stream POSITION (#16557) 2023-10-27 12:51:08 +01:00
Erik Johnston c02406ac71 Add new module API for adding custom fields to events unsigned section (#16549) 2023-10-27 09:04:08 +00:00
Patrick Cloke 679c691f6f Remove more usages of cursor_to_dict. (#16551)
Mostly to improve type safety.
2023-10-26 15:12:28 -04:00
Patrick Cloke 85e5f2dc25 Add a new module API to update user presence state. (#16544)
This adds a module API which allows a module to update a user's
presence state/status message. This is useful for controlling presence
from an external system.

To fully control presence from the module the presence.enabled config
parameter gains a new state of "untracked" which disables internal tracking
of presence changes via user actions, etc. Only updates from the module will
be persisted and sent down sync properly).
2023-10-26 15:11:24 -04:00
Patrick Cloke 9407d5ba78 Convert simple_select_list and simple_select_list_txn to return lists of tuples (#16505)
This should use fewer allocations and improves type hints.
2023-10-26 13:01:36 -04:00
David Robertson c14a7de6af Pin the recommended poetry version in contributors' guide (#16550) 2023-10-25 16:31:15 +01:00
Erik Johnston ba47fea528 Allow multiple workers to write to receipts stream. (#16432)
Fixes #16417
2023-10-25 16:16:19 +01:00
Patrick Cloke e182dbb5b9 Fix tests on Twisted trunk. (#16528)
Twisted trunk makes a change to the `TLSMemoryBIOFactory` where
the underlying protocol is changed from `TLSMemoryBIOProtocol` to
`BufferingTLSTransport` to improve performance of TLS code (see
https://github.com/twisted/twisted/issues/11989).

In order to properly hook this code up in tests we need to pass the test
reactor's clock into `TLSMemoryBIOFactory` to avoid the global (trial)
reactor being used by default.

Twisted does something similar internally for tests:
https://github.com/twisted/twisted/blob/157cd8e659705940e895d321339d467e76ae9d0a/src/twisted/web/test/test_agent.py#L871-L874
2023-10-25 07:39:45 -04:00
Richard Brežák 95076f77c1 Fix http/s proxy authentication with long username/passwords (#16504) 2023-10-24 13:45:21 +00:00
David Robertson 2f1065f81b Revert "Add test case to detect dodgy b64 encoding"
This reverts commit 5fe76b9434.

I think I had this accidentally commited on my local develop branch, and
so it accidentally got merged into upstream develop.

This should re-land with corrections in #16504.
2023-10-24 14:34:47 +01:00
David Robertson 2f35424812 Merge branch 'master' into develop 2023-10-24 14:23:20 +01:00
David Robertson c0d2f7649e Merge branch 'develop' of github.com:matrix-org/synapse into develop 2023-10-24 14:23:19 +01:00
David Robertson 6ec98810e3 Rework alias and public room list rules docs (#16541) 2023-10-24 13:26:41 +01:00
Jason Little ffbe9b7666 Remove duplicate call to wake a remote destination when using federation sending worker (#16515) 2023-10-24 08:09:59 -04:00
Michael Sasser 3df70aa800 Replace all Prometheus datasource UIDs of the Grafana Dashboard with the variable ${DS_PROMETHEUS} and remove __inputs (#16471) 2023-10-23 19:50:50 +01:00
David Robertson 5fe76b9434 Add test case to detect dodgy b64 encoding 2023-10-23 19:29:22 +01:00
Patrick Cloke 3ab861ab9e Fix type hint errors from Twisted trunk (#16526) 2023-10-23 14:28:05 -04:00
Erik Johnston 8f35f8148e Fix bug where a new writer advances their token too quickly (#16473)
* Fix bug where a new writer advances their token too quickly

When starting a new writer (for e.g. persisting events), the
`MultiWriterIdGenerator` doesn't have a minimum token for it as there
are no rows matching that new writer in the DB.

This results in the the first stream ID it acquired being announced as
persisted *before* it actually finishes persisting, if another writer
gets and persists a subsequent stream ID. This is due to the logic of
setting the minimum persisted position to the minimum known position of
across all writers, and the new writer starts off not being considered.

* Fix sending out POSITIONs when our token advances without update

Broke in #14820

* For replication HTTP requests, only wait for minimal position
2023-10-23 16:57:30 +01:00
Erik Johnston 3bc23cc45c Fix bug that could cause a /sync to tightloop with sqlite after restart (#16540)
This could happen if the last rows in the account data stream were inserted into `account_data`. After a restart the max account ID would be calculated without looking at the `account_data` table, and so have an old ID.
2023-10-23 13:39:25 +00:00
Marcel 3bcb6a059f Mention how to redirect the Jaeger traces to a specific Jaeger instance (#16531) 2023-10-23 11:55:36 +00:00
Denis Kasak 3a0aa6fe76 Force TLS certificate verification in registration script. (#16530)
If using the script remotely, there's no particularly convincing reason
to disable certificate verification, as this makes the connection
interceptible.

If on the other hand, the script is used locally (the most common use
case), you can simply target the HTTP listener and avoid TLS altogether.
This is what the script already attempts to do if passed a homeserver
configuration YAML file.
2023-10-23 07:38:51 -04:00
Patrick Cloke 12ca87f5ea Remove the last reference to event_txn_id. (#16521)
This table was no longer used, except for a background process
which purged old entries in it.
2023-10-23 07:37:45 -04:00
David Robertson 478a6c65eb Bump matrix-synapse-ldap3 from 0.2.2 to 0.3.0 (#16539) 2023-10-23 12:28:29 +01:00
dependabot[bot] f835ab8de5 Bump black from 23.9.1 to 23.10.0 (#16538)
Bumps [black](https://github.com/psf/black) from 23.9.1 to 23.10.0.
- [Release notes](https://github.com/psf/black/releases)
- [Changelog](https://github.com/psf/black/blob/main/CHANGES.md)
- [Commits](https://github.com/psf/black/compare/23.9.1...23.10.0)

---
updated-dependencies:
- dependency-name: black
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-23 10:25:14 +01:00
dependabot[bot] 786b614fb2 Bump types-requests from 2.31.0.2 to 2.31.0.10 (#16537)
Bumps [types-requests](https://github.com/python/typeshed) from 2.31.0.2 to 2.31.0.10.
- [Commits](https://github.com/python/typeshed/commits)

---
updated-dependencies:
- dependency-name: types-requests
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-23 10:25:02 +01:00
dependabot[bot] a8026209d2 Bump gitpython from 3.1.37 to 3.1.40 (#16534)
Bumps [gitpython](https://github.com/gitpython-developers/GitPython) from 3.1.37 to 3.1.40.
- [Release notes](https://github.com/gitpython-developers/GitPython/releases)
- [Changelog](https://github.com/gitpython-developers/GitPython/blob/main/CHANGES)
- [Commits](https://github.com/gitpython-developers/GitPython/compare/3.1.37...3.1.40)

---
updated-dependencies:
- dependency-name: gitpython
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-23 10:24:46 +01:00
dependabot[bot] 2d12163cb4 Bump types-pillow from 10.0.0.3 to 10.1.0.0 (#16536)
Bumps [types-pillow](https://github.com/python/typeshed) from 10.0.0.3 to 10.1.0.0.
- [Commits](https://github.com/python/typeshed/commits)

---
updated-dependencies:
- dependency-name: types-pillow
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-23 09:46:55 +01:00
dependabot[bot] 9171bf3b35 Bump pygithub from 1.59.1 to 2.1.1 (#16535)
Bumps [pygithub](https://github.com/pygithub/pygithub) from 1.59.1 to 2.1.1.
- [Release notes](https://github.com/pygithub/pygithub/releases)
- [Changelog](https://github.com/PyGithub/PyGithub/blob/main/doc/changes.rst)
- [Commits](https://github.com/pygithub/pygithub/compare/v1.59.1...v2.1.1)

---
updated-dependencies:
- dependency-name: pygithub
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-23 09:45:12 +01:00
Patrick Cloke d2eab22de7 Clarify presence router docs. (#16529) 2023-10-20 11:40:26 -04:00
Erik Johnston e9069c9f91 Mark sync as limited if there is a gap in the timeline (#16485)
This splits thinsg into two queries, but most of the time we won't have
new event backwards extremities so this shouldn't actually add an extra
RTT for the majority of cases.

Note this removes the check for events with no prev events, but that was
part of MSC2716 work that has since been removed.
2023-10-19 15:04:18 +01:00
Patrick Cloke 49c9745b45 Avoid sending massive replication updates when purging a room. (#16510) 2023-10-18 12:26:01 -04:00
Mathieu Velten bcff01b406 Improve performance of delete device messages query (#16492) 2023-10-18 16:42:01 +01:00
Patrick Cloke 8841db4d27 Run trial/integration tests if .ci is modified. (#16512) 2023-10-18 07:19:53 -04:00
dependabot[bot] 19033313e6 Bump urllib3 from 1.26.17 to 1.26.18 (#16516)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-18 11:58:16 +01:00
Patrick Cloke 68d9559fef Test against Python 3.12 release (#16511) 2023-10-17 14:41:10 -04:00
152 changed files with 3372 additions and 1610 deletions
+2 -2
View File
@@ -47,7 +47,7 @@ if not IS_PR:
"database": "sqlite",
"extras": "all",
}
for version in ("3.9", "3.10", "3.11", "3.12.0-rc.2")
for version in ("3.9", "3.10", "3.11", "3.12")
)
trial_postgres_tests = [
@@ -62,7 +62,7 @@ trial_postgres_tests = [
if not IS_PR:
trial_postgres_tests.append(
{
"python-version": "3.11",
"python-version": "3.12",
"database": "postgres",
"postgres-version": "16",
"extras": "all",
+6
View File
@@ -37,15 +37,18 @@ jobs:
- 'Cargo.toml'
- 'Cargo.lock'
- '.rustfmt.toml'
- '.github/workflows/tests.yml'
trial:
- 'synapse/**'
- 'tests/**'
- 'rust/**'
- '.ci/scripts/calculate_jobs.py'
- 'Cargo.toml'
- 'Cargo.lock'
- 'pyproject.toml'
- 'poetry.lock'
- '.github/workflows/tests.yml'
integration:
- 'synapse/**'
@@ -56,7 +59,9 @@ jobs:
- 'pyproject.toml'
- 'poetry.lock'
- 'docker/**'
- '.ci/**'
- 'scripts-dev/complement.sh'
- '.github/workflows/tests.yml'
linting:
- 'synapse/**'
@@ -70,6 +75,7 @@ jobs:
- 'mypy.ini'
- 'pyproject.toml'
- 'poetry.lock'
- '.github/workflows/tests.yml'
check-sampleconfig:
runs-on: ubuntu-latest
+1
View File
@@ -0,0 +1 @@
Allow multiple workers to write to receipts stream.
+1
View File
@@ -0,0 +1 @@
Fixed a bug that prevents Grafana from finding the correct datasource. Contributed by @MichaelSasser.
+1
View File
@@ -0,0 +1 @@
Fix a long-standing, exceedingly rare edge case where the first event persisted by a new event persister worker might not be sent down `/sync`.
+1
View File
@@ -0,0 +1 @@
Fix long-standing bug where `/sync` incorrectly did not mark a room as `limited` in a sync requests when there were missing remote events.
+1
View File
@@ -0,0 +1 @@
Improve performance of delete device messages query, cf issue [16479](https://github.com/matrix-org/synapse/issues/16479).
+1
View File
@@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.41 where HTTP(S) forward proxy authorization would fail when using basic HTTP authentication with a long `username:password` string.
+1
View File
@@ -0,0 +1 @@
Reduce memory allocations.
+1
View File
@@ -0,0 +1 @@
Improve replication performance when purging rooms.
+1
View File
@@ -0,0 +1 @@
Run tests against Python 3.12.
+1
View File
@@ -0,0 +1 @@
Run trial & integration tests in continuous integration when `.ci` directory is modified.
+1
View File
@@ -0,0 +1 @@
Remove duplicate call to mark remote server 'awake' when using a federation sending worker.
+1
View File
@@ -0,0 +1 @@
Stop deleting from an unused table.
+1
View File
@@ -0,0 +1 @@
Improve type hints.
+1
View File
@@ -0,0 +1 @@
Fix running unit tests on Twisted trunk.
+1
View File
@@ -0,0 +1 @@
Improve documentation of presence router.
+1
View File
@@ -0,0 +1 @@
Force TLS certificate verification in user registration script.
+1
View File
@@ -0,0 +1 @@
Add a sentence to the opentracing docs on how you can have jaeger in a different place than synapse.
+1
View File
@@ -0,0 +1 @@
Bump matrix-synapse-ldap3 from 0.2.2 to 0.3.0.
+1
View File
@@ -0,0 +1 @@
Fix long-standing bug where `/sync` could tightloop after restart when using SQLite.
+1
View File
@@ -0,0 +1 @@
Correctly describe the meaning of unspecified rule lists in the [`alias_creation_rules`](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#alias_creation_rules) and [`room_list_publication_rules`](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_list_publication_rules) config options and improve their descriptions more generally.
+1
View File
@@ -0,0 +1 @@
Add a new module API for controller presence.
+1
View File
@@ -0,0 +1 @@
Add a new module API callback that allows adding extra fields to events' unsigned section when sent down to clients.
+1
View File
@@ -0,0 +1 @@
Pin the recommended poetry version in contributors' guide.
+1
View File
@@ -0,0 +1 @@
Improve type hints.
+1
View File
@@ -0,0 +1 @@
Fix a long-standing, exceedingly rare edge case where the first event persisted by a new event persister worker might not be sent down `/sync`.
+1
View File
@@ -0,0 +1 @@
Fix ratelimiting of message sending when using workers, where the ratelimit would only be applied after most of the work has been done.
+1
View File
@@ -0,0 +1 @@
Experimental support for [MSC4072](https://github.com/matrix-org/matrix-spec-proposals/pull/4072): Return a result for all devices requested in a `/keys/claim` request.
File diff suppressed because it is too large Load Diff
+2 -1
View File
@@ -19,7 +19,7 @@
# Usage
- [Federation](federate.md)
- [Configuration](usage/configuration/README.md)
- [Configuration Manual](usage/configuration/config_documentation.md)
- [Configuration Manual](usage/configuration/config_documentation.md)
- [Homeserver Sample Config File](usage/configuration/homeserver_sample_config.md)
- [Logging Sample Config File](usage/configuration/logging_sample_config.md)
- [Structured Logging](structured_logging.md)
@@ -48,6 +48,7 @@
- [Password auth provider callbacks](modules/password_auth_provider_callbacks.md)
- [Background update controller callbacks](modules/background_update_controller_callbacks.md)
- [Account data callbacks](modules/account_data_callbacks.md)
- [Add extra fields to client events unsigned section callbacks](modules/add_extra_fields_to_client_events_unsigned.md)
- [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
- [Workers](workers.md)
- [Using `synctl` with Workers](synctl_workers.md)
+1 -1
View File
@@ -66,7 +66,7 @@ Of their installation methods, we recommend
```shell
pip install --user pipx
pipx install poetry
pipx install poetry==1.5.2 # Problems with Poetry 1.6, see https://github.com/matrix-org/synapse/issues/16147
```
but see poetry's [installation instructions](https://python-poetry.org/docs/#installation)
@@ -51,17 +51,24 @@ will be inserted with that ID.
For any given stream reader (including writers themselves), we may define a per-writer current stream ID:
> The current stream ID _for a writer W_ is the largest stream ID such that
> A current stream ID _for a writer W_ is the largest stream ID such that
> all transactions added by W with equal or smaller ID have completed.
Similarly, there is a "linear" notion of current stream ID:
> The "linear" current stream ID is the largest stream ID such that
> A "linear" current stream ID is the largest stream ID such that
> all facts (added by any writer) with equal or smaller ID have completed.
Because different stream readers A and B learn about new facts at different times, A and B may disagree about current stream IDs.
Put differently: we should think of stream readers as being independent of each other, proceeding through a stream of facts at different rates.
The above definition does not give a unique current stream ID, in fact there can
be a range of current stream IDs. Synapse uses both the minimum and maximum IDs
for different purposes. Most often the maximum is used, as its generally
beneficial for workers to advance their IDs as soon as possible. However, the
minimum is used in situations where e.g. another worker is going to wait until
the stream advances past a position.
**NB.** For both senses of "current", that if a writer opens a transaction that never completes, the current stream ID will never advance beyond that writer's last written stream ID.
For single-writer streams, the per-writer current ID and the linear current ID are the same.
@@ -114,7 +121,7 @@ Writers need to track:
- track their current position (i.e. its own per-writer stream ID).
- their facts currently awaiting completion.
At startup,
At startup,
- the current position of that writer can be found by querying the database (which suggests that facts need to be written to the database atomically, in a transaction); and
- there are no facts awaiting completion.
@@ -0,0 +1,32 @@
# Add extra fields to client events unsigned section callbacks
_First introduced in Synapse v1.96.0_
This callback allows modules to add extra fields to the unsigned section of
events when they get sent down to clients.
These get called *every* time an event is to be sent to clients, so care should
be taken to ensure with respect to performance.
### API
To register the callback, use
`register_add_extra_fields_to_unsigned_client_event_callbacks` on the
`ModuleApi`.
The callback should be of the form
```python
async def add_field_to_unsigned(
event: EventBase,
) -> JsonDict:
```
where the extra fields to add to the event's unsigned section is returned.
(Modules must not attempt to modify the `event` directly).
This cannot be used to alter the "core" fields in the unsigned section emitted
by Synapse itself.
If multiple such callbacks try to add the same field to an event's unsigned
section, the last-registered callback wins.
+11 -3
View File
@@ -1,8 +1,16 @@
# Presence router callbacks
Presence router callbacks allow module developers to specify additional users (local or remote)
to receive certain presence updates from local users. Presence router callbacks can be
registered using the module API's `register_presence_router_callbacks` method.
Presence router callbacks allow module developers to define additional users
which receive presence updates from local users. The additional users
can be local or remote.
For example, it could be used to direct all of `@alice:example.com` (a local user)'s
presence updates to `@bob:matrix.org` (a remote user), even though they don't share a
room. (Note that those presence updates might not make it to `@bob:matrix.org`'s client
unless a similar presence router is running on that homeserver.)
Presence router callbacks can be registered using the module API's
`register_presence_router_callbacks` method.
## Callbacks
+5
View File
@@ -51,6 +51,11 @@ docker run -d --name jaeger \
jaegertracing/all-in-one:1
```
By default, Synapse will publish traces to Jaeger on localhost.
If Jaeger is hosted elsewhere, point Synapse to the correct host by setting
`opentracing.jaeger_config.local_agent.reporting_host` [in the Synapse configuration](usage/configuration/config_documentation.md#opentracing-1)
or by setting the `JAEGER_AGENT_HOST` environment variable to the desired address.
Latest documentation is probably at
https://www.jaegertracing.io/docs/latest/getting-started.
+140 -35
View File
@@ -230,6 +230,13 @@ Example configuration:
presence:
enabled: false
```
`enabled` can also be set to a special value of "untracked" which ignores updates
received via clients and federation, while still accepting updates from the
[module API](../../modules/index.md).
*The "untracked" option was added in Synapse 1.96.0.*
---
### `require_auth_for_profile_requests`
@@ -3797,62 +3804,160 @@ enable_room_list_search: false
---
### `alias_creation_rules`
The `alias_creation_rules` option controls who is allowed to create aliases
on this server.
The `alias_creation_rules` option allows server admins to prevent unwanted
alias creation on this server.
The format of this option is a list of rules that contain globs that
match against user_id, room_id and the new alias (fully qualified with
server name). The action in the first rule that matches is taken,
which can currently either be "allow" or "deny".
This setting is an optional list of 0 or more rules. By default, no list is
provided, meaning that all alias creations are permitted.
Missing user_id/room_id/alias fields default to "*".
Otherwise, requests to create aliases are matched against each rule in order.
The first rule that matches decides if the request is allowed or denied. If no
rule matches, the request is denied. In particular, this means that configuring
an empty list of rules will deny every alias creation request.
If no rules match the request is denied. An empty list means no one
can create aliases.
Each rule is a YAML object containing four fields, each of which is an optional string:
Options for the rules include:
* `user_id`: Matches against the creator of the alias. Defaults to "*".
* `alias`: Matches against the alias being created. Defaults to "*".
* `room_id`: Matches against the room ID the alias is being pointed at. Defaults to "*"
* `action`: Whether to "allow" or "deny" the request if the rule matches. Defaults to allow.
* `user_id`: a glob pattern that matches against the creator of the alias.
* `alias`: a glob pattern that matches against the alias being created.
* `room_id`: a glob pattern that matches against the room ID the alias is being pointed at.
* `action`: either `allow` or `deny`. What to do with the request if the rule matches. Defaults to `allow`.
Each of the glob patterns is optional, defaulting to `*` ("match anything").
Note that the patterns match against fully qualified IDs, e.g. against
`@alice:example.com`, `#room:example.com` and `!abcdefghijk:example.com` instead
of `alice`, `room` and `abcedgghijk`.
Example configuration:
```yaml
# No rule list specified. All alias creations are allowed.
# This is the default behaviour.
alias_creation_rules:
- user_id: "bad_user"
alias: "spammy_alias"
room_id: "*"
action: deny
```
```yaml
# A list of one rule which allows everything.
# This has the same effect as the previous example.
alias_creation_rules:
- "action": "allow"
```
```yaml
# An empty list of rules. All alias creations are denied.
alias_creation_rules: []
```
```yaml
# A list of one rule which denies everything.
# This has the same effect as the previous example.
alias_creation_rules:
- "action": "deny"
```
```yaml
# Prevent a specific user from creating aliases.
# Allow other users to create any alias
alias_creation_rules:
- user_id: "@bad_user:example.com"
action: deny
- action: allow
```
```yaml
# Prevent aliases being created which point to a specific room.
alias_creation_rules:
- room_id: "!forbiddenRoom:example.com"
action: deny
- action: allow
```
---
### `room_list_publication_rules`
The `room_list_publication_rules` option controls who can publish and
which rooms can be published in the public room list.
The `room_list_publication_rules` option allows server admins to prevent
unwanted entries from being published in the public room list.
The format of this option is the same as that for
`alias_creation_rules`.
[`alias_creation_rules`](#alias_creation_rules): an optional list of 0 or more
rules. By default, no list is provided, meaning that all rooms may be
published to the room list.
If the room has one or more aliases associated with it, only one of
the aliases needs to match the alias rule. If there are no aliases
then only rules with `alias: *` match.
Otherwise, requests to publish a room are matched against each rule in order.
The first rule that matches decides if the request is allowed or denied. If no
rule matches, the request is denied. In particular, this means that configuring
an empty list of rules will deny every alias creation request.
If no rules match the request is denied. An empty list means no one
can publish rooms.
Each rule is a YAML object containing four fields, each of which is an optional string:
* `user_id`: a glob pattern that matches against the user publishing the room.
* `alias`: a glob pattern that matches against one of published room's aliases.
- If the room has no aliases, the alias match fails unless `alias` is unspecified or `*`.
- If the room has exactly one alias, the alias match succeeds if the `alias` pattern matches that alias.
- If the room has two or more aliases, the alias match succeeds if the pattern matches at least one of the aliases.
* `room_id`: a glob pattern that matches against the room ID of the room being published.
* `action`: either `allow` or `deny`. What to do with the request if the rule matches. Defaults to `allow`.
Each of the glob patterns is optional, defaulting to `*` ("match anything").
Note that the patterns match against fully qualified IDs, e.g. against
`@alice:example.com`, `#room:example.com` and `!abcdefghijk:example.com` instead
of `alice`, `room` and `abcedgghijk`.
Options for the rules include:
* `user_id`: Matches against the creator of the alias. Defaults to "*".
* `alias`: Matches against any current local or canonical aliases associated with the room. Defaults to "*".
* `room_id`: Matches against the room ID being published. Defaults to "*".
* `action`: Whether to "allow" or "deny" the request if the rule matches. Defaults to allow.
Example configuration:
```yaml
# No rule list specified. Anyone may publish any room to the public list.
# This is the default behaviour.
room_list_publication_rules:
- user_id: "*"
alias: "*"
room_id: "*"
action: allow
```
```yaml
# A list of one rule which allows everything.
# This has the same effect as the previous example.
room_list_publication_rules:
- "action": "allow"
```
```yaml
# An empty list of rules. No-one may publish to the room list.
room_list_publication_rules: []
```
```yaml
# A list of one rule which denies everything.
# This has the same effect as the previous example.
room_list_publication_rules:
- "action": "deny"
```
```yaml
# Prevent a specific user from publishing rooms.
# Allow other users to publish anything.
room_list_publication_rules:
- user_id: "@bad_user:example.com"
action: deny
- action: allow
```
```yaml
# Prevent publication of a specific room.
room_list_publication_rules:
- room_id: "!forbiddenRoom:example.com"
action: deny
- action: allow
```
```yaml
# Prevent publication of rooms with at least one alias containing the word "potato".
room_list_publication_rules:
- alias: "#*potato*:example.com"
action: deny
- action: allow
```
---
Generated
+52 -63
View File
@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]]
name = "alabaster"
@@ -162,33 +162,29 @@ lxml = ["lxml"]
[[package]]
name = "black"
version = "23.9.1"
version = "23.10.0"
description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.8"
files = [
{file = "black-23.9.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:d6bc09188020c9ac2555a498949401ab35bb6bf76d4e0f8ee251694664df6301"},
{file = "black-23.9.1-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:13ef033794029b85dfea8032c9d3b92b42b526f1ff4bf13b2182ce4e917f5100"},
{file = "black-23.9.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:75a2dc41b183d4872d3a500d2b9c9016e67ed95738a3624f4751a0cb4818fe71"},
{file = "black-23.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13a2e4a93bb8ca74a749b6974925c27219bb3df4d42fc45e948a5d9feb5122b7"},
{file = "black-23.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:adc3e4442eef57f99b5590b245a328aad19c99552e0bdc7f0b04db6656debd80"},
{file = "black-23.9.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:8431445bf62d2a914b541da7ab3e2b4f3bc052d2ccbf157ebad18ea126efb91f"},
{file = "black-23.9.1-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:8fc1ddcf83f996247505db6b715294eba56ea9372e107fd54963c7553f2b6dfe"},
{file = "black-23.9.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:7d30ec46de88091e4316b17ae58bbbfc12b2de05e069030f6b747dfc649ad186"},
{file = "black-23.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:031e8c69f3d3b09e1aa471a926a1eeb0b9071f80b17689a655f7885ac9325a6f"},
{file = "black-23.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:538efb451cd50f43aba394e9ec7ad55a37598faae3348d723b59ea8e91616300"},
{file = "black-23.9.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:638619a559280de0c2aa4d76f504891c9860bb8fa214267358f0a20f27c12948"},
{file = "black-23.9.1-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:a732b82747235e0542c03bf352c126052c0fbc458d8a239a94701175b17d4855"},
{file = "black-23.9.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:cf3a4d00e4cdb6734b64bf23cd4341421e8953615cba6b3670453737a72ec204"},
{file = "black-23.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf99f3de8b3273a8317681d8194ea222f10e0133a24a7548c73ce44ea1679377"},
{file = "black-23.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:14f04c990259576acd093871e7e9b14918eb28f1866f91968ff5524293f9c573"},
{file = "black-23.9.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:c619f063c2d68f19b2d7270f4cf3192cb81c9ec5bc5ba02df91471d0b88c4c5c"},
{file = "black-23.9.1-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:6a3b50e4b93f43b34a9d3ef00d9b6728b4a722c997c99ab09102fd5efdb88325"},
{file = "black-23.9.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c46767e8df1b7beefb0899c4a95fb43058fa8500b6db144f4ff3ca38eb2f6393"},
{file = "black-23.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50254ebfa56aa46a9fdd5d651f9637485068a1adf42270148cd101cdf56e0ad9"},
{file = "black-23.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:403397c033adbc45c2bd41747da1f7fc7eaa44efbee256b53842470d4ac5a70f"},
{file = "black-23.9.1-py3-none-any.whl", hash = "sha256:6ccd59584cc834b6d127628713e4b6b968e5f79572da66284532525a042549f9"},
{file = "black-23.9.1.tar.gz", hash = "sha256:24b6b3ff5c6d9ea08a8888f6977eae858e1f340d7260cf56d70a49823236b62d"},
{file = "black-23.10.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:f8dc7d50d94063cdfd13c82368afd8588bac4ce360e4224ac399e769d6704e98"},
{file = "black-23.10.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:f20ff03f3fdd2fd4460b4f631663813e57dc277e37fb216463f3b907aa5a9bdd"},
{file = "black-23.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3d9129ce05b0829730323bdcb00f928a448a124af5acf90aa94d9aba6969604"},
{file = "black-23.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:960c21555be135c4b37b7018d63d6248bdae8514e5c55b71e994ad37407f45b8"},
{file = "black-23.10.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:30b78ac9b54cf87bcb9910ee3d499d2bc893afd52495066c49d9ee6b21eee06e"},
{file = "black-23.10.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:0e232f24a337fed7a82c1185ae46c56c4a6167fb0fe37411b43e876892c76699"},
{file = "black-23.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31946ec6f9c54ed7ba431c38bc81d758970dd734b96b8e8c2b17a367d7908171"},
{file = "black-23.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:c870bee76ad5f7a5ea7bd01dc646028d05568d33b0b09b7ecfc8ec0da3f3f39c"},
{file = "black-23.10.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:6901631b937acbee93c75537e74f69463adaf34379a04eef32425b88aca88a23"},
{file = "black-23.10.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:481167c60cd3e6b1cb8ef2aac0f76165843a374346aeeaa9d86765fe0dd0318b"},
{file = "black-23.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f74892b4b836e5162aa0452393112a574dac85e13902c57dfbaaf388e4eda37c"},
{file = "black-23.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:47c4510f70ec2e8f9135ba490811c071419c115e46f143e4dce2ac45afdcf4c9"},
{file = "black-23.10.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:76baba9281e5e5b230c9b7f83a96daf67a95e919c2dfc240d9e6295eab7b9204"},
{file = "black-23.10.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:a3c2ddb35f71976a4cfeca558848c2f2f89abc86b06e8dd89b5a65c1e6c0f22a"},
{file = "black-23.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db451a3363b1e765c172c3fd86213a4ce63fb8524c938ebd82919bf2a6e28c6a"},
{file = "black-23.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:7fb5fc36bb65160df21498d5a3dd330af8b6401be3f25af60c6ebfe23753f747"},
{file = "black-23.10.0-py3-none-any.whl", hash = "sha256:e223b731a0e025f8ef427dd79d8cd69c167da807f5710add30cdf131f13dd62e"},
{file = "black-23.10.0.tar.gz", hash = "sha256:31b9f87b277a68d0e99d2905edae08807c007973eaa609da5f0c62def6b7c0bd"},
]
[package.dependencies]
@@ -600,20 +596,20 @@ smmap = ">=3.0.1,<6"
[[package]]
name = "gitpython"
version = "3.1.37"
version = "3.1.40"
description = "GitPython is a Python library used to interact with Git repositories"
optional = false
python-versions = ">=3.7"
files = [
{file = "GitPython-3.1.37-py3-none-any.whl", hash = "sha256:5f4c4187de49616d710a77e98ddf17b4782060a1788df441846bddefbb89ab33"},
{file = "GitPython-3.1.37.tar.gz", hash = "sha256:f9b9ddc0761c125d5780eab2d64be4873fc6817c2899cbcb34b02344bdc7bc54"},
{file = "GitPython-3.1.40-py3-none-any.whl", hash = "sha256:cf14627d5a8049ffbf49915732e5eddbe8134c3bdb9d476e6182b676fc573f8a"},
{file = "GitPython-3.1.40.tar.gz", hash = "sha256:22b126e9ffb671fdd0c129796343a02bf67bf2994b35449ffc9321aa755e18a4"},
]
[package.dependencies]
gitdb = ">=4.0.1,<5"
[package.extras]
test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-sugar"]
test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"]
[[package]]
name = "hiredis"
@@ -1341,13 +1337,13 @@ test = ["aiounittest", "tox", "twisted"]
[[package]]
name = "matrix-synapse-ldap3"
version = "0.2.2"
version = "0.3.0"
description = "An LDAP3 auth provider for Synapse"
optional = true
python-versions = ">=3.7"
files = [
{file = "matrix-synapse-ldap3-0.2.2.tar.gz", hash = "sha256:b388d95693486eef69adaefd0fd9e84463d52fe17b0214a00efcaa669b73cb74"},
{file = "matrix_synapse_ldap3-0.2.2-py3-none-any.whl", hash = "sha256:66ee4c85d7952c6c27fd04c09cdfdf4847b8e8b7d6a7ada6ba1100013bda060f"},
{file = "matrix-synapse-ldap3-0.3.0.tar.gz", hash = "sha256:8bb6517173164d4b9cc44f49de411d8cebdb2e705d5dd1ea1f38733c4a009e1d"},
{file = "matrix_synapse_ldap3-0.3.0-py3-none-any.whl", hash = "sha256:8b4d701f8702551e98cc1d8c20dbed532de5613584c08d0df22de376ba99159d"},
]
[package.dependencies]
@@ -1980,20 +1976,23 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pygithub"
version = "1.59.1"
version = "2.1.1"
description = "Use the full Github API v3"
optional = false
python-versions = ">=3.7"
files = [
{file = "PyGithub-1.59.1-py3-none-any.whl", hash = "sha256:3d87a822e6c868142f0c2c4bf16cce4696b5a7a4d142a7bd160e1bdf75bc54a9"},
{file = "PyGithub-1.59.1.tar.gz", hash = "sha256:c44e3a121c15bf9d3a5cc98d94c9a047a5132a9b01d22264627f58ade9ddc217"},
{file = "PyGithub-2.1.1-py3-none-any.whl", hash = "sha256:4b528d5d6f35e991ea5fd3f942f58748f24938805cb7fcf24486546637917337"},
{file = "PyGithub-2.1.1.tar.gz", hash = "sha256:ecf12c2809c44147bce63b047b3d2e9dac8a41b63e90fcb263c703f64936b97c"},
]
[package.dependencies]
deprecated = "*"
Deprecated = "*"
pyjwt = {version = ">=2.4.0", extras = ["crypto"]}
pynacl = ">=1.4.0"
python-dateutil = "*"
requests = ">=2.14.0"
typing-extensions = ">=4.0.0"
urllib3 = ">=1.26.0"
[[package]]
name = "pygments"
@@ -2137,7 +2136,7 @@ s2repoze = ["paste", "repoze.who", "zope.interface"]
name = "python-dateutil"
version = "2.8.2"
description = "Extensions to the standard Python datetime module"
optional = true
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
@@ -3106,13 +3105,13 @@ files = [
[[package]]
name = "types-pillow"
version = "10.0.0.3"
version = "10.1.0.0"
description = "Typing stubs for Pillow"
optional = false
python-versions = "*"
python-versions = ">=3.7"
files = [
{file = "types-Pillow-10.0.0.3.tar.gz", hash = "sha256:ae0c877d363da349bbb82c5463c9e78037290cc07d3714cb0ceaf5d2f7f5c825"},
{file = "types_Pillow-10.0.0.3-py3-none-any.whl", hash = "sha256:54a49f3c6a3f5e95ebeee396d7773dde22ce2515d594f9c0596c0a983558f0d4"},
{file = "types-Pillow-10.1.0.0.tar.gz", hash = "sha256:0f5e7cf010ed226800cb5821e87781e5d0e81257d948a9459baa74a8c8b7d822"},
{file = "types_Pillow-10.1.0.0-py3-none-any.whl", hash = "sha256:f97f596b6a39ddfd26da3eb67421062193e10732d2310f33898d36f9694331b5"},
]
[[package]]
@@ -3153,17 +3152,17 @@ files = [
[[package]]
name = "types-requests"
version = "2.31.0.2"
version = "2.31.0.10"
description = "Typing stubs for requests"
optional = false
python-versions = "*"
python-versions = ">=3.7"
files = [
{file = "types-requests-2.31.0.2.tar.gz", hash = "sha256:6aa3f7faf0ea52d728bb18c0a0d1522d9bfd8c72d26ff6f61bfc3d06a411cf40"},
{file = "types_requests-2.31.0.2-py3-none-any.whl", hash = "sha256:56d181c85b5925cbc59f4489a57e72a8b2166f18273fd8ba7b6fe0c0b986f12a"},
{file = "types-requests-2.31.0.10.tar.gz", hash = "sha256:dc5852a76f1eaf60eafa81a2e50aefa3d1f015c34cf0cba130930866b1b22a92"},
{file = "types_requests-2.31.0.10-py3-none-any.whl", hash = "sha256:b32b9a86beffa876c0c3ac99a4cd3b8b51e973fb8e3bd4e0a6bb32c7efad80fc"},
]
[package.dependencies]
types-urllib3 = "*"
urllib3 = ">=2"
[[package]]
name = "types-setuptools"
@@ -3176,17 +3175,6 @@ files = [
{file = "types_setuptools-68.2.0.0-py3-none-any.whl", hash = "sha256:77edcc843e53f8fc83bb1a840684841f3dc804ec94562623bfa2ea70d5a2ba1b"},
]
[[package]]
name = "types-urllib3"
version = "1.26.25.8"
description = "Typing stubs for urllib3"
optional = false
python-versions = "*"
files = [
{file = "types-urllib3-1.26.25.8.tar.gz", hash = "sha256:ecf43c42d8ee439d732a1110b4901e9017a79a38daca26f08e42c8460069392c"},
{file = "types_urllib3-1.26.25.8-py3-none-any.whl", hash = "sha256:95ea847fbf0bf675f50c8ae19a665baedcf07e6b4641662c4c3c72e7b2edf1a9"},
]
[[package]]
name = "typing-extensions"
version = "4.8.0"
@@ -3211,19 +3199,20 @@ files = [
[[package]]
name = "urllib3"
version = "1.26.17"
version = "2.0.7"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
python-versions = ">=3.7"
files = [
{file = "urllib3-1.26.17-py2.py3-none-any.whl", hash = "sha256:94a757d178c9be92ef5539b8840d48dc9cf1b2709c9d6b588232a055c524458b"},
{file = "urllib3-1.26.17.tar.gz", hash = "sha256:24d6a242c28d29af46c3fae832c36db3bbebcc533dd1bb549172cd739c82df21"},
{file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"},
{file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"},
]
[package.extras]
brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "webencodings"
+2 -2
View File
@@ -50,7 +50,7 @@ def request_registration(
url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
# Get the nonce
r = requests.get(url, verify=False)
r = requests.get(url)
if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
@@ -88,7 +88,7 @@ def request_registration(
}
_print("Sending registration request...")
r = requests.post(url, json=data, verify=False)
r = requests.post(url, json=data)
if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
+6 -1
View File
@@ -15,7 +15,6 @@
import enum
from typing import TYPE_CHECKING, Any, Optional
import attr
import attr.validators
from synapse.api.errors import LimitExceededError
@@ -419,3 +418,9 @@ class ExperimentalConfig(Config):
self.msc4028_push_encrypted_events = experimental.get(
"msc4028_push_encrypted_events", False
)
# MSC4072: Return an empty dict from /keys/claim for unknown devices or those
# with exhausted OTKs
self.msc4072_empty_dict_for_exhausted_devices = experimental.get(
"msc4072_empty_dict_for_exhausted_devices", False
)
+8 -3
View File
@@ -368,9 +368,14 @@ class ServerConfig(Config):
# Whether to enable user presence.
presence_config = config.get("presence") or {}
self.use_presence = presence_config.get("enabled")
if self.use_presence is None:
self.use_presence = config.get("use_presence", True)
presence_enabled = presence_config.get("enabled")
if presence_enabled is None:
presence_enabled = config.get("use_presence", True)
# Whether presence is enabled *at all*.
self.presence_enabled = bool(presence_enabled)
# Whether to internally track presence, requires that presence is enabled,
self.track_presence = self.presence_enabled and presence_enabled != "untracked"
# Custom presence router module
# This is the legacy way of configuring it (the config should now be put in the modules section)
+2 -2
View File
@@ -358,9 +358,9 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `account_data` messages."
)
if len(self.writers.receipts) != 1:
if len(self.writers.receipts) == 0:
raise ConfigError(
"Must only specify one instance to handle `receipts` messages."
"Must specify at least one instance to handle `receipts` messages."
)
if len(self.writers.events) == 0:
+41 -7
View File
@@ -17,6 +17,7 @@ import re
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
@@ -45,6 +46,7 @@ from . import EventBase
if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer
# Split strings on "." but not "\." (or "\\\.").
@@ -56,6 +58,13 @@ CANONICALJSON_MAX_INT = (2**53) - 1
CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
# Module API callback that allows adding fields to the unsigned section of
# events that are sent to clients.
ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK = Callable[
[EventBase], Awaitable[JsonDict]
]
def prune_event(event: EventBase) -> EventBase:
"""Returns a pruned version of the given event, which removes all keys we
don't know about or think could potentially be dodgy.
@@ -509,7 +518,13 @@ class EventClientSerializer:
clients.
"""
def serialize_event(
def __init__(self, hs: "HomeServer") -> None:
self._store = hs.get_datastores().main
self._add_extra_fields_to_unsigned_client_event_callbacks: List[
ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
] = []
async def serialize_event(
self,
event: Union[JsonDict, EventBase],
time_now: int,
@@ -535,10 +550,21 @@ class EventClientSerializer:
serialized_event = serialize_event(event, time_now, config=config)
new_unsigned = {}
for callback in self._add_extra_fields_to_unsigned_client_event_callbacks:
u = await callback(event)
new_unsigned.update(u)
if new_unsigned:
# We do the `update` this way round so that modules can't clobber
# existing fields.
new_unsigned.update(serialized_event["unsigned"])
serialized_event["unsigned"] = new_unsigned
# Check if there are any bundled aggregations to include with the event.
if bundle_aggregations:
if event.event_id in bundle_aggregations:
self._inject_bundled_aggregations(
await self._inject_bundled_aggregations(
event,
time_now,
config,
@@ -548,7 +574,7 @@ class EventClientSerializer:
return serialized_event
def _inject_bundled_aggregations(
async def _inject_bundled_aggregations(
self,
event: EventBase,
time_now: int,
@@ -590,7 +616,7 @@ class EventClientSerializer:
# said that we should only include the `event_id`, `origin_server_ts` and
# `sender` of the edit; however MSC3925 proposes extending it to the whole
# of the edit, which is what we do here.
serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event(
serialized_aggregations[RelationTypes.REPLACE] = await self.serialize_event(
event_aggregations.replace,
time_now,
config=config,
@@ -600,7 +626,7 @@ class EventClientSerializer:
if event_aggregations.thread:
thread = event_aggregations.thread
serialized_latest_event = self.serialize_event(
serialized_latest_event = await self.serialize_event(
thread.latest_event,
time_now,
config=config,
@@ -623,7 +649,7 @@ class EventClientSerializer:
"m.relations", {}
).update(serialized_aggregations)
def serialize_events(
async def serialize_events(
self,
events: Iterable[Union[JsonDict, EventBase]],
time_now: int,
@@ -645,7 +671,7 @@ class EventClientSerializer:
The list of serialized events
"""
return [
self.serialize_event(
await self.serialize_event(
event,
time_now,
config=config,
@@ -654,6 +680,14 @@ class EventClientSerializer:
for event in events
]
def register_add_extra_fields_to_unsigned_client_event_callback(
self, callback: ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
) -> None:
"""Register a callback that returns additions to the unsigned section of
serialized events.
"""
self._add_extra_fields_to_unsigned_client_event_callbacks.append(callback)
_PowerLevel = Union[str, int]
PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
+7 -13
View File
@@ -84,7 +84,7 @@ from synapse.replication.http.federation import (
from synapse.storage.databases.main.lock import Lock
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.roommember import MemberSummary
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.types import JsonDict, JsonSerializable, StateMap, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
@@ -1000,19 +1000,13 @@ class FederationServer(FederationBase):
self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
) -> Dict[str, Any]:
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys
json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
await self._e2e_keys_handler.claim_local_one_time_keys(
query,
always_include_fallback_keys=always_include_fallback_keys,
result_dict=json_result,
)
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {}).setdefault(device_id, {})[
key_id
] = key
logger.info(
"Claimed one-time-keys: %s",
",".join(
@@ -1395,7 +1389,7 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
if not self.config.server.use_presence and edu_type == EduTypes.PRESENCE:
if not self.config.server.track_presence and edu_type == EduTypes.PRESENCE:
return
# Check if we have a handler on this instance
+1 -1
View File
@@ -844,7 +844,7 @@ class FederationSender(AbstractFederationSender):
destinations (list[str])
"""
if not states or not self.hs.config.server.use_presence:
if not states or not self.hs.config.server.track_presence:
# No-op if presence is disabled.
return
+24 -20
View File
@@ -47,6 +47,7 @@ from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
MultiWriterStreamToken,
RoomAlias,
RoomStreamToken,
StreamKeyType,
@@ -217,7 +218,7 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
"""
@@ -259,19 +260,6 @@ class ApplicationServicesHandler:
):
return
# Assert that new_token is an integer (and not a RoomStreamToken).
# All of the supported streams that this function handles use an
# integer to track progress (rather than a RoomStreamToken - a
# vector clock implementation) as they don't support multiple
# stream writers.
#
# As a result, we simply assert that new_token is an integer.
# If we do end up needing to pass a RoomStreamToken down here
# in the future, using RoomStreamToken.stream (the minimum stream
# position) to convert to an ascending integer value should work.
# Additional context: https://github.com/matrix-org/synapse/pull/11137
assert isinstance(new_token, int)
# Ignore to-device messages if the feature flag is not enabled
if (
stream_key == StreamKeyType.TO_DEVICE
@@ -286,6 +274,9 @@ class ApplicationServicesHandler:
):
return
# We know we're not a `RoomStreamToken` at this point.
assert not isinstance(new_token, RoomStreamToken)
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
@@ -327,7 +318,7 @@ class ApplicationServicesHandler:
self,
services: List[ApplicationService],
stream_key: StreamKeyType,
new_token: int,
new_token: Union[int, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
logger.debug("Checking interested services for %s", stream_key)
@@ -340,6 +331,7 @@ class ApplicationServicesHandler:
#
# Instead we simply grab the latest typing updates in _handle_typing
# and, if they apply to this application service, send it off.
assert isinstance(new_token, int)
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -350,15 +342,23 @@ class ApplicationServicesHandler:
(service.id, stream_key)
):
if stream_key == StreamKeyType.RECEIPT:
assert isinstance(new_token, MultiWriterStreamToken)
# We store appservice tokens as integers, so we ignore
# the `instance_map` components and instead simply
# follow the base stream position.
new_token = MultiWriterStreamToken(stream=new_token.stream)
events = await self._handle_receipts(service, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
service, "read_receipt", new_token
service, "read_receipt", new_token.stream
)
elif stream_key == StreamKeyType.PRESENCE:
assert isinstance(new_token, int)
events = await self._handle_presence(service, users, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -368,6 +368,7 @@ class ApplicationServicesHandler:
)
elif stream_key == StreamKeyType.TO_DEVICE:
assert isinstance(new_token, int)
# Retrieve a list of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve.
to_device_messages = await self._get_to_device_messages(
@@ -383,6 +384,7 @@ class ApplicationServicesHandler:
)
elif stream_key == StreamKeyType.DEVICE_LIST:
assert isinstance(new_token, int)
device_list_summary = await self._get_device_list_summary(
service, new_token
)
@@ -432,7 +434,7 @@ class ApplicationServicesHandler:
return typing
async def _handle_receipts(
self, service: ApplicationService, new_token: int
self, service: ApplicationService, new_token: MultiWriterStreamToken
) -> List[JsonMapping]:
"""
Return the latest read receipts that the given application service should receive.
@@ -455,15 +457,17 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
if new_token is not None and new_token <= from_key:
if new_token is not None and new_token.stream <= from_key:
logger.debug(
"Rejecting token lower than or equal to stored: %s" % (new_token,)
)
return []
from_token = MultiWriterStreamToken(stream=from_key)
receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key, to_key=new_token
service=service, from_key=from_token, to_key=new_token
)
return receipts
@@ -857,7 +861,7 @@ class ApplicationServicesHandler:
Returns:
A tuple of:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
A map of user ID -> a map device ID -> a map of key ID -> key.
A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support
+2 -2
View File
@@ -103,10 +103,10 @@ class DeactivateAccountHandler:
# Attempt to unbind any known bound threepids to this account from identity
# server(s).
bound_threepids = await self.store.user_get_bound_threepids(user_id)
for threepid in bound_threepids:
for medium, address in bound_threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
user_id, threepid["medium"], threepid["address"], id_server
user_id, medium, address, id_server
)
except Exception:
# Do we want this to be a fatal error or should we carry on?
+2
View File
@@ -592,6 +592,8 @@ class DeviceHandler(DeviceWorkerHandler):
)
# Delete device messages asynchronously and in batches using the task scheduler
# We specify an upper stream id to avoid deleting non delivered messages
# if an user re-uses a device ID.
await self._task_scheduler.schedule_task(
DELETE_DEVICE_MSGS_TASK_NAME,
resource_id=device_id,
+118 -24
View File
@@ -32,6 +32,7 @@ from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.types import (
JsonDict,
JsonMapping,
JsonSerializable,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@@ -560,7 +561,8 @@ class E2eKeysHandler:
self,
local_query: List[Tuple[str, str, str, int]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
result_dict: Dict[str, Dict[str, Dict[str, JsonSerializable]]],
) -> None:
"""Claim one time keys for local users.
1. Attempt to claim OTKs from the database.
@@ -570,18 +572,34 @@ class E2eKeysHandler:
Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
always_include_fallback_keys: True to always include fallback keys.
Returns:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
result_dict: A dict to update with the results.
{user_id -> { device_id -> { key_id -> key string/object }}}
"""
def update_result_dict(
results: Mapping[str, Mapping[str, Mapping[str, JsonSerializable]]]
) -> None:
"""Stash results from a store query in `result_dict`"""
for user_id, device_keys in results.items():
user_result_dict = result_dict.setdefault(user_id, {})
for device_id, keys in device_keys.items():
device_result_dict = user_result_dict.setdefault(device_id, {})
device_result_dict.update(keys)
# Cap the number of OTKs that can be claimed at once to avoid abuse.
local_query = [
(user_id, device_id, algorithm, min(count, 5))
for user_id, device_id, algorithm, count in local_query
]
# prepopulate the response to make sure that all queried users/devices are
# included, even if the user/device is unknown or has run out of OTKs
if self.config.experimental.msc4072_empty_dict_for_exhausted_devices:
for user_id, device_id, _, _ in local_query:
result_dict.setdefault(user_id, {}).setdefault(device_id, {})
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
update_result_dict(otk_results)
# If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys.
@@ -592,6 +610,7 @@ class E2eKeysHandler:
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
update_result_dict(appservice_results)
else:
appservice_results = {}
@@ -646,10 +665,7 @@ class E2eKeysHandler:
# For each user that does not have a one-time keys available, see if
# there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
# Return the results in order, each item from the input query should
# only appear once in the combined list.
return (otk_results, appservice_results, fallback_results)
update_result_dict(fallback_results)
@trace
async def claim_one_time_keys(
@@ -659,6 +675,25 @@ class E2eKeysHandler:
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
"""
Handle a /keys/claim request.
Handles requests for local users with a db lookup, and makes federation
requests for remote users.
Args:
query: map from user ID, to map from device ID, to map from algorithm name
to number of keys needed
(``{user_id: {device_id: {algorithm: number_of keys}}}``)
user: The user id of the requesting user
timeout: number of milliseconds to wait for the response from remote servers.
``config.federation.client_timeout_ms`` by default.
always_include_fallback_keys: True to always include fallback keys, even
for devices which still have one-time keys.
"""
local_query: List[Tuple[str, str, str, int]] = []
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
@@ -672,22 +707,19 @@ class E2eKeysHandler:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))
results = await self.claim_local_one_time_keys(
local_query, always_include_fallback_keys
log_kv(
{
"message": "claiming one time keys",
"local query": local_query,
"remote queries, by server": remote_queries,
}
)
# A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {}).setdefault(
device_id, {}
).update({key_id: key})
json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
await self.claim_local_one_time_keys(
local_query, always_include_fallback_keys, json_result
)
# Remote failures.
failures: Dict[str, JsonDict] = {}
@@ -700,9 +732,18 @@ class E2eKeysHandler:
remote_result = await self.federation.claim_client_keys(
user, destination, device_keys, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
try:
destination_result = filter_remote_claimed_keys(
device_keys,
remote_result,
self.config.experimental.msc4072_empty_dict_for_exhausted_devices,
)
except Exception as e:
logger.warning(
f"Error parsing /keys/claim response from server {destination}",
e,
)
raise
except Exception as e:
failure = _exception_to_failure(e)
@@ -710,6 +751,11 @@ class E2eKeysHandler:
set_tag("error", True)
set_tag("reason", str(failure))
else:
# only populate json_result once we know there will not be an entry in
# failures for this destination.
json_result.update(destination_result)
await make_deferred_yieldable(
defer.gatherResults(
[
@@ -1625,3 +1671,51 @@ class SigningKeyEduUpdater:
device_ids = device_ids + new_device_ids
await self._device_handler.notify_device_update(user_id, device_ids)
def filter_remote_claimed_keys(
destination_query: Dict[str, Dict[str, Dict[str, int]]],
remote_response: JsonDict,
msc4072_empty_dict_for_exhausted_devices: bool,
) -> JsonDict:
"""
Process the response from a federation /keys/claim request
Checks that there are no redundant entries, and that all the entries that
should be there are present.
Args:
destination_query: user->device->key map that was sent in the request to
this server
remote_response: response from the remote server
msc4072_empty_dict_for_exhausted_devices: true to include an entry in the
result for every queried device
Returns:
user->device->key map to be merged into the results
"""
remote_otks = remote_response["one_time_keys"]
destination_result: JsonDict = {}
if msc4072_empty_dict_for_exhausted_devices:
# We need to make sure there is an entry in destination_result for
# every queried (user, device) even if the remote server did not
# populate it; so we iterate the query and populate
# destination_result based on the federation result.
for user_id, user_query in destination_query.items():
remote_user_result = remote_otks.get(user_id, {})
destination_user_result = destination_result[user_id] = {}
for device_id in user_query.keys():
destination_user_result[device_id] = remote_user_result.get(
device_id, {}
)
else:
# We need to make sure that remote servers do not poison the
# result with data for users which do not belong to it, so we only
# copy data for users that were queried.
for user_id, keys in remote_otks.items():
if user_id in destination_query:
destination_result[user_id] = keys
return destination_result
+1 -1
View File
@@ -120,7 +120,7 @@ class EventStreamHandler:
events.extend(to_add)
chunks = self._event_serializer.serialize_events(
chunks = await self._event_serializer.serialize_events(
events,
time_now,
config=SerializeEventConfig(
+9 -9
View File
@@ -19,6 +19,8 @@ import logging
import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
import attr
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -357,9 +359,9 @@ class IdentityHandler:
# Check to see if a session already exists and that it is not yet
# marked as validated
if session and session.get("validated_at") is None:
session_id = session["session_id"]
last_send_attempt = session["last_send_attempt"]
if session and session.validated_at is None:
session_id = session.session_id
last_send_attempt = session.last_send_attempt
# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
@@ -480,7 +482,6 @@ class IdentityHandler:
# We don't actually know which medium this 3PID is. Thus we first assume it's email,
# and if validation fails we try msisdn
validation_session = None
# Try to validate as email
if self.hs.config.email.can_verify_email:
@@ -488,19 +489,18 @@ class IdentityHandler:
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)
if validation_session:
return validation_session
if validation_session:
return attr.asdict(validation_session)
# Try to validate as msisdn
if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = await self.threepid_from_creds(
return await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds,
)
return validation_session
return None
async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str
+9 -9
View File
@@ -145,7 +145,7 @@ class InitialSyncHandler:
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
receipt = await self.store.get_linearized_receipts_for_rooms(
joined_rooms,
to_key=int(now_token.receipt_key),
to_key=now_token.receipt_key,
)
receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
@@ -173,7 +173,7 @@ class InitialSyncHandler:
d["inviter"] = event.sender
invite_event = await self.store.get_event(event.event_id)
d["invite"] = self._event_serializer.serialize_event(
d["invite"] = await self._event_serializer.serialize_event(
invite_event,
time_now,
config=serializer_options,
@@ -225,7 +225,7 @@ class InitialSyncHandler:
d["messages"] = {
"chunk": (
self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
messages,
time_now=time_now,
config=serializer_options,
@@ -235,7 +235,7 @@ class InitialSyncHandler:
"end": await end_token.to_string(self.store),
}
d["state"] = self._event_serializer.serialize_events(
d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
config=serializer_options,
@@ -387,7 +387,7 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
messages, time_now, config=serialize_options
)
),
@@ -396,7 +396,7 @@ class InitialSyncHandler:
},
"state": (
# Don't bundle aggregations as this is a deprecated API.
self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
room_state.values(), time_now, config=serialize_options
)
),
@@ -420,7 +420,7 @@ class InitialSyncHandler:
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
# Don't bundle aggregations as this is a deprecated API.
state = self._event_serializer.serialize_events(
state = await self._event_serializer.serialize_events(
current_state.values(),
time_now,
config=serialize_options,
@@ -439,7 +439,7 @@ class InitialSyncHandler:
async def get_presence() -> List[JsonDict]:
# If presence is disabled, return an empty list
if not self.hs.config.server.use_presence:
if not self.hs.config.server.presence_enabled:
return []
states = await presence_handler.get_states(
@@ -497,7 +497,7 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
messages, time_now, config=serialize_options
)
),
+58 -17
View File
@@ -244,7 +244,7 @@ class MessageHandler:
)
room_state = room_state_events[membership_event_id]
events = self._event_serializer.serialize_events(
events = await self._event_serializer.serialize_events(
room_state.values(),
self.clock.time_msec(),
config=SerializeEventConfig(requester=requester),
@@ -999,7 +999,26 @@ class EventCreationHandler:
raise ShadowBanError()
if ratelimit:
await self.request_ratelimiter.ratelimit(requester, update=False)
room_id = event_dict["room_id"]
try:
room_version = await self.store.get_room_version(room_id)
except NotFoundError:
# The room doesn't exist.
raise AuthError(403, f"User {requester.user} not in room {room_id}")
if room_version.updated_redaction_rules:
redacts = event_dict["content"].get("redacts")
else:
redacts = event_dict.get("redacts")
is_admin_redaction = await self.is_admin_redaction(
event_type=event_dict["type"],
sender=event_dict["sender"],
redacts=redacts,
)
await self.request_ratelimiter.ratelimit(
requester, is_admin_redaction=is_admin_redaction, update=False
)
# We limit the number of concurrent event sends in a room so that we
# don't fork the DAG too much. If we don't limit then we can end up in
@@ -1508,6 +1527,18 @@ class EventCreationHandler:
first_event.room_id
)
if writer_instance != self._instance_name:
# Ratelimit before sending to the other event persister, to
# ensure that we correctly have ratelimits on both the event
# creators and event persisters.
if ratelimit:
for event, _ in events_and_context:
is_admin_redaction = await self.is_admin_redaction(
event.type, event.sender, event.redacts
)
await self.request_ratelimiter.ratelimit(
requester, is_admin_redaction=is_admin_redaction
)
try:
result = await self.send_events(
instance_name=writer_instance,
@@ -1538,6 +1569,7 @@ class EventCreationHandler:
# stream_ordering entry manually (as it was persisted on
# another worker).
event.internal_metadata.stream_ordering = stream_id
return event
event = await self.persist_and_notify_client_events(
@@ -1696,21 +1728,9 @@ class EventCreationHandler:
# can apply different ratelimiting. We do this by simply checking
# it's not a self-redaction (to avoid having to look up whether the
# user is actually admin or not).
is_admin_redaction = False
if event.type == EventTypes.Redaction:
assert event.redacts is not None
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.as_is,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
)
is_admin_redaction = bool(
original_event and event.sender != original_event.sender
)
is_admin_redaction = await self.is_admin_redaction(
event.type, event.sender, event.redacts
)
await self.request_ratelimiter.ratelimit(
requester, is_admin_redaction=is_admin_redaction
@@ -1930,6 +1950,27 @@ class EventCreationHandler:
return persisted_events[-1]
async def is_admin_redaction(
self, event_type: str, sender: str, redacts: Optional[str]
) -> bool:
"""Return whether the event is a redaction made by an admin, and thus
should use a different ratelimiter.
"""
if event_type != EventTypes.Redaction:
return False
assert redacts is not None
original_event = await self.store.get_event(
redacts,
redact_behaviour=EventRedactBehaviour.as_is,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
)
return bool(original_event and sender != original_event.sender)
async def _maybe_kick_guest_users(
self, event: EventBase, context: EventContext
) -> None:
+2 -2
View File
@@ -657,7 +657,7 @@ class PaginationHandler:
chunk = {
"chunk": (
self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
events,
time_now,
config=serialize_options,
@@ -669,7 +669,7 @@ class PaginationHandler:
}
if state:
chunk["state"] = self._event_serializer.serialize_events(
chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, config=serialize_options
)
+47 -31
View File
@@ -192,7 +192,8 @@ class BasePresenceHandler(abc.ABC):
self.state = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
self._presence_enabled = hs.config.server.use_presence
self._presence_enabled = hs.config.server.presence_enabled
self._track_presence = hs.config.server.track_presence
self._federation = None
if hs.should_send_federation():
@@ -512,7 +513,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
)
async def _on_shutdown(self) -> None:
if self._presence_enabled:
if self._track_presence:
self.hs.get_replication_command_handler().send_command(
ClearUserSyncsCommand(self.instance_id)
)
@@ -524,7 +525,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
is_syncing: bool,
last_sync_ms: int,
) -> None:
if self._presence_enabled:
if self._track_presence:
self.hs.get_replication_command_handler().send_user_sync(
self.instance_id, user_id, device_id, is_syncing, last_sync_ms
)
@@ -571,7 +572,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
Called by the sync and events servlets to record that a user has connected to
this worker and is waiting for some events.
"""
if not affect_presence or not self._presence_enabled:
if not affect_presence or not self._track_presence:
return _NullContextManager()
# Note that this causes last_active_ts to be incremented which is not
@@ -702,8 +703,8 @@ class WorkerPresenceHandler(BasePresenceHandler):
user_id = target_user.to_string()
# If presence is disabled, no-op
if not self._presence_enabled:
# If tracking of presence is disabled, no-op
if not self._track_presence:
return
# Proxy request to instance that writes presence
@@ -723,7 +724,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
if not self._presence_enabled:
if not self._track_presence:
return
# Proxy request to instance that writes presence
@@ -760,7 +761,7 @@ class PresenceHandler(BasePresenceHandler):
] = {}
now = self.clock.time_msec()
if self._presence_enabled:
if self._track_presence:
for state in self.user_to_current_state.values():
# Create a psuedo-device to properly handle time outs. This will
# be overridden by any "real" devices within SYNC_ONLINE_TIMEOUT.
@@ -831,7 +832,7 @@ class PresenceHandler(BasePresenceHandler):
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
if self._presence_enabled:
if self._track_presence:
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
@@ -839,6 +840,9 @@ class PresenceHandler(BasePresenceHandler):
30, self.clock.looping_call, self._handle_timeouts, 5000
)
# Presence information is persisted, whether or not it is being tracked
# internally.
if self._presence_enabled:
self.clock.call_later(
60,
self.clock.looping_call,
@@ -854,7 +858,7 @@ class PresenceHandler(BasePresenceHandler):
)
# Used to handle sending of presence to newly joined users/servers
if self._presence_enabled:
if self._track_presence:
self.notifier.add_replication_callback(self.notify_new_event)
# Presence is best effort and quickly heals itself, so lets just always
@@ -905,7 +909,9 @@ class PresenceHandler(BasePresenceHandler):
)
async def _update_states(
self, new_states: Iterable[UserPresenceState], force_notify: bool = False
self,
new_states: Iterable[UserPresenceState],
force_notify: bool = False,
) -> None:
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
@@ -943,7 +949,7 @@ class PresenceHandler(BasePresenceHandler):
for new_state in new_states:
user_id = new_state.user_id
# Its fine to not hit the database here, as the only thing not in
# It's fine to not hit the database here, as the only thing not in
# the current state cache are OFFLINE states, where the only field
# of interest is last_active which is safe enough to assume is 0
# here.
@@ -957,6 +963,9 @@ class PresenceHandler(BasePresenceHandler):
is_mine=self.is_mine_id(user_id),
wheel_timer=self.wheel_timer,
now=now,
# When overriding disabled presence, don't kick off all the
# wheel timers.
persist=not self._track_presence,
)
if force_notify:
@@ -1072,7 +1081,7 @@ class PresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
if not self._presence_enabled:
if not self._track_presence:
return
user_id = user.to_string()
@@ -1124,7 +1133,7 @@ class PresenceHandler(BasePresenceHandler):
client that is being used by a user.
presence_state: The presence state indicated in the sync request
"""
if not affect_presence or not self._presence_enabled:
if not affect_presence or not self._track_presence:
return _NullContextManager()
curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0)
@@ -1284,7 +1293,7 @@ class PresenceHandler(BasePresenceHandler):
async def incoming_presence(self, origin: str, content: JsonDict) -> None:
"""Called when we receive a `m.presence` EDU from a remote server."""
if not self._presence_enabled:
if not self._track_presence:
return
now = self.clock.time_msec()
@@ -1359,7 +1368,7 @@ class PresenceHandler(BasePresenceHandler):
raise SynapseError(400, "Invalid presence state")
# If presence is disabled, no-op
if not self._presence_enabled:
if not self._track_presence:
return
user_id = target_user.to_string()
@@ -2118,6 +2127,7 @@ def handle_update(
is_mine: bool,
wheel_timer: WheelTimer,
now: int,
persist: bool,
) -> Tuple[UserPresenceState, bool, bool]:
"""Given a presence update:
1. Add any appropriate timers.
@@ -2129,6 +2139,8 @@ def handle_update(
is_mine: Whether the user is ours
wheel_timer
now: Time now in ms
persist: True if this state should persist until another update occurs.
Skips insertion into wheel timers.
Returns:
3-tuple: `(new_state, persist_and_notify, federation_ping)` where:
@@ -2146,14 +2158,15 @@ def handle_update(
if is_mine:
if new_state.state == PresenceState.ONLINE:
# Idle timer
wheel_timer.insert(
now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER
)
if not persist:
wheel_timer.insert(
now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER
)
active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY
new_state = new_state.copy_and_replace(currently_active=active)
if active:
if active and not persist:
wheel_timer.insert(
now=now,
obj=user_id,
@@ -2162,11 +2175,12 @@ def handle_update(
if new_state.state != PresenceState.OFFLINE:
# User has stopped syncing
wheel_timer.insert(
now=now,
obj=user_id,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
)
if not persist:
wheel_timer.insert(
now=now,
obj=user_id,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
)
last_federate = new_state.last_federation_update_ts
if now - last_federate > FEDERATION_PING_INTERVAL:
@@ -2174,7 +2188,7 @@ def handle_update(
new_state = new_state.copy_and_replace(last_federation_update_ts=now)
federation_ping = True
if new_state.state == PresenceState.BUSY:
if new_state.state == PresenceState.BUSY and not persist:
wheel_timer.insert(
now=now,
obj=user_id,
@@ -2182,11 +2196,13 @@ def handle_update(
)
else:
wheel_timer.insert(
now=now,
obj=user_id,
then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
)
# An update for a remote user was received.
if not persist:
wheel_timer.insert(
now=now,
obj=user_id,
then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
)
# Check whether the change was something worth notifying about
if should_notify(prev_state, new_state, is_mine):
+10 -9
View File
@@ -20,6 +20,7 @@ from synapse.streams import EventSource
from synapse.types import (
JsonDict,
JsonMapping,
MultiWriterStreamToken,
ReadReceipt,
StreamKeyType,
UserID,
@@ -200,7 +201,7 @@ class ReceiptsHandler:
await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource(EventSource[int, JsonMapping]):
class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.config = hs.config
@@ -273,13 +274,12 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
async def get_new_events(
self,
user: UserID,
from_key: int,
from_key: MultiWriterStreamToken,
limit: int,
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonMapping], int]:
from_key = int(from_key)
) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
to_key = self.get_current_key()
if from_key == to_key:
@@ -296,8 +296,11 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
return events, to_key
async def get_new_events_as(
self, from_key: int, to_key: int, service: ApplicationService
) -> Tuple[List[JsonMapping], int]:
self,
from_key: MultiWriterStreamToken,
to_key: MultiWriterStreamToken,
service: ApplicationService,
) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
"""Returns a set of new read receipt events that an appservice
may be interested in.
@@ -312,8 +315,6 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
appservice may be interested in.
* The current read receipt stream token.
"""
from_key = int(from_key)
if from_key == to_key:
return [], to_key
@@ -333,5 +334,5 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
return events, to_key
def get_current_key(self) -> int:
def get_current_key(self) -> MultiWriterStreamToken:
return self.store.get_max_receipt_stream_id()
+5 -3
View File
@@ -167,7 +167,7 @@ class RelationsHandler:
now = self._clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
return_value: JsonDict = {
"chunk": self._event_serializer.serialize_events(
"chunk": await self._event_serializer.serialize_events(
events,
now,
bundle_aggregations=aggregations,
@@ -177,7 +177,9 @@ class RelationsHandler:
if include_original_event:
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
return_value["original_event"] = self._event_serializer.serialize_event(
return_value[
"original_event"
] = await self._event_serializer.serialize_event(
event,
now,
bundle_aggregations=None,
@@ -602,7 +604,7 @@ class RelationsHandler:
)
now = self._clock.time_msec()
serialized_events = self._event_serializer.serialize_events(
serialized_events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
+4 -4
View File
@@ -374,13 +374,13 @@ class SearchHandler:
serialize_options = SerializeEventConfig(requester=requester)
for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
context["events_before"] = await self._event_serializer.serialize_events(
context["events_before"],
time_now,
bundle_aggregations=aggregations,
config=serialize_options,
)
context["events_after"] = self._event_serializer.serialize_events(
context["events_after"] = await self._event_serializer.serialize_events(
context["events_after"],
time_now,
bundle_aggregations=aggregations,
@@ -390,7 +390,7 @@ class SearchHandler:
results = [
{
"rank": search_result.rank_map[e.event_id],
"result": self._event_serializer.serialize_event(
"result": await self._event_serializer.serialize_event(
e,
time_now,
bundle_aggregations=aggregations,
@@ -409,7 +409,7 @@ class SearchHandler:
if state_results:
rooms_cat_res["state"] = {
room_id: self._event_serializer.serialize_events(
room_id: await self._event_serializer.serialize_events(
state_events, time_now, config=serialize_options
)
for room_id, state_events in state_results.items()
+1 -4
View File
@@ -1206,10 +1206,7 @@ class SsoHandler:
# We have no guarantee that all the devices of that session are for the same
# `user_id`. Hence, we have to iterate over the list of devices and log them out
# one by one.
for device in devices:
user_id = device["user_id"]
device_id = device["device_id"]
for user_id, device_id in devices:
# If the user_id associated with that device/session is not the one we got
# out of the `sub` claim, skip that device and show log an error.
if expected_user_id is not None and user_id != expected_user_id:
+51 -10
View File
@@ -57,6 +57,7 @@ from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
MultiWriterStreamToken,
MutableStateMap,
Requester,
RoomStreamToken,
@@ -477,7 +478,11 @@ class SyncHandler:
event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else 0
receipt_key = (
since_token.receipt_key
if since_token
else MultiWriterStreamToken(stream=0)
)
receipt_source = self.event_sources.sources.receipt
receipts, receipt_key = await receipt_source.get_new_events(
@@ -500,12 +505,27 @@ class SyncHandler:
async def _load_filtered_recents(
self,
room_id: str,
sync_result_builder: "SyncResultBuilder",
sync_config: SyncConfig,
now_token: StreamToken,
upto_token: StreamToken,
since_token: Optional[StreamToken] = None,
potential_recents: Optional[List[EventBase]] = None,
newly_joined_room: bool = False,
) -> TimelineBatch:
"""Create a timeline batch for the room
Args:
room_id
sync_result_builder
sync_config
upto_token: The token up to which we should fetch (more) events.
If `potential_results` is non-empty then this is *start* of
the the list.
since_token
potential_recents: If non-empty, the events between the since token
and current token to send down to clients.
newly_joined_room
"""
with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit()
block_all_timeline = (
@@ -521,6 +541,20 @@ class SyncHandler:
else:
limited = False
# Check if there is a gap, if so we need to mark this as limited and
# recalculate which events to send down.
gap_token = await self.store.get_timeline_gaps(
room_id,
since_token.room_key if since_token else None,
sync_result_builder.now_token.room_key,
)
if gap_token:
# There's a gap, so we need to ignore the passed in
# `potential_recents`, and reset `upto_token` to match.
potential_recents = None
upto_token = sync_result_builder.now_token
limited = True
log_kv({"limited": limited})
if potential_recents:
@@ -559,10 +593,10 @@ class SyncHandler:
recents = []
if not limited or block_all_timeline:
prev_batch_token = now_token
prev_batch_token = upto_token
if recents:
room_key = recents[0].internal_metadata.before
prev_batch_token = now_token.copy_and_replace(
prev_batch_token = upto_token.copy_and_replace(
StreamKeyType.ROOM, room_key
)
@@ -573,11 +607,15 @@ class SyncHandler:
filtering_factor = 2
load_limit = max(timeline_limit * filtering_factor, 10)
max_repeat = 5 # Only try a few times per room, otherwise
room_key = now_token.room_key
room_key = upto_token.room_key
end_key = room_key
since_key = None
if since_token and not newly_joined_room:
if since_token and gap_token:
# If there is a gap then we need to only include events after
# it.
since_key = gap_token
elif since_token and not newly_joined_room:
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat:
@@ -647,7 +685,7 @@ class SyncHandler:
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
prev_batch_token = now_token.copy_and_replace(StreamKeyType.ROOM, room_key)
prev_batch_token = upto_token.copy_and_replace(StreamKeyType.ROOM, room_key)
# Don't bother to bundle aggregations if the timeline is unlimited,
# as clients will have all the necessary information.
@@ -662,7 +700,9 @@ class SyncHandler:
return TimelineBatch(
events=recents,
prev_batch=prev_batch_token,
limited=limited or newly_joined_room,
# Also mark as limited if this is a new room or there has been a gap
# (to force client to paginate the gap).
limited=limited or newly_joined_room or gap_token is not None,
bundled_aggregations=bundled_aggregations,
)
@@ -1477,7 +1517,7 @@ class SyncHandler:
# Presence data is included if the server has it enabled and not filtered out.
include_presence_data = bool(
self.hs_config.server.use_presence
self.hs_config.server.presence_enabled
and not sync_config.filter_collection.blocks_all_presence()
)
# Device list updates are sent if a since token is provided.
@@ -2397,8 +2437,9 @@ class SyncHandler:
batch = await self._load_filtered_recents(
room_id,
sync_result_builder,
sync_config,
now_token=upto_token,
upto_token=upto_token,
since_token=since_token,
potential_recents=events,
newly_joined_room=newly_joined,
+3 -3
View File
@@ -187,9 +187,9 @@ class _BaseThreepidAuthChecker:
if row:
threepid = {
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
"medium": row.medium,
"address": row.address,
"validated_at": row.validated_at,
}
# Valid threepid returned, delete from the db
+1 -1
View File
@@ -59,7 +59,7 @@ class BasicProxyCredentials(ProxyCredentials):
a Proxy-Authorization header.
"""
# Encode as base64 and prepend the authorization type
return b"Basic " + base64.encodebytes(self.username_password)
return b"Basic " + base64.b64encode(self.username_password)
@attr.s(auto_attribs=True)
+1 -4
View File
@@ -949,10 +949,7 @@ class MediaRepository:
deleted = 0
for media in old_media:
origin = media["media_origin"]
media_id = media["media_id"]
file_id = media["filesystem_id"]
for origin, media_id, file_id in old_media:
key = (origin, media_id)
logger.info("Deleting: %r", key)
+54
View File
@@ -23,6 +23,7 @@ from typing import (
Generator,
Iterable,
List,
Mapping,
Optional,
Tuple,
TypeVar,
@@ -39,6 +40,7 @@ from twisted.web.resource import Resource
from synapse.api import errors
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.config import ConfigError
from synapse.events import EventBase
from synapse.events.presence_router import (
@@ -46,6 +48,7 @@ from synapse.events.presence_router import (
GET_USERS_FOR_STATES_CALLBACK,
PresenceRouter,
)
from synapse.events.utils import ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK
from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
@@ -257,6 +260,7 @@ class ModuleApi:
self.custom_template_dir = hs.config.server.custom_template_directory
self._callbacks = hs.get_module_api_callbacks()
self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
self._event_serializer = hs.get_event_client_serializer()
try:
app_name = self._hs.config.email.email_app_name
@@ -488,6 +492,25 @@ class ModuleApi:
"""
self._hs.register_module_web_resource(path, resource)
def register_add_extra_fields_to_unsigned_client_event_callbacks(
self,
*,
add_field_to_unsigned_callback: Optional[
ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
] = None,
) -> None:
"""Registers a callback that can be used to add fields to the unsigned
section of events.
The callback is called every time an event is sent down to a client.
Added in Synapse 1.96.0
"""
if add_field_to_unsigned_callback is not None:
self._event_serializer.register_add_extra_fields_to_unsigned_client_event_callback(
add_field_to_unsigned_callback
)
#########################################################################
# The following methods can be called by the module at any point in time.
@@ -1184,6 +1207,37 @@ class ModuleApi:
presence_events, [destination]
)
async def set_presence_for_users(
self, users: Mapping[str, Tuple[str, Optional[str]]]
) -> None:
"""
Update the internal presence state of users.
This can be used for either local or remote users.
Note that this method can only be run on the process that is configured to write to the
presence stream. By default, this is the main process.
Added in Synapse v1.96.0.
"""
# We pull out the presence handler here to break a cyclic
# dependency between the presence router and module API.
presence_handler = self._hs.get_presence_handler()
from synapse.handlers.presence import PresenceHandler
assert isinstance(presence_handler, PresenceHandler)
states = await presence_handler.current_state_for_users(users.keys())
for user_id, (state, status_msg) in users.items():
prev_state = states.setdefault(user_id, UserPresenceState.default(user_id))
states[user_id] = prev_state.copy_and_replace(
state=state, status_msg=status_msg
)
await presence_handler._update_states(states.values(), force_notify=True)
def looping_background_call(
self,
f: Callable,
+43 -2
View File
@@ -21,11 +21,13 @@ from typing import (
Dict,
Iterable,
List,
Literal,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
import attr
@@ -44,6 +46,7 @@ from synapse.metrics import LaterGauge
from synapse.streams.config import PaginationConfig
from synapse.types import (
JsonDict,
MultiWriterStreamToken,
PersistedEventPosition,
RoomStreamToken,
StrCollection,
@@ -127,7 +130,7 @@ class _NotifierUserStream:
def notify(
self,
stream_key: StreamKeyType,
stream_id: Union[int, RoomStreamToken],
stream_id: Union[int, RoomStreamToken, MultiWriterStreamToken],
time_now_ms: int,
) -> None:
"""Notify any listeners for this user of a new event from an
@@ -452,10 +455,48 @@ class Notifier:
except Exception:
logger.exception("Error pusher pool of event")
@overload
def on_new_event(
self,
stream_key: Literal[StreamKeyType.ROOM],
new_token: RoomStreamToken,
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
...
@overload
def on_new_event(
self,
stream_key: Literal[StreamKeyType.RECEIPT],
new_token: MultiWriterStreamToken,
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
...
@overload
def on_new_event(
self,
stream_key: Literal[
StreamKeyType.ACCOUNT_DATA,
StreamKeyType.DEVICE_LIST,
StreamKeyType.PRESENCE,
StreamKeyType.PUSH_RULES,
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
],
new_token: int,
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
...
def on_new_event(
self,
stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
+1 -1
View File
@@ -238,7 +238,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
data[_STREAM_POSITION_KEY] = {
"streams": {
stream.NAME: stream.current_token(local_instance_name)
stream.NAME: stream.minimal_local_current_token()
for stream in streams
},
"instance_name": local_instance_name,
+2 -12
View File
@@ -126,8 +126,9 @@ class ReplicationDataHandler:
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
)
elif stream_name == ReceiptsStream.NAME:
new_token = self.store.get_max_receipt_stream_id()
self.notifier.on_new_event(
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
StreamKeyType.RECEIPT, new_token, rooms=[row.room_id for row in rows]
)
await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
elif stream_name == ToDeviceStream.NAME:
@@ -279,14 +280,6 @@ class ReplicationDataHandler:
# may be streaming.
self.notifier.notify_replication()
def on_remote_server_up(self, server: str) -> None:
"""Called when get a new REMOTE_SERVER_UP command."""
# Let's wake up the transaction queue for the server in case we have
# pending stuff to send to it.
if self.send_handler:
self.send_handler.wake_destination(server)
async def wait_for_stream_position(
self,
instance_name: str,
@@ -405,9 +398,6 @@ class FederationSenderHandler:
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
def wake_destination(self, server: str) -> None:
self.federation_sender.wake_destination(server)
async def process_replication_rows(
self, stream_name: str, token: int, rows: list
) -> None:
-2
View File
@@ -657,8 +657,6 @@ class ReplicationCommandHandler:
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
) -> None:
"""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
self._notifier.notify_remote_server_up(cmd.data)
def on_LOCK_RELEASED(
+18 -1
View File
@@ -27,7 +27,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import PositionCommand
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import EventsStream
from synapse.replication.tcp.streams._base import StreamRow, Token
from synapse.replication.tcp.streams._base import CachesStream, StreamRow, Token
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -204,6 +204,23 @@ class ReplicationStreamer:
# The token has advanced but there is no data to
# send, so we send a `POSITION` to inform other
# workers of the updated position.
#
# There are two reasons for this: 1) this instance
# requested a stream ID but didn't use it, or 2)
# this instance advanced its own stream position due
# to receiving notifications about other instances
# advancing their stream position.
# We skip sending `POSITION` for the `caches` stream
# for the second case as a) it generates a lot of
# traffic as every worker would echo each write, and
# b) nothing cares if a given worker's caches stream
# position lags.
if stream.NAME == CachesStream.NAME:
# If there haven't been any writes since the
# `last_token` then we're in the second case.
if stream.minimal_local_current_token() <= last_token:
continue
# Note: `last_token` may not *actually* be the
# last token we sent out in a RDATA or POSITION.
+83 -46
View File
@@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
logger = logging.getLogger(__name__)
@@ -107,22 +108,10 @@ class Stream:
def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
`current_token_function` and `update_function` are callbacks which
should be implemented by subclasses.
`current_token_function` takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process). On the writer process this is where
the writer has successfully written up to, whereas on other processes
this is the position which we have received updates up to over
replication. (Note that most streams have a single writer and so their
implementations ignore the instance name passed in).
`update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more
info.
@@ -133,12 +122,28 @@ class Stream:
update_function: callback go get stream updates, as above
"""
self.local_instance_name = local_instance_name
self.current_token = current_token_function
self.update_function = update_function
# The token from which we last asked for updates
self.last_token = self.current_token(self.local_instance_name)
def current_token(self, instance_name: str) -> Token:
"""This takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process).
"""
# We can't make this an abstract class as it makes mypy unhappy.
raise NotImplementedError()
def minimal_local_current_token(self) -> Token:
"""Tries to return a minimal current token for the local instance,
i.e. for writers this would be the last successful write.
If local instance is not a writer (or has written yet) then falls back
to returning the normal "current token".
"""
raise NotImplementedError()
def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
@@ -190,6 +195,25 @@ class Stream:
return updates, upto_token, limited
class _StreamFromIdGen(Stream):
"""Helper class for simple streams that use a stream ID generator"""
def __init__(
self,
local_instance_name: str,
update_function: UpdateFunction,
stream_id_gen: "AbstractStreamIdGenerator",
):
self._stream_id_gen = stream_id_gen
super().__init__(local_instance_name, update_function)
def current_token(self, instance_name: str) -> Token:
return self._stream_id_gen.get_current_token_for_writer(instance_name)
def minimal_local_current_token(self) -> Token:
return self._stream_id_gen.get_minimal_local_current_token()
def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
@@ -242,17 +266,21 @@ class BackfillStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
self._current_token,
self.store.get_all_new_backfill_event_rows,
)
def _current_token(self, instance_name: str) -> int:
def current_token(self, instance_name: str) -> Token:
# The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it.
return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
def minimal_local_current_token(self) -> Token:
# The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it.
return -self.store._backfill_id_gen.get_minimal_local_current_token()
class PresenceStream(Stream):
class PresenceStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceStreamRow:
user_id: str
@@ -283,9 +311,7 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_current_presence_token),
update_function,
hs.get_instance_name(), update_function, store._presence_id_gen
)
@@ -305,13 +331,18 @@ class PresenceFederationStream(Stream):
ROW_TYPE = PresenceFederationStreamRow
def __init__(self, hs: "HomeServer"):
federation_queue = hs.get_presence_handler().get_federation_queue()
self._federation_queue = hs.get_presence_handler().get_federation_queue()
super().__init__(
hs.get_instance_name(),
federation_queue.get_current_token,
federation_queue.get_replication_rows,
self._federation_queue.get_replication_rows,
)
def current_token(self, instance_name: str) -> Token:
return self._federation_queue.get_current_token(instance_name)
def minimal_local_current_token(self) -> Token:
return self._federation_queue.get_current_token(self.local_instance_name)
class TypingStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -341,20 +372,25 @@ class TypingStream(Stream):
update_function: Callable[
[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
] = typing_writer_handler.get_all_typing_updates
current_token_function = typing_writer_handler.get_current_token
self.current_token_function = typing_writer_handler.get_current_token
else:
# Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
current_token_function = hs.get_typing_handler().get_current_token
self.current_token_function = hs.get_typing_handler().get_current_token
super().__init__(
hs.get_instance_name(),
current_token_without_instance(current_token_function),
update_function,
)
def current_token(self, instance_name: str) -> Token:
return self.current_token_function()
class ReceiptsStream(Stream):
def minimal_local_current_token(self) -> Token:
return self.current_token_function()
class ReceiptsStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ReceiptsStreamRow:
room_id: str
@@ -371,12 +407,12 @@ class ReceiptsStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_receipt_stream_id),
store.get_all_updated_receipts,
store._receipts_id_gen,
)
class PushRulesStream(Stream):
class PushRulesStream(_StreamFromIdGen):
"""A user has changed their push rules"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -387,20 +423,16 @@ class PushRulesStream(Stream):
ROW_TYPE = PushRulesStreamRow
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
self._current_token,
self.store.get_all_push_rule_updates,
store.get_all_push_rule_updates,
store._push_rules_stream_id_gen,
)
def _current_token(self, instance_name: str) -> int:
push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token
class PushersStream(Stream):
class PushersStream(_StreamFromIdGen):
"""A user has added/changed/removed a pusher"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -418,8 +450,8 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_pushers_stream_token),
store.get_all_updated_pushers_rows,
store._pushers_id_gen,
)
@@ -447,15 +479,20 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token_for_writer,
store.get_all_updated_caches,
self.store.get_all_updated_caches,
)
def current_token(self, instance_name: str) -> Token:
return self.store.get_cache_stream_token_for_writer(instance_name)
class DeviceListsStream(Stream):
def minimal_local_current_token(self) -> Token:
return self.current_token(self.local_instance_name)
class DeviceListsStream(_StreamFromIdGen):
"""Either a user has updated their devices or a remote server needs to be
told about a device update.
"""
@@ -473,8 +510,8 @@ class DeviceListsStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self.store.get_device_stream_token),
self._update_function,
self.store._device_list_id_gen,
)
async def _update_function(
@@ -525,7 +562,7 @@ class DeviceListsStream(Stream):
return updates, upper_limit_token, devices_limited or signatures_limited
class ToDeviceStream(Stream):
class ToDeviceStream(_StreamFromIdGen):
"""New to_device messages for a client"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -539,12 +576,12 @@ class ToDeviceStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_to_device_stream_token),
store.get_all_new_device_messages,
store._device_inbox_id_gen,
)
class AccountDataStream(Stream):
class AccountDataStream(_StreamFromIdGen):
"""Global or per room account data was changed"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -560,8 +597,8 @@ class AccountDataStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self.store.get_max_account_data_stream_id),
self._update_function,
self.store._account_data_id_gen,
)
async def _update_function(
+47 -6
View File
@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
from collections import defaultdict
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
import attr
from synapse.replication.tcp.streams._base import (
Stream,
StreamRow,
StreamUpdateResult,
Token,
_StreamFromIdGen,
)
if TYPE_CHECKING:
@@ -51,8 +52,19 @@ data part are:
* The state_key of the state which has changed
* The event id of the new state
A "state-all" row is sent whenever the "current state" in a room changes, but there are
too many state updates for a particular room in the same update. This replaces any
"state" rows on a per-room basis. The fields in the data part are:
* The room id for the state changes
"""
# Any room with more than _MAX_STATE_UPDATES_PER_ROOM will send a EventsStreamAllStateRow
# instead of individual EventsStreamEventRow. This is predominantly useful when
# purging large rooms.
_MAX_STATE_UPDATES_PER_ROOM = 150
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamRow:
@@ -111,15 +123,23 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamAllStateRow(BaseEventsStreamRow):
TypeId = "state-all"
room_id: str
_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
EventsStreamEventRow,
EventsStreamCurrentStateRow,
EventsStreamAllStateRow,
)
TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream):
class EventsStream(_StreamFromIdGen):
"""We received a new event, or an event went from being an outlier to not"""
NAME = "events"
@@ -127,9 +147,7 @@ class EventsStream(Stream):
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
self._store._stream_id_gen.get_current_token_for_writer,
self._update_function,
hs.get_instance_name(), self._update_function, self._store._stream_id_gen
)
async def _update_function(
@@ -213,9 +231,28 @@ class EventsStream(Stream):
if stream_id <= upper_limit
)
# Separate out rooms that have many state updates, listeners should clear
# all state for those rooms.
state_updates_by_room = defaultdict(list)
for stream_id, room_id, _type, _state_key, _event_id in state_rows:
state_updates_by_room[room_id].append(stream_id)
state_all_rows = [
(stream_ids[-1], room_id)
for room_id, stream_ids in state_updates_by_room.items()
if len(stream_ids) >= _MAX_STATE_UPDATES_PER_ROOM
]
state_all_updates: Iterable[Tuple[int, Tuple]] = (
(max_stream_id, (EventsStreamAllStateRow.TypeId, (room_id,)))
for (max_stream_id, room_id) in state_all_rows
)
# Any remaining state updates are sent individually.
state_all_rooms = {room_id for _, room_id in state_all_rows}
state_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
for (stream_id, *rest) in state_rows
if rest[0] not in state_all_rooms
)
ex_outliers_updates: Iterable[Tuple[int, Tuple]] = (
@@ -224,7 +261,11 @@ class EventsStream(Stream):
)
# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
updates = list(
heapq.merge(
event_updates, state_all_updates, state_updates, ex_outliers_updates
)
)
return updates, upper_limit, limited
@classmethod
+11 -4
View File
@@ -18,6 +18,7 @@ import attr
from synapse.replication.tcp.streams._base import (
Stream,
Token,
current_token_without_instance,
make_http_update_function,
)
@@ -47,7 +48,7 @@ class FederationStream(Stream):
# will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.)
federation_sender = hs.get_federation_sender()
current_token = current_token_without_instance(
self.current_token_func = current_token_without_instance(
federation_sender.get_current_token
)
update_function: Callable[
@@ -57,15 +58,21 @@ class FederationStream(Stream):
elif hs.should_send_federation():
# federation sender: Query master process
update_function = make_http_update_function(hs, self.NAME)
current_token = self._stub_current_token
self.current_token_func = self._stub_current_token
else:
# other worker: stub out the update function (we're not interested in
# any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
current_token = self._stub_current_token
self.current_token_func = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function)
super().__init__(hs.get_instance_name(), update_function)
def current_token(self, instance_name: str) -> Token:
return self.current_token_func(instance_name)
def minimal_local_current_token(self) -> Token:
return self.current_token(self.local_instance_name)
@staticmethod
def _stub_current_token(instance_name: str) -> int:
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING
import attr
from synapse.replication.tcp.streams import Stream
from synapse.replication.tcp.streams._base import _StreamFromIdGen
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow:
room_id: str
class UnPartialStatedRoomStream(Stream):
class UnPartialStatedRoomStream(_StreamFromIdGen):
"""
Stream to notify about rooms becoming un-partial-stated;
that is, when the background sync finishes such that we now have full state for
@@ -41,8 +41,8 @@ class UnPartialStatedRoomStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
store.get_un_partial_stated_rooms_token,
store.get_un_partial_stated_rooms_from_stream,
store._un_partial_stated_rooms_stream_id_gen,
)
@@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow:
rejection_status_changed: bool
class UnPartialStatedEventStream(Stream):
class UnPartialStatedEventStream(_StreamFromIdGen):
"""
Stream to notify about events becoming un-partial-stated.
"""
@@ -68,6 +68,6 @@ class UnPartialStatedEventStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
store.get_un_partial_stated_events_token,
store.get_un_partial_stated_events_from_stream,
store._un_partial_stated_events_stream_id_gen,
)
+13 -1
View File
@@ -85,7 +85,19 @@ class ListDestinationsRestServlet(RestServlet):
destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
)
response = {"destinations": destinations, "total": total}
response = {
"destinations": [
{
"destination": r[0],
"retry_last_ts": r[1],
"retry_interval": r[2],
"failure_ts": r[3],
"last_successful_stream_ordering": r[4],
}
for r in destinations
],
"total": total,
}
if (start + limit) < total:
response["next_token"] = str(start + len(destinations))
+16 -6
View File
@@ -444,7 +444,7 @@ class RoomStateRestServlet(RestServlet):
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
room_state = self._event_serializer.serialize_events(events.values(), now)
room_state = await self._event_serializer.serialize_events(events.values(), now)
ret = {"state": room_state}
return HTTPStatus.OK, ret
@@ -724,7 +724,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id)
return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
result = [
{
"event_id": ex[0],
"state_group": ex[1],
"depth": ex[2],
"received_ts": ex[3],
}
for ex in extremities
]
return HTTPStatus.OK, {"count": len(extremities), "results": result}
class RoomEventContextServlet(RestServlet):
@@ -779,22 +789,22 @@ class RoomEventContextServlet(RestServlet):
time_now = self.clock.time_msec()
results = {
"events_before": self._event_serializer.serialize_events(
"events_before": await self._event_serializer.serialize_events(
event_context.events_before,
time_now,
bundle_aggregations=event_context.aggregations,
),
"event": self._event_serializer.serialize_event(
"event": await self._event_serializer.serialize_event(
event_context.event,
time_now,
bundle_aggregations=event_context.aggregations,
),
"events_after": self._event_serializer.serialize_events(
"events_after": await self._event_serializer.serialize_events(
event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
),
"state": self._event_serializer.serialize_events(
"state": await self._event_serializer.serialize_events(
event_context.state, time_now
),
"start": event_context.start,
+12 -1
View File
@@ -108,7 +108,18 @@ class UserMediaStatisticsRestServlet(RestServlet):
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term
)
ret = {"users": users_media, "total": total}
ret = {
"users": [
{
"user_id": r[0],
"displayname": r[1],
"media_count": r[2],
"media_length": r[3],
}
for r in users_media
],
"total": total,
}
if (start + limit) < total:
ret["next_token"] = start + len(users_media)
+1 -1
View File
@@ -93,7 +93,7 @@ class EventRestServlet(RestServlet):
event = await self.event_handler.get_event(requester.user, None, event_id)
if event:
result = self._event_serializer.serialize_event(
result = await self._event_serializer.serialize_event(
event,
self.clock.time_msec(),
config=SerializeEventConfig(requester=requester),
+1 -1
View File
@@ -87,7 +87,7 @@ class NotificationsServlet(RestServlet):
"actions": pa.actions,
"ts": pa.received_ts,
"event": (
self._event_serializer.serialize_event(
await self._event_serializer.serialize_event(
notif_events[pa.event_id],
now,
config=serialize_options,
+2 -4
View File
@@ -42,15 +42,13 @@ class PresenceStatusRestServlet(RestServlet):
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self._use_presence = hs.config.server.use_presence
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self._use_presence:
if not self.hs.config.server.presence_enabled:
return 200, {"presence": "offline"}
if requester.user != user:
@@ -96,7 +94,7 @@ class PresenceStatusRestServlet(RestServlet):
except Exception:
raise SynapseError(400, "Unable to parse state")
if self._use_presence:
if self.hs.config.server.track_presence:
await self.presence_handler.set_state(user, requester.device_id, state)
return 200, {}
+5 -5
View File
@@ -859,7 +859,7 @@ class RoomEventServlet(RestServlet):
# per MSC2676, /rooms/{roomId}/event/{eventId}, should return the
# *original* event, rather than the edited version
event_dict = self._event_serializer.serialize_event(
event_dict = await self._event_serializer.serialize_event(
event,
self.clock.time_msec(),
bundle_aggregations=aggregations,
@@ -911,25 +911,25 @@ class RoomEventContextServlet(RestServlet):
time_now = self.clock.time_msec()
serializer_options = SerializeEventConfig(requester=requester)
results = {
"events_before": self._event_serializer.serialize_events(
"events_before": await self._event_serializer.serialize_events(
event_context.events_before,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
"event": self._event_serializer.serialize_event(
"event": await self._event_serializer.serialize_event(
event_context.event,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
"events_after": self._event_serializer.serialize_events(
"events_after": await self._event_serializer.serialize_events(
event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
"state": self._event_serializer.serialize_events(
"state": await self._event_serializer.serialize_events(
event_context.state,
time_now,
config=serializer_options,
+4 -4
View File
@@ -384,7 +384,7 @@ class SyncRestServlet(RestServlet):
"""
invited = {}
for room in rooms:
invite = self._event_serializer.serialize_event(
invite = await self._event_serializer.serialize_event(
room.invite, time_now, config=serialize_options
)
unsigned = dict(invite.get("unsigned", {}))
@@ -415,7 +415,7 @@ class SyncRestServlet(RestServlet):
"""
knocked = {}
for room in rooms:
knock = self._event_serializer.serialize_event(
knock = await self._event_serializer.serialize_event(
room.knock, time_now, config=serialize_options
)
@@ -506,10 +506,10 @@ class SyncRestServlet(RestServlet):
event.room_id,
)
serialized_state = self._event_serializer.serialize_events(
serialized_state = await self._event_serializer.serialize_events(
state_events, time_now, config=serialize_options
)
serialized_timeline = self._event_serializer.serialize_events(
serialized_timeline = await self._event_serializer.serialize_events(
timeline_events,
time_now,
config=serialize_options,
+1 -1
View File
@@ -786,7 +786,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_event_client_serializer(self) -> EventClientSerializer:
return EventClientSerializer()
return EventClientSerializer(self)
@cache_in_self
def get_password_policy_handler(self) -> PasswordPolicyHandler:
+19 -42
View File
@@ -35,7 +35,6 @@ from typing import (
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
@@ -606,13 +605,16 @@ class DatabasePool:
If the background updates have not completed, wait 15 sec and check again.
"""
updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
desc="check_background_updates",
updates = cast(
List[Tuple[str]],
await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
desc="check_background_updates",
),
)
background_update_names = [x["update_name"] for x in updates]
background_update_names = [x[0] for x in updates]
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
if update_name not in background_update_names:
@@ -1044,43 +1046,20 @@ class DatabasePool:
results = [dict(zip(col_headers, row)) for row in cursor]
return results
@overload
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...
@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...
async def execute(
self,
desc: str,
decoder: Optional[Callable[[Cursor], R]],
query: str,
*args: Any,
) -> Union[List[Tuple[Any, ...]], R]:
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
"""Runs a single query for a result set.
Args:
desc: description of the transaction, for logging and metrics
decoder - The function which can resolve the cursor results to
something meaningful.
query - The query string to execute
*args - Query args.
Returns:
The result of decoder(results)
"""
def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]:
txn.execute(query, args)
if decoder:
return decoder(txn)
else:
return txn.fetchall()
return txn.fetchall()
return await self.runInteraction(desc, interaction)
@@ -1804,9 +1783,9 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]],
retcols: Collection[str],
desc: str = "simple_select_list",
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows, returning the result as a list of tuples.
Args:
table: the table name
@@ -1817,8 +1796,7 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
Returns:
A list of dictionaries, one per result row, each a mapping between the
column names from `retcols` and that column's value for the row.
A list of tuples, one per result row, each the retcolumn's value for the row.
"""
return await self.runInteraction(
desc,
@@ -1836,9 +1814,9 @@ class DatabasePool:
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows, returning the result as a list of tuples.
Args:
txn: Transaction object
@@ -1849,8 +1827,7 @@ class DatabasePool:
retcols: the names of the columns to return
Returns:
A list of dictionaries, one per result row, each a mapping between the
column names from `retcols` and that column's value for the row.
A list of tuples, one per result row, each the retcolumn's value for the row.
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1863,7 +1840,7 @@ class DatabasePool:
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql)
return cls.cursor_to_dict(txn)
return txn.fetchall()
async def simple_select_many_batch(
self,
+15 -8
View File
@@ -94,7 +94,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
hs.get_replication_notifier(),
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
extra_tables=[
("account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
is_writer=self._instance_name in hs.config.worker.writers.account_data,
)
@@ -283,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_room_txn(
txn: LoggingTransaction,
) -> Dict[str, JsonDict]:
rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
["account_data_type", "content"],
) -> Dict[str, JsonMapping]:
rows = cast(
List[Tuple[str, str]],
self.db_pool.simple_select_list_txn(
txn,
table="room_account_data",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=["account_data_type", "content"],
),
)
return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
account_data_type: db_to_json(content)
for account_data_type, content in rows
}
return await self.db_pool.runInteraction(
+9 -4
View File
@@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A list of ApplicationServices, which may be empty.
"""
results = await self.db_pool.simple_select_list(
"application_services_state", {"state": state.value}, ["as_id"]
results = cast(
List[Tuple[str]],
await self.db_pool.simple_select_list(
table="application_services_state",
keyvalues={"state": state.value},
retcols=("as_id",),
),
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
services = []
for res in results:
for (as_id,) in results:
for service in as_list:
if service.id == res["as_id"]:
if service.id == as_id:
services.append(service)
return services
+8
View File
@@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamAllStateRow,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
@@ -264,6 +265,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
(data.state_key,)
)
self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined]
elif row.type == EventsStreamAllStateRow.TypeId:
assert isinstance(data, EventsStreamAllStateRow)
# Similar to the above, but the entire caches are invalidated. This is
# unfortunate for the membership caches, but should recover quickly.
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined]
self.get_rooms_for_user_with_stream_ordering.invalidate_all() # type: ignore[attr-defined]
self.get_rooms_for_user.invalidate_all() # type: ignore[attr-defined]
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
@@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
"""
rows = await self.db_pool.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100
"_censor_redactions_fetch", sql, before_ts, 100
)
updates = []
+14 -11
View File
@@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
if device_id is not None:
keyvalues["device_id"] = device_id
res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
res = cast(
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
)
return {
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
user_id=d["user_id"],
device_id=d["device_id"],
ip=d["ip"],
user_agent=d["user_agent"],
last_seen=d["last_seen"],
(user_id, device_id): DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip=ip,
user_agent=user_agent,
last_seen=last_seen,
)
for d in res
for user_id, ip, user_agent, device_id, last_seen in res
}
async def _get_user_ip_and_agents_from_database(
@@ -478,18 +478,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": "No changes in cache since last check"})
return 0
ROW_ID_NAME = self.database_engine.row_id_name
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
limit_statement = "" if limit is None else f"LIMIT {limit}"
sql = f"""
DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
SELECT {ROW_ID_NAME} FROM device_inbox
WHERE user_id = ? AND device_id = ? AND stream_id <= ?
{limit_statement}
DELETE FROM device_inbox WHERE user_id = ? AND device_id = ? AND stream_id <= (
SELECT MAX(stream_id) FROM (
SELECT stream_id FROM device_inbox
WHERE user_id = ? AND device_id = ? AND stream_id <= ?
ORDER BY stream_id
{limit_statement}
) AS q1
)
"""
txn.execute(sql, (user_id, device_id, up_to_stream_id))
txn.execute(sql, (user_id, device_id, user_id, device_id, up_to_stream_id))
return txn.rowcount
count = await self.db_pool.runInteraction(
+41 -32
View File
@@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
allow_none=True,
)
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
async def get_devices_by_user(
self, user_id: str
) -> Dict[str, Dict[str, Optional[str]]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
@@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id:
Returns:
A mapping from device_id to a dict containing "device_id", "user_id"
and "display_name" for each device.
and "display_name" for each device. Display name may be null.
"""
devices = await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
devices = cast(
List[Tuple[str, str, Optional[str]]],
await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
),
)
return {d["device_id"]: d for d in devices}
return {
d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
for d in devices
}
async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
) -> List[Dict[str, Any]]:
) -> List[Tuple[str, str]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.
Args:
@@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
return await self.db_pool.simple_select_list(
table="device_auth_providers",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id",
return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="device_auth_providers",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id",
),
)
@trace
@@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_cached_devices_for_user(
self, user_id: str
) -> Mapping[str, JsonMapping]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="get_cached_devices_for_user",
devices = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="get_cached_devices_for_user",
),
)
return {
device["device_id"]: db_to_json(device["content"]) for device in devices
}
return {device[0]: db_to_json(device[1]) for device in devices}
def get_cached_device_list_changes(
self,
@@ -882,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
rows = await self.db_pool.execute(
"get_all_devices_changed",
None,
sql,
from_key,
to_key,
@@ -966,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
WHERE from_user_id = ? AND stream_id > ?
"""
rows = await self.db_pool.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
"get_users_whose_signatures_changed", sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
else:
@@ -1080,7 +1091,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
row_tuples = cast(
rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
@@ -1090,11 +1101,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable",
),
)
return {row[0] for row in row_tuples}
else:
rows = cast(
List[Dict[str, str]],
List[Tuple[str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
@@ -1103,7 +1112,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
),
)
return {row["user_id"] for row in rows}
return {row[0] for row in rows}
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
+29 -20
View File
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict
@@ -274,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
if session_id:
keyvalues["session_id"] = session_id
rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
"user_id",
"room_id",
"session_id",
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
rows = cast(
List[Tuple[str, str, int, int, int, str]],
await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
"room_id",
"session_id",
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
),
desc="get_e2e_room_keys",
),
desc="get_e2e_room_keys",
)
sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
for row in rows:
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
for (
room_id,
session_id,
first_message_index,
forwarded_count,
is_verified,
session_data,
) in rows:
room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
room_entry["sessions"][session_id] = {
"first_message_index": first_message_index,
"forwarded_count": forwarded_count,
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
"session_data": db_to_json(row["session_data"]),
"is_verified": bool(is_verified),
"session_data": db_to_json(session_data),
}
return sessions
@@ -52,7 +52,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict, JsonMapping
from synapse.types import JsonDict, JsonMapping, JsonSerializable
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
@@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
None,
sql,
now_stream_id,
user_id,
@@ -1113,7 +1112,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
Dict[str, Dict[str, Dict[str, JsonSerializable]]],
List[Tuple[str, str, str, int]],
]:
"""Take a list of one time keys out of the database.
@@ -1122,7 +1122,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
Returns:
A tuple pf:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
A map of user ID -> a map device ID -> a map of key ID -> key
A copy of the input which has not been fulfilled.
"""
@@ -1215,7 +1215,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
]
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
results: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
missing: List[Tuple[str, str, str, int]] = []
for user_id, device_id, algorithm, count in query_list:
if self.database_engine.supports_returning:
@@ -1241,7 +1241,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
device_id, {}
)
for claim_row in claim_rows:
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
# The shape of the key depends on the algorithm: it is a dict for
# signed_curve25519, or a string for curve25519. In general, it
# is whatever the client chose to upload, since we dont validate it.
decoded_key: JsonSerializable = json_decoder.decode(claim_row[1])
device_results[claim_row[0]] = decoded_key
# Did we get enough OTKs?
count -= len(claim_rows)
if count:
@@ -1251,7 +1255,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def claim_e2e_fallback_keys(
self, query_list: Iterable[Tuple[str, str, str, bool]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
) -> Dict[str, Dict[str, Dict[str, JsonSerializable]]]:
"""Take a list of fallback keys out of the database.
Args:
@@ -1261,7 +1265,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
results: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(
table="e2e_fallback_keys_json",
@@ -1299,7 +1303,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
# The shape of the key depends on the algorithm: it is a dict for
# signed_curve25519, or a string for curve25519. In general, it
# is whatever the client chose to upload, since we dont validate it.
decoded_key: JsonSerializable = json_decoder.decode(key_json)
device_results[f"{algorithm}:{key_id}"] = decoded_key
return results
@@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# keeping only the forward extremities (i.e. the events not referenced
# by other events in the queue). We do this so that we can always
# backpaginate in all the events we have dropped.
rows = await self.db_pool.simple_select_list(
table="federation_inbound_events_staging",
keyvalues={"room_id": room_id},
retcols=("event_id", "event_json"),
desc="prune_staged_events_in_room_fetch",
rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="federation_inbound_events_staging",
keyvalues={"room_id": room_id},
retcols=("event_id", "event_json"),
desc="prune_staged_events_in_room_fetch",
),
)
# Find the set of events referenced by those in the queue, as well as
# collecting all the event IDs in the queue.
referenced_events: Set[str] = set()
seen_events: Set[str] = set()
for row in rows:
event_id = row["event_id"]
for event_id, event_json in rows:
seen_events.add(event_id)
event_d = db_to_json(row["event_json"])
event_d = db_to_json(event_json)
# We don't bother parsing the dicts into full blown event objects,
# as that is needlessly expensive.
+49 -25
View File
@@ -2267,35 +2267,59 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
# From the events passed in, add all of the prev events as backwards extremities.
# Ignore any events that are already backwards extrems or outliers.
query = (
"INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
" SELECT 1 FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
" )"
# 1. Don't add an event as a extremity again if we already persisted it
# as a non-outlier.
# 2. Don't add an outlier as an extremity if it has no prev_events
" AND NOT EXISTS ("
" SELECT 1 FROM events"
" LEFT JOIN event_edges edge"
" ON edge.event_id = events.event_id"
" WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = FALSE OR edge.event_id IS NULL)"
" )"
room_id = events[0].room_id
potential_backwards_extremities = {
e_id
for ev in events
for e_id in ev.prev_event_ids()
if not ev.internal_metadata.is_outlier()
}
if not potential_backwards_extremities:
return
existing_events_outliers = self.db_pool.simple_select_many_txn(
txn,
table="events",
column="event_id",
iterable=potential_backwards_extremities,
keyvalues={"outlier": False},
retcols=("event_id",),
)
txn.execute_batch(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id)
for ev in events
for e_id in ev.prev_event_ids()
if not ev.internal_metadata.is_outlier()
],
potential_backwards_extremities.difference_update(
e for e, in existing_events_outliers
)
if potential_backwards_extremities:
self.db_pool.simple_upsert_many_txn(
txn,
table="event_backward_extremities",
key_names=("room_id", "event_id"),
key_values=[(room_id, ev) for ev in potential_backwards_extremities],
value_names=(),
value_values=(),
)
# Record the stream orderings where we have new gaps.
gap_events = [
(room_id, self._instance_name, ev.internal_metadata.stream_ordering)
for ev in events
if any(
e_id in potential_backwards_extremities
for e_id in ev.prev_event_ids()
)
]
self.db_pool.simple_insert_many_txn(
txn,
table="timeline_gaps",
keys=("room_id", "instance_name", "stream_ordering"),
values=gap_events,
)
# Delete all these events that we've already fetched and now know that their
# prev events are the new backwards extremeties.
query = (
@@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
# indexes on it.
# We need to pass execute a dummy function to handle the txn's result otherwise
# it tries to call fetchall() on it and fails because there's no result to fetch.
await self.db_pool.execute(
await self.db_pool.runInteraction(
"background_analyze_new_stream_ordering_column",
lambda txn: None,
"ANALYZE events(stream_ordering2)",
lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
)
await self.db_pool.runInteraction(
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import Any, Dict, List
from typing import List, Optional, Tuple, cast
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
@@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
async def get_forward_extremities_for_room(
self, room_id: str
) -> List[Dict[str, Any]]:
"""Get list of forward extremities for a room."""
) -> List[Tuple[str, int, int, Optional[int]]]:
"""
Get list of forward extremities for a room.
Returns:
A list of tuples of event_id, state_group, depth, and received_ts.
"""
def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
) -> List[Tuple[str, int, int, Optional[int]]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
@@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
"""
txn.execute(sql, (room_id,))
return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",
@@ -2095,12 +2095,6 @@ class EventsWorkerStore(SQLBaseStore):
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
"""
txn.execute(sql, (one_day_ago,))
sql = """
DELETE FROM event_txn_id_device_id
WHERE inserted_ts < ?
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, FrozenSet
from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
Returns:
the features currently enabled for the user
"""
enabled = await self.db_pool.simple_select_list(
"per_user_experimental_features",
{"user_id": user_id, "enabled": True},
["feature"],
enabled = cast(
List[Tuple[str]],
await self.db_pool.simple_select_list(
table="per_user_experimental_features",
keyvalues={"user_id": user_id, "enabled": True},
retcols=("feature",),
),
)
return frozenset(feature["feature"] for feature in enabled)
return frozenset(feature[0] for feature in enabled)
async def set_features_for_user(
self,
+19 -16
View File
@@ -248,17 +248,20 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
rows = await self.db_pool.simple_select_list(
table="server_keys_json",
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
rows = cast(
List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
await self.db_pool.simple_select_list(
table="server_keys_json",
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
),
desc="get_server_keys_json_for_remote",
)
if not rows:
@@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore):
# We sort the rows by ts_added_ms so that the most recently added entry
# will stomp over older entries in the dictionary.
rows.sort(key=lambda r: r["ts_added_ms"])
rows.sort(key=lambda r: r[2])
return {
row["key_id"]: FetchKeyResultForRemote(
key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
key_json=bytes(key_json),
valid_until_ts=ts_valid_until_ms,
added_ts=ts_added_ms,
)
for row in rows
for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}
@@ -437,25 +437,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
rows = await self.db_pool.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
rows = cast(
List[Tuple[int, int, str, str, int]],
await self.db_pool.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
),
desc="get_local_media_thumbnails",
),
desc="get_local_media_thumbnails",
)
return [
ThumbnailInfo(
width=row["thumbnail_width"],
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
for row in rows
]
@@ -568,25 +567,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_thumbnails(
self, origin: str, media_id: str
) -> List[ThumbnailInfo]:
rows = await self.db_pool.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
rows = cast(
List[Tuple[int, int, str, str, int]],
await self.db_pool.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
),
desc="get_remote_media_thumbnails",
),
desc="get_remote_media_thumbnails",
)
return [
ThumbnailInfo(
width=row["thumbnail_width"],
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
for row in rows
]
@@ -652,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
) -> List[Dict[str, str]]:
) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.
@@ -666,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
* The filesystem ID.
"""
sql = """
SELECT media_origin, media_id, filesystem_id
FROM remote_media_cache
WHERE last_access_ts < ?
"""
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
if include_quarantined_media is False:
# Only include media that has not been quarantined
@@ -679,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
AND quarantined_by IS NULL
"""
return await self.db_pool.execute(
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
return cast(
List[Tuple[str, str, str]],
await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:

Some files were not shown because too many files have changed in this diff Show More