1
0

Compare commits

...

45 Commits

Author SHA1 Message Date
Olivier Wilkinson (reivilibre) c91ab4bc55 Remove merge_into and just have merged which copies inputs to avoid footguns 2024-01-17 17:18:23 +00:00
Olivier Wilkinson (reivilibre) 29541fd994 Move stream_writers to their own field in the WorkerTemplate 2024-01-17 14:39:57 +00:00
Olivier Wilkinson (reivilibre) 2ff1de3b3c Newsfile
Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>
2024-01-10 12:15:14 +00:00
Olivier Wilkinson (reivilibre) ad4bb0e6b5 Tweak instantiate_worker_template, both in name, description and variable names 2024-01-10 12:14:50 +00:00
Olivier Wilkinson (reivilibre) 2f1d727038 Update comment on merged 2024-01-10 12:14:50 +00:00
Olivier Wilkinson (reivilibre) 3a46cf08e9 Fix comment and mutation bug on merge_worker_template_configs 2024-01-10 12:14:50 +00:00
Olivier Wilkinson (reivilibre) 259a808dc2 Docstring on WorkerTemplate 2024-01-10 12:14:50 +00:00
Olivier Wilkinson (reivilibre) 321d3590fe Add a --generate-only option 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) fbafde81a1 Promote mark_filepath to constant 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) 3bb21a9c26 Use merge_into when adding workers to the shared config 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) f49dbc7ba7 Add sharding_allowed to the WorkerTemplate rather than having a separate function for that 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) 7d8824e2dc Rename function to add_worker_to_instance_map given reduction of scope 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) f38297b619 Remove special logic for adding stream_writers: just make it part of the extra config template 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) 8b7463957f Add merge_into 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) 26073fa778 Tweak comments 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) 94a85b36f7 Convert listener_resources and endpoint_patterns to Set[str] 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) 67d4fc8b99 Collapse WORKERS_CONFIG by removing entries with defaults 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) ba3b6a4dfd Use a lambda for the worker name rather than search and replace later 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) a22eb7dc15 Convert worker templates into dataclass 2023-11-17 12:03:56 +00:00
Olivier Wilkinson (reivilibre) b39a50a43b Remove obsolete "app" from worker templates 2023-11-17 12:03:56 +00:00
Erik Johnston 1b238e8837 Speed up persisting large number of outliers (#16649)
Recalculating the roots tuple every iteration could be very expensive, so instead let's do a topological sort.
2023-11-16 14:25:35 +00:00
Erik Johnston fef08cbee8 Fix sending out of order POSITION over replication (#16639)
If a worker reconnects to Redis we send out the current positions of all our streams. However, if we're also trying to send out a backlog of RDATA at the same time then we can end up sending a `POSITION` with the current token *before* we've sent all the RDATA before the current token.

This doesn't cause actual bugs as the receiving servers see the POSITION, fetch the relevant rows from the DB, and then ignore the old RDATA as they come in. However, this is inefficient so it'd be better if we didn't  send out-of-order positions
2023-11-16 13:05:09 +00:00
Erik Johnston 898655fd12 More efficiently handle no-op POSITION (#16640)
We may receive `POSITION` commands where we already know that worker has
advanced past that position, so there is no point in handling it.
2023-11-16 12:32:17 +00:00
reivilibre 830988ae72 Fix test not detecting tables with missing primary keys and missing replica identities, then add more replica identities. (#16647)
* Fix the CI query that did not detect all cases of missing primary keys

* Add more missing REPLICA IDENTITY entries

* Newsfile

Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>

---------

Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>
2023-11-16 12:26:27 +00:00
David Robertson 43d1aa75e8 Add an Admin API to temporarily grant the ability to update an existing cross-signing key without UIA (#16634) 2023-11-15 17:28:10 +00:00
Sumner Evans 999bd77d3a Asynchronous Uploads (#15503)
Support asynchronous uploads as defined in MSC2246.
2023-11-15 09:19:24 -05:00
Patrick Cloke 80922dc46e Add links to pre-1.0 changelog issue/PR references. (#16638) 2023-11-15 13:31:24 +00:00
Patrick Cloke f2f2c7c1f0 Use full GitHub links instead of bare issue numbers. (#16637) 2023-11-15 08:02:11 -05:00
Will Hunt 4dd18bdc2e Improve documentation for /_synapse/admin/v1/rooms/<room_id>/timestamp_to_event (#16631) 2023-11-14 11:43:44 -05:00
Nick Mills-Barrett 0e36a57b60 Remove whole table locks on push rule add/delete (#16051)
The statements are already executed within a transaction thus a table
level lock is unnecessary.
2023-11-13 16:57:44 +00:00
reivilibre 69afe3f7a0 Add a Postgres REPLICA IDENTITY to tables that do not have an implicit one. This should allow use of Postgres logical replication. (#16456)
* Add Postgres replica identities to tables that don't have an implicit one

Fixes #16224

* Newsfile

Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>

* Move the delta to version 83 as we missed the boat for 82

* Add a test that all tables have a REPLICA IDENTITY

* Extend the test to include when indices are deleted

* isort

* black

* Fully qualify `oid` as it is a 'hidden attribute' in Postgres 11

* Update tests/storage/test_database.py

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>

* Add missed tables

---------

Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
2023-11-13 16:03:22 +00:00
David Robertson fb2554b11f Fix outbound_federation_restricted_to docs & note when added (#16628) 2023-11-13 14:26:49 +00:00
dependabot[bot] 7455b9e27d Bump serde from 1.0.190 to 1.0.192 (#16627)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-13 11:10:28 +00:00
dependabot[bot] 35fac66d20 Bump prometheus-client from 0.17.1 to 0.18.0 (#16626)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-13 11:09:30 +00:00
dependabot[bot] 69d1ee3feb Bump treq from 22.2.0 to 23.11.0 (#16623)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-13 11:08:31 +00:00
dependabot[bot] f92af19fa5 Bump types-pyopenssl from 23.2.0.2 to 23.3.0.0 (#16625)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-13 11:06:10 +00:00
dependabot[bot] 22a513014d Bump types-bleach from 6.1.0.0 to 6.1.0.1 (#16624)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-13 11:05:37 +00:00
dependabot[bot] ca7421b5fd Bump towncrier from 23.6.0 to 23.11.0 (#16622)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-13 11:05:02 +00:00
Patrick Cloke 2c6a7dfcbf Use attempt_to_set_autocommit everywhere. (#16615)
To avoid asserting the type of the database connection.
2023-11-09 16:19:42 -05:00
reivilibre dc7f068d9c Fix a long-standing bug where Synapse would not unbind third-party identifiers for Application Service users when deactivated and would not emit a compliant response. (#16617)
* Don't skip unbinding 3PIDs and returning success status when deactivating AS user

Fixes #16608

* Newsfile

Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>

---------

Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>
2023-11-09 20:18:25 +00:00
Patrick Cloke bc4372ad81 Use dbname instead of database for Postgres config. (#16618) 2023-11-09 14:40:45 -05:00
Patrick Cloke 9f514dd0fb Use _invalidate_cache_and_stream_bulk in more places. (#16616)
This takes advantage of the new bulk method in more places to
invalidate caches for many keys at once (and then to stream that
over replication).
2023-11-09 14:40:30 -05:00
Patrick Cloke ab3f1b3b53 Convert simple_select_one_txn and simple_select_one to return tuples. (#16612) 2023-11-09 11:13:31 -05:00
Patrick Cloke ff716b483b Return attrs for more media repo APIs. (#16611) 2023-11-09 11:00:30 -05:00
David Robertson 91587d4cf9 Bulk-invalidate e2e cached queries after claiming keys (#16613)
Co-authored-by: Patrick Cloke <patrickc@matrix.org>
2023-11-09 15:57:09 +00:00
145 changed files with 3729 additions and 1828 deletions
+6 -6
View File
@@ -8,21 +8,21 @@
# If ignoring a pull request that was not squash merged, only the merge
# commit needs to be put here. Child commits will be resolved from it.
# Run black (#3679).
# Run black (https://github.com/matrix-org/synapse/pull/3679).
8b3d9b6b199abb87246f982d5db356f1966db925
# Black reformatting (#5482).
# Black reformatting (https://github.com/matrix-org/synapse/pull/5482).
32e7c9e7f20b57dd081023ac42d6931a8da9b3a3
# Target Python 3.5 with black (#8664).
# Target Python 3.5 with black (https://github.com/matrix-org/synapse/pull/8664).
aff1eb7c671b0a3813407321d2702ec46c71fa56
# Update black to 20.8b1 (#9381).
# Update black to 20.8b1 (https://github.com/matrix-org/synapse/pull/9381).
0a00b7ff14890987f09112a2ae696c61001e6cf1
# Convert tests/rest/admin/test_room.py to unix file endings (#7953).
# Convert tests/rest/admin/test_room.py to unix file endings (https://github.com/matrix-org/synapse/pull/7953).
c4268e3da64f1abb5b31deaeb5769adb6510c0a7
# Update black to 23.1.0 (#15103)
# Update black to 23.1.0 (https://github.com/matrix-org/synapse/pull/15103)
9bb2eac71962970d02842bca441f4bcdbbf93a11
Generated
+4 -4
View File
@@ -332,18 +332,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "serde"
version = "1.0.190"
version = "1.0.192"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7"
checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.190"
version = "1.0.192"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3"
checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1"
dependencies = [
"proc-macro2",
"quote",
+1
View File
@@ -0,0 +1 @@
Add support for asynchronous uploads as defined by [MSC2246](https://github.com/matrix-org/matrix-spec-proposals/pull/2246). Contributed by @sumnerevans at @beeper.
+1
View File
@@ -0,0 +1 @@
Remove whole table locks on push rule modifications. Contributed by Nick @ Beeper (@fizzadar).
+1
View File
@@ -0,0 +1 @@
Add a Postgres `REPLICA IDENTITY` to tables that do not have an implicit one. This should allow use of Postgres logical replication.
+1
View File
@@ -0,0 +1 @@
Improve type hints.
+1
View File
@@ -0,0 +1 @@
Improve type hints.
+1
View File
@@ -0,0 +1 @@
Improve the performance of some operations in multi-worker deployments.
+1
View File
@@ -0,0 +1 @@
Use more generic database methods.
+1
View File
@@ -0,0 +1 @@
Improve the performance of some operations in multi-worker deployments.
+1
View File
@@ -0,0 +1 @@
Fix a long-standing bug where Synapse would not unbind third-party identifiers for Application Service users when deactivated and would not emit a compliant response.
+1
View File
@@ -0,0 +1 @@
Use `dbname` instead of the deprecated `database` connection parameter for psycopg2.
+1
View File
@@ -0,0 +1 @@
Note that the option [`outbound_federation_restricted_to`](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#outbound_federation_restricted_to) was added in Synapse 1.89.0, and fix a nearby formatting error.
+1
View File
@@ -0,0 +1 @@
Update parameter information for the `/timestamp_to_event` admin API.
+1
View File
@@ -0,0 +1 @@
Add an internal [Admin API endpoint](https://matrix-org.github.io/synapse/v1.97/usage/configuration/config_documentation.html#allow-replacing-master-cross-signing-key-without-user-interactive-auth) to temporarily grant the ability to update an existing cross-signing key without UIA.
+1
View File
@@ -0,0 +1 @@
Improve references to GitHub issues.
+1
View File
@@ -0,0 +1 @@
Improve references to GitHub issues.
+1
View File
@@ -0,0 +1 @@
Fix sending out of order `POSITION` over replication, causing additional database load.
+1
View File
@@ -0,0 +1 @@
More efficiently handle no-op `POSITION` over replication.
+1
View File
@@ -0,0 +1 @@
Add a Postgres `REPLICA IDENTITY` to tables that do not have an implicit one. This should allow use of Postgres logical replication.
+1
View File
@@ -0,0 +1 @@
Speed up persisting large number of outliers.
+1
View File
@@ -0,0 +1 @@
Refactor the `configure_workers_and_start.py` script used internally by Complement.
+1 -1
View File
@@ -1637,7 +1637,7 @@ matrix-synapse-py3 (0.99.3.1) stable; urgency=medium
matrix-synapse-py3 (0.99.3) stable; urgency=medium
[ Richard van der Hoff ]
* Fix warning during preconfiguration. (Fixes: #4819)
* Fix warning during preconfiguration. (Fixes: https://github.com/matrix-org/synapse/issues/4819)
[ Synapse Packaging team ]
* New synapse release 0.99.3.
@@ -6,7 +6,7 @@ command=/usr/local/bin/python -m synapse.app.complement_fork_starter
--config-path="{{ main_config_path }}"
--config-path=/conf/workers/shared.yaml
{%- for worker in workers %}
-- {{ worker.app }}
-- synapse.app.generic_worker
--config-path="{{ main_config_path }}"
--config-path=/conf/workers/shared.yaml
--config-path=/conf/workers/{{ worker.name }}.yaml
@@ -36,7 +36,7 @@ exitcodes=0
{% for worker in workers %}
[program:synapse_{{ worker.name }}]
command=/usr/local/bin/prefix-log /usr/local/bin/python -m {{ worker.app }}
command=/usr/local/bin/prefix-log /usr/local/bin/python -m synapse.app.generic_worker
--config-path="{{ main_config_path }}"
--config-path=/conf/workers/shared.yaml
--config-path=/conf/workers/{{ worker.name }}.yaml
+1 -1
View File
@@ -3,7 +3,7 @@
# Values will be change depending on whichever workers are selected when
# running that image.
worker_app: "{{ app }}"
worker_app: "synapse.app.generic_worker"
worker_name: "{{ name }}"
worker_listeners:
+257 -292
View File
@@ -47,16 +47,21 @@
# in the project's README), this script may be run multiple times, and functionality should
# continue to work if so.
import dataclasses
import os
import platform
import re
import subprocess
import sys
from argparse import ArgumentParser
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
@@ -78,9 +83,32 @@ MAIN_PROCESS_REPLICATION_PORT = 9093
MAIN_PROCESS_UNIX_SOCKET_PUBLIC_PATH = "/run/main_public.sock"
MAIN_PROCESS_UNIX_SOCKET_PRIVATE_PATH = "/run/main_private.sock"
# A simple name used as a placeholder in the WORKERS_CONFIG below. This will be replaced
# during processing with the name of the worker.
WORKER_PLACEHOLDER_NAME = "placeholder_name"
# We place a file at this path to indicate that the script has already been
# run and should not be run again.
MARKER_FILE_PATH = "/conf/workers_have_been_configured"
@dataclass
class WorkerTemplate:
"""
A definition of individual settings for a specific worker type.
A worker name can be fed into the template in order to generate a config.
These worker templates can be merged with `merge_worker_template_configs`
in order for a single worker to be made from multiple templates.
"""
listener_resources: Set[str] = field(default_factory=set)
endpoint_patterns: Set[str] = field(default_factory=set)
# (worker_name) -> {config}
shared_extra_conf: Callable[[str], Dict[str, Any]] = lambda _worker_name: {}
worker_extra_conf: str = ""
stream_writers: Set[str] = field(default_factory=set)
# True if and only if multiple of this worker type are allowed.
sharding_allowed: bool = True
# Workers with exposed endpoints needs either "client", "federation", or "media" listener_resources
# Watching /_matrix/client needs a "client" listener
@@ -88,75 +116,60 @@ WORKER_PLACEHOLDER_NAME = "placeholder_name"
# Watching /_matrix/media and related needs a "media" listener
# Stream Writers require "client" and "replication" listeners because they
# have to attach by instance_map to the master process and have client endpoints.
WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"pusher": {
"app": "synapse.app.generic_worker",
"listener_resources": [],
"endpoint_patterns": [],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"user_dir": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client"],
"endpoint_patterns": [
WORKERS_CONFIG: Dict[str, WorkerTemplate] = {
"pusher": WorkerTemplate(
shared_extra_conf=lambda worker_name: {
"pusher_instances": [worker_name],
}
),
"user_dir": WorkerTemplate(
listener_resources={"client"},
endpoint_patterns={
"^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$"
],
"shared_extra_conf": {
"update_user_directory_from_worker": WORKER_PLACEHOLDER_NAME
},
"worker_extra_conf": "",
},
"media_repository": {
"app": "synapse.app.generic_worker",
"listener_resources": ["media"],
"endpoint_patterns": [
shared_extra_conf=lambda worker_name: {
"update_user_directory_from_worker": worker_name
},
),
"media_repository": WorkerTemplate(
listener_resources={"media"},
endpoint_patterns={
"^/_matrix/media/",
"^/_synapse/admin/v1/purge_media_cache$",
"^/_synapse/admin/v1/room/.*/media.*$",
"^/_synapse/admin/v1/user/.*/media.*$",
"^/_synapse/admin/v1/media/.*$",
"^/_synapse/admin/v1/quarantine_media/.*$",
],
},
# The first configured media worker will run the media background jobs
"shared_extra_conf": {
shared_extra_conf=lambda worker_name: {
"enable_media_repo": False,
"media_instance_running_background_jobs": WORKER_PLACEHOLDER_NAME,
"media_instance_running_background_jobs": worker_name,
},
"worker_extra_conf": "enable_media_repo: true",
},
"appservice": {
"app": "synapse.app.generic_worker",
"listener_resources": [],
"endpoint_patterns": [],
"shared_extra_conf": {
"notify_appservices_from_worker": WORKER_PLACEHOLDER_NAME
worker_extra_conf="enable_media_repo: true",
),
"appservice": WorkerTemplate(
shared_extra_conf=lambda worker_name: {
"notify_appservices_from_worker": worker_name
},
"worker_extra_conf": "",
},
"federation_sender": {
"app": "synapse.app.generic_worker",
"listener_resources": [],
"endpoint_patterns": [],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"synchrotron": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client"],
"endpoint_patterns": [
),
"federation_sender": WorkerTemplate(
shared_extra_conf=lambda worker_name: {
"federation_sender_instances": [worker_name],
}
),
"synchrotron": WorkerTemplate(
listener_resources={"client"},
endpoint_patterns={
"^/_matrix/client/(v2_alpha|r0|v3)/sync$",
"^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$",
"^/_matrix/client/(api/v1|r0|v3)/initialSync$",
"^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$",
],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"client_reader": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client"],
"endpoint_patterns": [
},
),
"client_reader": WorkerTemplate(
listener_resources={"client"},
endpoint_patterns={
"^/_matrix/client/(api/v1|r0|v3|unstable)/publicRooms$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/joined_members$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$",
@@ -184,14 +197,11 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"^/_matrix/client/(api/v1|r0|v3|unstable)/directory/room/.*$",
"^/_matrix/client/(r0|v3|unstable)/capabilities$",
"^/_matrix/client/(r0|v3|unstable)/notifications$",
],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"federation_reader": {
"app": "synapse.app.generic_worker",
"listener_resources": ["federation"],
"endpoint_patterns": [
},
),
"federation_reader": WorkerTemplate(
listener_resources={"federation"},
endpoint_patterns={
"^/_matrix/federation/(v1|v2)/event/",
"^/_matrix/federation/(v1|v2)/state/",
"^/_matrix/federation/(v1|v2)/state_ids/",
@@ -211,97 +221,73 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"^/_matrix/federation/(v1|v2)/user/devices/",
"^/_matrix/federation/(v1|v2)/get_groups_publicised$",
"^/_matrix/key/v2/query",
],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"federation_inbound": {
"app": "synapse.app.generic_worker",
"listener_resources": ["federation"],
"endpoint_patterns": ["/_matrix/federation/(v1|v2)/send/"],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"event_persister": {
"app": "synapse.app.generic_worker",
"listener_resources": ["replication"],
"endpoint_patterns": [],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"background_worker": {
"app": "synapse.app.generic_worker",
"listener_resources": [],
"endpoint_patterns": [],
},
),
"federation_inbound": WorkerTemplate(
listener_resources={"federation"},
endpoint_patterns={"/_matrix/federation/(v1|v2)/send/"},
),
"event_persister": WorkerTemplate(
listener_resources={"replication"},
stream_writers={"events"},
),
"background_worker": WorkerTemplate(
# This worker cannot be sharded. Therefore, there should only ever be one
# background worker. This is enforced for the safety of your database.
"shared_extra_conf": {"run_background_tasks_on": WORKER_PLACEHOLDER_NAME},
"worker_extra_conf": "",
},
"event_creator": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client"],
"endpoint_patterns": [
shared_extra_conf=lambda worker_name: {"run_background_tasks_on": worker_name},
sharding_allowed=False,
),
"event_creator": WorkerTemplate(
listener_resources={"client"},
endpoint_patterns={
"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/redact",
"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/send",
"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/join/",
"^/_matrix/client/(api/v1|r0|v3|unstable)/knock/",
"^/_matrix/client/(api/v1|r0|v3|unstable)/profile/",
],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"frontend_proxy": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
"endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload"],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"account_data": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
"endpoint_patterns": [
},
),
"frontend_proxy": WorkerTemplate(
listener_resources={"client", "replication"},
endpoint_patterns={"^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload"},
),
"account_data": WorkerTemplate(
listener_resources={"client", "replication"},
endpoint_patterns={
"^/_matrix/client/(r0|v3|unstable)/.*/tags",
"^/_matrix/client/(r0|v3|unstable)/.*/account_data",
],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"presence": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
"endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/presence/"],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"receipts": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
"endpoint_patterns": [
},
stream_writers={"account_data"},
sharding_allowed=False,
),
"presence": WorkerTemplate(
listener_resources={"client", "replication"},
endpoint_patterns={"^/_matrix/client/(api/v1|r0|v3|unstable)/presence/"},
stream_writers={"presence"},
sharding_allowed=False,
),
"receipts": WorkerTemplate(
listener_resources={"client", "replication"},
endpoint_patterns={
"^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt",
"^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers",
],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"to_device": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
"endpoint_patterns": ["^/_matrix/client/(r0|v3|unstable)/sendToDevice/"],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"typing": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
"endpoint_patterns": [
"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing"
],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
},
stream_writers={"receipts"},
sharding_allowed=False,
),
"to_device": WorkerTemplate(
listener_resources={"client", "replication"},
endpoint_patterns={"^/_matrix/client/(r0|v3|unstable)/sendToDevice/"},
stream_writers={"to_device"},
sharding_allowed=False,
),
"typing": WorkerTemplate(
listener_resources={"client", "replication"},
endpoint_patterns={"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing"},
stream_writers={"typing"},
sharding_allowed=False,
),
}
# Templates for sections that may be inserted multiple times in config files
@@ -336,6 +322,45 @@ def flush_buffers() -> None:
sys.stderr.flush()
def merged(a: Any, b: Any) -> Any:
"""
Merges `a` and `b` together, returning the result.
The merge is performed with the following rules:
- dicts: values with the same key will be merged recursively
- lists: `new` will be appended to `dest`
- primitives: they will be checked for equality and inequality will result
in a ValueError
It is an error for `a` and `b` to be of different types.
"""
if isinstance(a, dict) and isinstance(b, dict):
result = {}
for key in set(a.keys()) | set(b.keys()):
if key in a and key in b:
result[key] = merged(a[key], b[key])
elif key in a:
result[key] = deepcopy(a[key])
else:
result[key] = deepcopy(b[key])
return result
elif isinstance(a, list) and isinstance(b, list):
return deepcopy(a) + deepcopy(b)
elif type(a) != type(b):
raise TypeError(f"Cannot merge {type(a).__name__} and {type(b).__name__}")
elif a != b:
raise ValueError(f"Cannot merge primitive values: {a!r} != {b!r}")
if type(a) not in {str, int, float, bool, None.__class__}:
raise TypeError(
f"Cannot use `merged` on type {a} as it may not be safe (must either be an immutable primitive or must have special copy/merge logic)"
)
return a
def convert(src: str, dst: str, **template_vars: object) -> None:
"""Generate a file from a template
@@ -364,138 +389,84 @@ def convert(src: str, dst: str, **template_vars: object) -> None:
outfile.write(rendered)
def add_worker_roles_to_shared_config(
def add_worker_to_instance_map(
shared_config: dict,
worker_types_set: Set[str],
worker_name: str,
worker_port: int,
) -> None:
"""Given a dictionary representing a config file shared across all workers,
append appropriate worker information to it for the current worker_type instance.
"""
Update the shared config map to add the worker in the instance_map.
Args:
shared_config: The config dict that all worker instances share (after being
converted to YAML)
worker_types_set: The type of worker (one of those defined in WORKERS_CONFIG).
This list can be a single worker type or multiple.
worker_name: The name of the worker instance.
worker_port: The HTTP replication port that the worker instance is listening on.
"""
# The instance_map config field marks the workers that write to various replication
# streams
instance_map = shared_config.setdefault("instance_map", {})
# This is a list of the stream_writers that there can be only one of. Events can be
# sharded, and therefore doesn't belong here.
singular_stream_writers = [
"account_data",
"presence",
"receipts",
"to_device",
"typing",
]
# Worker-type specific sharding config. Now a single worker can fulfill multiple
# roles, check each.
if "pusher" in worker_types_set:
shared_config.setdefault("pusher_instances", []).append(worker_name)
if "federation_sender" in worker_types_set:
shared_config.setdefault("federation_sender_instances", []).append(worker_name)
if "event_persister" in worker_types_set:
# Event persisters write to the events stream, so we need to update
# the list of event stream writers
shared_config.setdefault("stream_writers", {}).setdefault("events", []).append(
worker_name
)
# Map of stream writer instance names to host/ports combos
if os.environ.get("SYNAPSE_USE_UNIX_SOCKET", False):
instance_map[worker_name] = {
"path": f"/run/worker.{worker_port}",
}
else:
instance_map[worker_name] = {
"host": "localhost",
"port": worker_port,
}
# Update the list of stream writers. It's convenient that the name of the worker
# type is the same as the stream to write. Iterate over the whole list in case there
# is more than one.
for worker in worker_types_set:
if worker in singular_stream_writers:
shared_config.setdefault("stream_writers", {}).setdefault(
worker, []
).append(worker_name)
# Map of stream writer instance names to host/ports combos
# For now, all stream writers need http replication ports
if os.environ.get("SYNAPSE_USE_UNIX_SOCKET", False):
instance_map[worker_name] = {
"path": f"/run/worker.{worker_port}",
}
else:
instance_map[worker_name] = {
"host": "localhost",
"port": worker_port,
}
if os.environ.get("SYNAPSE_USE_UNIX_SOCKET", False):
instance_map[worker_name] = {
"path": f"/run/worker.{worker_port}",
}
else:
instance_map[worker_name] = {
"host": "localhost",
"port": worker_port,
}
def merge_worker_template_configs(
existing_dict: Optional[Dict[str, Any]],
to_be_merged_dict: Dict[str, Any],
left: WorkerTemplate,
right: WorkerTemplate,
) -> WorkerTemplate:
"""Merges two templates together, returning a new template that includes
the listeners, endpoint patterns and configuration from both.
Does not mutate the input templates.
"""
return WorkerTemplate(
# include listener resources from both
listener_resources=left.listener_resources | right.listener_resources,
# include endpoint patterns from both
endpoint_patterns=left.endpoint_patterns | right.endpoint_patterns,
# merge shared config dictionaries; the worker name will be replaced later
shared_extra_conf=lambda worker_name: merged(
left.shared_extra_conf(worker_name),
right.shared_extra_conf(worker_name),
),
# There is only one worker type that has a 'worker_extra_conf' and it is
# the media_repo. Since duplicate worker types on the same worker don't
# work, this is fine.
worker_extra_conf=(left.worker_extra_conf + right.worker_extra_conf),
# (This is unused, but in principle sharding this hybrid worker type
# would be allowed if both constituent types are shardable)
sharding_allowed=left.sharding_allowed and right.sharding_allowed,
# include stream writers from both
stream_writers=left.stream_writers | right.stream_writers,
)
def instantiate_worker_template(
template: WorkerTemplate, worker_name: str
) -> Dict[str, Any]:
"""When given an existing dict of worker template configuration consisting with both
dicts and lists, merge new template data from WORKERS_CONFIG(or create) and
return new dict.
"""Given a worker template, instantiate it into a worker configuration
(which is currently represented as a dictionary).
Args:
existing_dict: Either an existing worker template or a fresh blank one.
to_be_merged_dict: The template from WORKERS_CONFIGS to be merged into
existing_dict.
Returns: The newly merged together dict values.
template: The WorkerTemplate to template
worker_name: The name of the worker to use.
Returns: worker configuration dictionary
"""
new_dict: Dict[str, Any] = {}
if not existing_dict:
# It doesn't exist yet, just use the new dict(but take a copy not a reference)
new_dict = to_be_merged_dict.copy()
else:
for i in to_be_merged_dict.keys():
if (i == "endpoint_patterns") or (i == "listener_resources"):
# merge the two lists, remove duplicates
new_dict[i] = list(set(existing_dict[i] + to_be_merged_dict[i]))
elif i == "shared_extra_conf":
# merge dictionary's, the worker name will be replaced later
new_dict[i] = {**existing_dict[i], **to_be_merged_dict[i]}
elif i == "worker_extra_conf":
# There is only one worker type that has a 'worker_extra_conf' and it is
# the media_repo. Since duplicate worker types on the same worker don't
# work, this is fine.
new_dict[i] = existing_dict[i] + to_be_merged_dict[i]
else:
# Everything else should be identical, like "app", which only works
# because all apps are now generic_workers.
new_dict[i] = to_be_merged_dict[i]
return new_dict
def insert_worker_name_for_worker_config(
existing_dict: Dict[str, Any], worker_name: str
) -> Dict[str, Any]:
"""Insert a given worker name into the worker's configuration dict.
Args:
existing_dict: The worker_config dict that is imported into shared_config.
worker_name: The name of the worker to insert.
Returns: Copy of the dict with newly inserted worker name
"""
dict_to_edit = existing_dict.copy()
for k, v in dict_to_edit["shared_extra_conf"].items():
# Only proceed if it's the placeholder name string
if v == WORKER_PLACEHOLDER_NAME:
dict_to_edit["shared_extra_conf"][k] = worker_name
return dict_to_edit
worker_config_dict = dataclasses.asdict(template)
stream_writers_dict = {writer: worker_name for writer in template.stream_writers}
worker_config_dict["shared_extra_conf"] = merged(
template.shared_extra_conf(worker_name), stream_writers_dict
)
worker_config_dict["endpoint_patterns"] = sorted(template.endpoint_patterns)
worker_config_dict["listener_resources"] = sorted(template.listener_resources)
return worker_config_dict
def apply_requested_multiplier_for_worker(worker_types: List[str]) -> List[str]:
@@ -540,23 +511,6 @@ def apply_requested_multiplier_for_worker(worker_types: List[str]) -> List[str]:
return new_worker_types
def is_sharding_allowed_for_worker_type(worker_type: str) -> bool:
"""Helper to check to make sure worker types that cannot have multiples do not.
Args:
worker_type: The type of worker to check against.
Returns: True if allowed, False if not
"""
return worker_type not in [
"background_worker",
"account_data",
"presence",
"receipts",
"typing",
"to_device",
]
def split_and_strip_string(
given_string: str, split_char: str, max_split: SupportsIndex = -1
) -> List[str]:
@@ -682,7 +636,7 @@ def parse_worker_types(
)
if worker_type in worker_type_shard_counter:
if not is_sharding_allowed_for_worker_type(worker_type):
if not WORKERS_CONFIG[worker_type].sharding_allowed:
error(
f"There can be only a single worker with {worker_type} "
"type. Please recount and remove."
@@ -811,36 +765,35 @@ def generate_worker_files(
# Map locations to upstreams (corresponding to worker types) in Nginx
# but only if we use the appropriate worker type
for worker_type in all_worker_types_in_use:
for endpoint_pattern in WORKERS_CONFIG[worker_type]["endpoint_patterns"]:
for endpoint_pattern in sorted(WORKERS_CONFIG[worker_type].endpoint_patterns):
nginx_locations[endpoint_pattern] = f"http://{worker_type}"
# For each worker type specified by the user, create config values and write it's
# yaml config file
for worker_name, worker_types_set in requested_worker_types.items():
# The collected and processed data will live here.
worker_config: Dict[str, Any] = {}
worker_template: WorkerTemplate = WorkerTemplate()
# Merge all worker config templates for this worker into a single config
for worker_type in worker_types_set:
copy_of_template_config = WORKERS_CONFIG[worker_type].copy()
# Merge worker type template configuration data. It's a combination of lists
# and dicts, so use this helper.
worker_config = merge_worker_template_configs(
worker_config, copy_of_template_config
worker_template = merge_worker_template_configs(
worker_template, WORKERS_CONFIG[worker_type]
)
# Replace placeholder names in the config template with the actual worker name.
worker_config = insert_worker_name_for_worker_config(worker_config, worker_name)
worker_config: Dict[str, Any] = instantiate_worker_template(
worker_template, worker_name
)
worker_config.update(
{"name": worker_name, "port": str(worker_port), "config_path": config_path}
)
# Update the shared config with any worker_type specific options. The first of a
# given worker_type needs to stay assigned and not be replaced.
worker_config["shared_extra_conf"].update(shared_config)
shared_config = worker_config["shared_extra_conf"]
# Update the shared config with any options needed to enable this worker.
shared_config = merged(shared_config, worker_config["shared_extra_conf"])
if using_unix_sockets:
healthcheck_urls.append(
f"--unix-socket /run/worker.{worker_port} http://localhost/health"
@@ -848,10 +801,10 @@ def generate_worker_files(
else:
healthcheck_urls.append("http://localhost:%d/health" % (worker_port,))
# Update the shared config with sharding-related options if necessary
add_worker_roles_to_shared_config(
shared_config, worker_types_set, worker_name, worker_port
)
# Add all workers to the `instance_map`
# Technically only certain types of workers, such as stream writers, are needed
# here but it is simpler just to be consistent.
add_worker_to_instance_map(shared_config, worker_name, worker_port)
# Enable the worker in supervisord
worker_descriptors.append(worker_config)
@@ -1018,6 +971,14 @@ def generate_worker_log_config(
def main(args: List[str], environ: MutableMapping[str, str]) -> None:
parser = ArgumentParser()
parser.add_argument(
"--generate-only",
action="store_true",
help="Only generate configuration; don't run Synapse.",
)
opts = parser.parse_args(args)
config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
config_path = environ.get("SYNAPSE_CONFIG_PATH", config_dir + "/homeserver.yaml")
data_dir = environ.get("SYNAPSE_DATA_DIR", "/data")
@@ -1034,8 +995,8 @@ def main(args: List[str], environ: MutableMapping[str, str]) -> None:
log("Base homeserver config exists—not regenerating")
# This script may be run multiple times (mostly by Complement, see note at top of
# file). Don't re-configure workers in this instance.
mark_filepath = "/conf/workers_have_been_configured"
if not os.path.exists(mark_filepath):
if not os.path.exists(MARKER_FILE_PATH):
# Collect and validate worker_type requests
# Read the desired worker configuration from the environment
worker_types_env = environ.get("SYNAPSE_WORKER_TYPES", "").strip()
@@ -1054,11 +1015,15 @@ def main(args: List[str], environ: MutableMapping[str, str]) -> None:
generate_worker_files(environ, config_path, data_dir, requested_worker_types)
# Mark workers as being configured
with open(mark_filepath, "w") as f:
with open(MARKER_FILE_PATH, "w") as f:
f.write("")
else:
log("Worker config exists—not regenerating")
if opts.generate_only:
log("--generate-only: won't run Synapse")
return
# Lifted right out of start.py
jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),)
@@ -1081,4 +1046,4 @@ def main(args: List[str], environ: MutableMapping[str, str]) -> None:
if __name__ == "__main__":
main(sys.argv, os.environ)
main(sys.argv[1:], os.environ)
+2 -1
View File
@@ -536,7 +536,8 @@ The following query parameters are available:
**Response**
* `event_id` - converted from timestamp
* `event_id` - The event ID closest to the given timestamp.
* `origin_server_ts` - The timestamp of the event in milliseconds since the Unix epoch.
# Block Room API
The Block Room admin API allows server admins to block and unblock rooms,
+37
View File
@@ -773,6 +773,43 @@ Note: The token will expire if the *admin* user calls `/logout/all` from any
of their devices, but the token will *not* expire if the target user does the
same.
## Allow replacing master cross-signing key without User-Interactive Auth
This endpoint is not intended for server administrator usage;
we describe it here for completeness.
This API temporarily permits a user to replace their master cross-signing key
without going through
[user-interactive authentication](https://spec.matrix.org/v1.8/client-server-api/#user-interactive-authentication-api) (UIA).
This is useful when Synapse has delegated its authentication to the
[Matrix Authentication Service](https://github.com/matrix-org/matrix-authentication-service/);
as Synapse cannot perform UIA is not possible in these circumstances.
The API is
```http request
POST /_synapse/admin/v1/users/<user_id>/_allow_cross_signing_replacement_without_uia
{}
```
If the user does not exist, or does exist but has no master cross-signing key,
this will return with status code `404 Not Found`.
Otherwise, a response body like the following is returned, with status `200 OK`:
```json
{
"updatable_without_uia_before_ms": 1234567890
}
```
The response body is a JSON object with a single field:
- `updatable_without_uia_before_ms`: integer. The timestamp in milliseconds
before which the user is permitted to replace their cross-signing key without
going through UIA.
_Added in Synapse 1.97.0._
## User devices
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -66,7 +66,7 @@ database:
args:
user: <user>
password: <pass>
database: <db>
dbname: <db>
host: <host>
cp_min: 5
cp_max: 10
@@ -1447,7 +1447,7 @@ database:
args:
user: synapse_user
password: secretpassword
database: synapse
dbname: synapse
host: localhost
port: 5432
cp_min: 5
@@ -1526,7 +1526,7 @@ databases:
args:
user: synapse_user
password: secretpassword
database: synapse_main
dbname: synapse_main
host: localhost
port: 5432
cp_min: 5
@@ -1539,7 +1539,7 @@ databases:
args:
user: synapse_user
password: secretpassword
database: synapse_state
dbname: synapse_state
host: localhost
port: 5432
cp_min: 5
@@ -1753,6 +1753,19 @@ rc_third_party_invite:
burst_count: 10
```
---
### `rc_media_create`
This option ratelimits creation of MXC URIs via the `/_matrix/media/v1/create`
endpoint based on the account that's creating the media. Defaults to
`per_second: 10`, `burst_count: 50`.
Example configuration:
```yaml
rc_media_create:
per_second: 10
burst_count: 50
```
---
### `rc_federation`
Defines limits on federation requests.
@@ -1814,6 +1827,27 @@ Example configuration:
media_store_path: "DATADIR/media_store"
```
---
### `max_pending_media_uploads`
How many *pending media uploads* can a given user have? A pending media upload
is a created MXC URI that (a) is not expired (the `unused_expires_at` timestamp
has not passed) and (b) the media has not yet been uploaded for. Defaults to 5.
Example configuration:
```yaml
max_pending_media_uploads: 5
```
---
### `unused_expiration_time`
How long to wait in milliseconds before expiring created media IDs. Defaults to
"24h"
Example configuration:
```yaml
unused_expiration_time: "1h"
```
---
### `media_storage_providers`
Media storage providers allow media to be stored in different
@@ -4219,6 +4253,9 @@ outbound_federation_restricted_to:
Also see the [worker
documentation](../../workers.md#restrict-outbound-federation-traffic-to-a-specific-set-of-workers)
for more info.
_Added in Synapse 1.89.0._
---
### `run_background_tasks_on`
Generated
+21 -35
View File
@@ -416,19 +416,6 @@ files = [
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "click-default-group"
version = "1.2.2"
description = "Extends click.Group to invoke a command without explicit subcommand name"
optional = false
python-versions = "*"
files = [
{file = "click-default-group-1.2.2.tar.gz", hash = "sha256:d9560e8e8dfa44b3562fbc9425042a0fd6d21956fcc2db0077f63f34253ab904"},
]
[package.dependencies]
click = "*"
[[package]]
name = "colorama"
version = "0.4.6"
@@ -1742,13 +1729,13 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytes
[[package]]
name = "prometheus-client"
version = "0.17.1"
version = "0.18.0"
description = "Python client for the Prometheus monitoring system."
optional = false
python-versions = ">=3.6"
python-versions = ">=3.8"
files = [
{file = "prometheus_client-0.17.1-py3-none-any.whl", hash = "sha256:e537f37160f6807b8202a6fc4764cdd19bac5480ddd3e0d463c3002b34462101"},
{file = "prometheus_client-0.17.1.tar.gz", hash = "sha256:21e674f39831ae3f8acde238afd9a27a37d0d2fb5a28ea094f0ce25d2cbf2091"},
{file = "prometheus_client-0.18.0-py3-none-any.whl", hash = "sha256:8de3ae2755f890826f4b6479e5571d4f74ac17a81345fe69a6778fdb92579184"},
{file = "prometheus_client-0.18.0.tar.gz", hash = "sha256:35f7a8c22139e2bb7ca5a698e92d38145bc8dc74c1c0bf56f25cca886a764e17"},
]
[package.extras]
@@ -2906,18 +2893,17 @@ files = [
[[package]]
name = "towncrier"
version = "23.6.0"
version = "23.11.0"
description = "Building newsfiles for your project."
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "towncrier-23.6.0-py3-none-any.whl", hash = "sha256:da552f29192b3c2b04d630133f194c98e9f14f0558669d427708e203fea4d0a5"},
{file = "towncrier-23.6.0.tar.gz", hash = "sha256:fc29bd5ab4727c8dacfbe636f7fb5dc53b99805b62da1c96b214836159ff70c1"},
{file = "towncrier-23.11.0-py3-none-any.whl", hash = "sha256:2e519ca619426d189e3c98c99558fe8be50c9ced13ea1fc20a4a353a95d2ded7"},
{file = "towncrier-23.11.0.tar.gz", hash = "sha256:13937c247e3f8ae20ac44d895cf5f96a60ad46cfdcc1671759530d7837d9ee5d"},
]
[package.dependencies]
click = "*"
click-default-group = "*"
importlib-resources = {version = ">=5", markers = "python_version < \"3.10\""}
incremental = "*"
jinja2 = "*"
@@ -2928,13 +2914,13 @@ dev = ["furo", "packaging", "sphinx (>=5)", "twisted"]
[[package]]
name = "treq"
version = "22.2.0"
version = "23.11.0"
description = "High-level Twisted HTTP Client API"
optional = false
python-versions = ">=3.6"
files = [
{file = "treq-22.2.0-py3-none-any.whl", hash = "sha256:27d95b07c5c14be3e7b280416139b036087617ad5595be913b1f9b3ce981b9b2"},
{file = "treq-22.2.0.tar.gz", hash = "sha256:df757e3f141fc782ede076a604521194ffcb40fa2645cf48e5a37060307f52ec"},
{file = "treq-23.11.0-py3-none-any.whl", hash = "sha256:f494c2218d61cab2cabbee37cd6606d3eea9d16cf14190323095c95d22c467e9"},
{file = "treq-23.11.0.tar.gz", hash = "sha256:0914ff929fd1632ce16797235260f8bc19d20ff7c459c1deabd65b8c68cbeac5"},
]
[package.dependencies]
@@ -2942,11 +2928,11 @@ attrs = "*"
hyperlink = ">=21.0.0"
incremental = "*"
requests = ">=2.1.0"
Twisted = {version = ">=18.7.0", extras = ["tls"]}
Twisted = {version = ">=22.10.0", extras = ["tls"]}
[package.extras]
dev = ["httpbin (==0.5.0)", "pep8", "pyflakes"]
docs = ["sphinx (>=1.4.8)"]
dev = ["httpbin (==0.7.0)", "pep8", "pyflakes", "werkzeug (==2.0.3)"]
docs = ["sphinx (<7.0.0)"]
[[package]]
name = "twine"
@@ -3047,13 +3033,13 @@ twisted = "*"
[[package]]
name = "types-bleach"
version = "6.1.0.0"
version = "6.1.0.1"
description = "Typing stubs for bleach"
optional = false
python-versions = ">=3.7"
files = [
{file = "types-bleach-6.1.0.0.tar.gz", hash = "sha256:3cf0e55d4618890a00af1151f878b2e2a7a96433850b74e12bede7663d774532"},
{file = "types_bleach-6.1.0.0-py3-none-any.whl", hash = "sha256:f0bc75d0f6475036ac69afebf37c41d116dfba78dae55db80437caf0fcd35c28"},
{file = "types-bleach-6.1.0.1.tar.gz", hash = "sha256:1e43c437e734a90efe4f40ebfe831057599568d3b275939ffbd6094848a18a27"},
{file = "types_bleach-6.1.0.1-py3-none-any.whl", hash = "sha256:f83f80e0709f13d809a9c79b958a1089df9b99e68059287beb196e38967e4ddf"},
]
[[package]]
@@ -3127,13 +3113,13 @@ files = [
[[package]]
name = "types-pyopenssl"
version = "23.2.0.2"
version = "23.3.0.0"
description = "Typing stubs for pyOpenSSL"
optional = false
python-versions = "*"
python-versions = ">=3.7"
files = [
{file = "types-pyOpenSSL-23.2.0.2.tar.gz", hash = "sha256:6a010dac9ecd42b582d7dd2cc3e9e40486b79b3b64bb2fffba1474ff96af906d"},
{file = "types_pyOpenSSL-23.2.0.2-py3-none-any.whl", hash = "sha256:19536aa3debfbe25a918cf0d898e9f5fbbe6f3594a429da7914bf331deb1b342"},
{file = "types-pyOpenSSL-23.3.0.0.tar.gz", hash = "sha256:5ffb077fe70b699c88d5caab999ae80e192fe28bf6cda7989b7e79b1e4e2dcd3"},
{file = "types_pyOpenSSL-23.3.0.0-py3-none-any.whl", hash = "sha256:00171433653265843b7469ddb9f3c86d698668064cc33ef10537822156130ebf"},
]
[package.dependencies]
+5 -4
View File
@@ -192,7 +192,7 @@ phonenumbers = ">=8.2.0"
# we use GaugeHistogramMetric, which was added in prom-client 0.4.0.
prometheus-client = ">=0.4.0"
# we use `order`, which arrived in attrs 19.2.0.
# Note: 21.1.0 broke `/sync`, see #9936
# Note: 21.1.0 broke `/sync`, see https://github.com/matrix-org/synapse/issues/9936
attrs = ">=19.2.0,!=21.1.0"
netaddr = ">=0.7.18"
# Jinja 2.x is incompatible with MarkupSafe>=2.1. To ensure that admins do not
@@ -357,7 +357,7 @@ commonmark = ">=0.9.1"
pygithub = ">=1.55"
# The following are executed as commands by the release script.
twine = "*"
# Towncrier min version comes from #3425. Rationale unclear.
# Towncrier min version comes from https://github.com/matrix-org/synapse/pull/3425. Rationale unclear.
towncrier = ">=18.6.0rc1"
# Used for checking the Poetry lockfile
@@ -377,8 +377,9 @@ furo = ">=2022.12.7,<2024.0.0"
[build-system]
# The upper bounds here are defensive, intended to prevent situations like
# #13849 and #14079 where we see buildtime or runtime errors caused by build
# system changes.
# https://github.com/matrix-org/synapse/issues/13849 and
# https://github.com/matrix-org/synapse/issues/14079 where we see buildtime or
# runtime errors caused by build system changes.
# We are happy to raise these upper bounds upon request,
# provided we check that it's safe to do so (i.e. that CI passes).
requires = ["poetry-core>=1.1.0,<=1.7.0", "setuptools_rust>=1.3,<=1.8.1"]
+1 -2
View File
@@ -348,8 +348,7 @@ class Porter:
backward_chunk = 0
already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
forward_chunk, backward_chunk = row
if total_to_port is None:
already_ported, total_to_port = await self._get_total_count_to_port(
+2
View File
@@ -83,6 +83,8 @@ class Codes(str, Enum):
USER_DEACTIVATED = "M_USER_DEACTIVATED"
# USER_LOCKED = "M_USER_LOCKED"
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED"
CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA"
# Part of MSC3848
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848
+2 -2
View File
@@ -104,8 +104,8 @@ logger = logging.getLogger("synapse.app.generic_worker")
class GenericWorkerStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
# FIXME(https://github.com/matrix-org/synapse/issues/3714): We need to add
# UserDirectoryStore as we write directly rather than going via the correct worker.
UserDirectoryStore,
StatsStore,
UIAuthWorkerStore,
+7
View File
@@ -204,3 +204,10 @@ class RatelimitConfig(Config):
"rc_third_party_invite",
defaults={"per_second": 0.0025, "burst_count": 5},
)
# Ratelimit create media requests:
self.rc_media_create = RatelimitSettings.parse(
config,
"rc_media_create",
defaults={"per_second": 10, "burst_count": 50},
)
+6
View File
@@ -141,6 +141,12 @@ class ContentRepositoryConfig(Config):
"prevent_media_downloads_from", []
)
self.unused_expiration_time = self.parse_duration(
config.get("unused_expiration_time", "24h")
)
self.max_pending_media_uploads = config.get("max_pending_media_uploads", 5)
self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
+2 -2
View File
@@ -581,14 +581,14 @@ class FederationSender(AbstractFederationSender):
"get_joined_hosts", str(sg)
)
if destinations is None:
# Add logging to help track down #13444
# Add logging to help track down https://github.com/matrix-org/synapse/issues/13444
logger.info(
"Unexpectedly did not have cached destinations for %s / %s",
sg,
event.event_id,
)
else:
# Add logging to help track down #13444
# Add logging to help track down https://github.com/matrix-org/synapse/issues/13444
logger.info(
"Unexpectedly did not have cached prev group for %s",
event.event_id,
+13 -7
View File
@@ -1450,19 +1450,25 @@ class E2eKeysHandler:
return desired_key_data
async def is_cross_signing_set_up_for_user(self, user_id: str) -> bool:
async def check_cross_signing_setup(self, user_id: str) -> Tuple[bool, bool]:
"""Checks if the user has cross-signing set up
Args:
user_id: The user to check
Returns:
True if the user has cross-signing set up, False otherwise
Returns: a 2-tuple of booleans
- whether the user has cross-signing set up, and
- whether the user's master cross-signing key may be replaced without UIA.
"""
existing_master_key = await self.store.get_e2e_cross_signing_key(
user_id, "master"
)
return existing_master_key is not None
(
exists,
ts_replacable_without_uia_before,
) = await self.store.get_master_cross_signing_key_updatable_before(user_id)
if ts_replacable_without_uia_before is None:
return exists, False
else:
return exists, self.clock.time_msec() < ts_replacable_without_uia_before
def _check_cross_signing_key(
+8 -12
View File
@@ -88,7 +88,7 @@ from synapse.types import (
)
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.iterutils import batch_iter, partition
from synapse.util.iterutils import batch_iter, partition, sorted_topologically_batched
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
@@ -748,7 +748,7 @@ class FederationEventHandler:
# fetching fresh state for the room if the missing event
# can't be found, which slightly reduces our security.
# it may also increase our DAG extremity count for the room,
# causing additional state resolution? See #1760.
# causing additional state resolution? See https://github.com/matrix-org/synapse/issues/1760.
# However, fetching state doesn't hold the linearizer lock
# apparently.
#
@@ -1669,14 +1669,13 @@ class FederationEventHandler:
# XXX: it might be possible to kick this process off in parallel with fetching
# the events.
while event_map:
# build a list of events whose auth events are not in the queue.
roots = tuple(
ev
for ev in event_map.values()
if not any(aid in event_map for aid in ev.auth_event_ids())
)
# We need to persist an event's auth events before the event.
auth_graph = {
ev: [event_map[e_id] for e_id in ev.auth_event_ids() if e_id in event_map]
for ev in event_map.values()
}
for roots in sorted_topologically_batched(event_map.values(), auth_graph):
if not roots:
# if *none* of the remaining events are ready, that means
# we have a loop. This either means a bug in our logic, or that
@@ -1698,9 +1697,6 @@ class FederationEventHandler:
await self._auth_and_persist_outliers_inner(room_id, roots)
for ev in roots:
del event_map[ev.event_id]
async def _auth_and_persist_outliers_inner(
self, room_id: str, fetched_events: Collection[EventBase]
) -> None:
+1 -1
View File
@@ -1816,7 +1816,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# the same token repeatedly.
#
# Hence this guard where we just return nothing so that the sync
# doesn't return. C.f. #5503.
# doesn't return. C.f. https://github.com/matrix-org/synapse/issues/5503.
return [], max_token
# Figure out which other users this user should explicitly receive
+9 -6
View File
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
from synapse.api.errors import (
AuthError,
@@ -23,6 +23,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -306,7 +307,9 @@ class ProfileHandler:
server_name = host
if self._is_mine_server_name(server_name):
media_info = await self.store.get_local_media(media_id)
media_info: Optional[
Union[LocalMedia, RemoteMedia]
] = await self.store.get_local_media(media_id)
else:
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@@ -322,12 +325,12 @@ class ProfileHandler:
if self.max_avatar_size:
# Ensure avatar does not exceed max allowed avatar size
if media_info["media_length"] > self.max_avatar_size:
if media_info.media_length > self.max_avatar_size:
logger.warning(
"Forbidding avatar change to %s: %d bytes is above the allowed size "
"limit",
mxc,
media_info["media_length"],
media_info.media_length,
)
return False
@@ -335,12 +338,12 @@ class ProfileHandler:
# Ensure the avatar's file type is allowed
if (
self.allowed_avatar_mimetypes
and media_info["media_type"] not in self.allowed_avatar_mimetypes
and media_info.media_type not in self.allowed_avatar_mimetypes
):
logger.warning(
"Forbidding avatar change to %s: mimetype %s not allowed",
mxc,
media_info["media_type"],
media_info.media_type,
)
return False
+3 -3
View File
@@ -269,7 +269,7 @@ class RoomCreationHandler:
self,
requester: Requester,
old_room_id: str,
old_room: Dict[str, Any],
old_room: Tuple[bool, str, bool],
new_room_id: str,
new_version: RoomVersion,
tombstone_event: EventBase,
@@ -279,7 +279,7 @@ class RoomCreationHandler:
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
old_room: a dict containing room information for the room to be replaced,
old_room: a tuple containing room information for the room to be replaced,
as returned by `RoomWorkerStore.get_room`.
new_room_id: the id of the replacement room
new_version: the version to upgrade the room to
@@ -299,7 +299,7 @@ class RoomCreationHandler:
await self.store.store_room(
room_id=new_room_id,
room_creator_user_id=user_id,
is_public=old_room["is_public"],
is_public=old_room[0],
room_version=new_version,
)
+2 -1
View File
@@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
if old_room is not None and old_room["is_public"]:
# If the old room exists and is public.
if old_room is not None and old_room[0]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)
+1 -1
View File
@@ -806,7 +806,7 @@ class SsoHandler:
media_id = profile["avatar_url"].split("/")[-1]
if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id)
if media is not None and upload_name == media["upload_name"]:
if media is not None and upload_name == media.upload_name:
logger.info("skipping saving the user avatar")
return True
+2 -2
View File
@@ -399,7 +399,7 @@ class SyncHandler:
#
# If that happens, we mustn't cache it, so that when the client comes back
# with the same cache token, we don't immediately return the same empty
# result, causing a tightloop. (#8518)
# result, causing a tightloop. (https://github.com/matrix-org/synapse/issues/8518)
if result.next_batch == since_token:
cache_context.should_cache = False
@@ -1003,7 +1003,7 @@ class SyncHandler:
# always make sure we LL ourselves so we know we're in the room
# (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
# We only need apply this on full state syncs given we disabled
# LL for incr syncs in #3840.
# LL for incr syncs in https://github.com/matrix-org/synapse/pull/3840.
# We don't insert ourselves into `members_to_fetch`, because in some
# rare cases (an empty event batch with a now_token after the user's
# leave in a partial state room which another local user has
+4 -4
View File
@@ -184,8 +184,8 @@ class UserDirectoryHandler(StateDeltasHandler):
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
# FIXME(https://github.com/matrix-org/synapse/issues/3714): We should
# probably do this in the same worker as all the other changes.
if await self.store.should_include_local_user_in_dir(user_id):
await self.store.update_profile_in_user_dir(
@@ -194,8 +194,8 @@ class UserDirectoryHandler(StateDeltasHandler):
async def handle_local_user_deactivated(self, user_id: str) -> None:
"""Called when a user ID is deactivated"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
# FIXME(https://github.com/matrix-org/synapse/issues/3714): We should
# probably do this in the same worker as all the other changes.
await self.store.remove_from_user_dir(user_id)
async def _unsafe_process(self) -> None:
+8 -6
View File
@@ -465,7 +465,7 @@ class MatrixFederationHttpClient:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3
due to #3622.
due to https://github.com/matrix-org/synapse/issues/3622.
Args:
request: details of request to be sent
@@ -958,9 +958,9 @@ class MatrixFederationHttpClient:
requests).
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
will be attempted before backing off if backing off has been
enabled.
of the request. Workaround for https://github.com/matrix-org/synapse/issues/3622
in Synapse <= v0.99.3. This will be attempted before backing off if
backing off has been enabled.
parser: The parser to use to decode the response. Defaults to
parsing as JSON.
backoff_on_all_error_codes: Back off if we get any error response
@@ -1155,7 +1155,8 @@ class MatrixFederationHttpClient:
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
the request. Workaround for https://github.com/matrix-org/synapse/issues/3622
in Synapse <= v0.99.3.
parser: The parser to use to decode the response. Defaults to
parsing as JSON.
@@ -1250,7 +1251,8 @@ class MatrixFederationHttpClient:
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
the request. Workaround for https://github.com/matrix-org/synapse/issues/3622
in Synapse <= v0.99.3.
parser: The parser to use to decode the response. Defaults to
parsing as JSON.
+6
View File
@@ -83,6 +83,12 @@ INLINE_CONTENT_TYPES = [
"audio/x-flac",
]
# Default timeout_ms for download and thumbnail requests
DEFAULT_MAX_TIMEOUT_MS = 20_000
# Maximum allowed timeout_ms for download and thumbnail requests
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS = 60_000
def respond_404(request: SynapseRequest) -> None:
assert request.path is not None
+243 -41
View File
@@ -19,6 +19,7 @@ import shutil
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import attr
from matrix_common.types.mxc_uri import MXCUri
import twisted.internet.error
@@ -26,13 +27,16 @@ import twisted.web.http
from twisted.internet.defer import Deferred
from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
NotFoundError,
RequestSendFailed,
SynapseError,
cs_error,
)
from synapse.config.repository import ThumbnailRequirement
from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.logging.opentracing import trace
@@ -50,6 +54,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@@ -78,6 +83,8 @@ class MediaRepository:
self.store = hs.get_datastores().main
self.max_upload_size = hs.config.media.max_upload_size
self.max_image_pixels = hs.config.media.max_image_pixels
self.unused_expiration_time = hs.config.media.unused_expiration_time
self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
Thumbnailer.set_limits(self.max_image_pixels)
@@ -183,6 +190,117 @@ class MediaRepository:
else:
self.recently_accessed_locals.add(media_id)
@trace
async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]:
"""Create and store a media ID for a local user and return the MXC URI and its
expiration.
Args:
auth_user: The user_id of the uploader
Returns:
A tuple containing the MXC URI of the stored content and the timestamp at
which the MXC URI expires.
"""
media_id = random_string(24)
now = self.clock.time_msec()
await self.store.store_local_media_id(
media_id=media_id,
time_now_ms=now,
user_id=auth_user,
)
return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time
@trace
async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]:
"""Check if the user is over the limit for pending media uploads.
Args:
auth_user: The user_id of the uploader
Returns:
A tuple with a boolean and an integer indicating whether the user has too
many pending media uploads and the timestamp at which the first pending
media will expire, respectively.
"""
pending, first_expiration_ts = await self.store.count_pending_media(
user_id=auth_user
)
return pending >= self.max_pending_media_uploads, first_expiration_ts
@trace
async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None:
"""Verify that the media ID can be uploaded to by the given user. This
function checks that:
* the media ID exists
* the media ID does not already have content
* the user uploading is the same as the one who created the media ID
* the media ID has not expired
Args:
media_id: The media ID to verify
auth_user: The user_id of the uploader
"""
media = await self.store.get_local_media(media_id)
if media is None:
raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND)
if media.user_id != auth_user.to_string():
raise SynapseError(
403,
"Only the creator of the media ID can upload to it",
errcode=Codes.FORBIDDEN,
)
if media.media_length is not None:
raise SynapseError(
409,
"Media ID already has content",
errcode=Codes.CANNOT_OVERWRITE_MEDIA,
)
expired_time_ms = self.clock.time_msec() - self.unused_expiration_time
if media.created_ts < expired_time_ms:
raise NotFoundError("Media ID has expired")
@trace
async def update_content(
self,
media_id: str,
media_type: str,
upload_name: Optional[str],
content: IO,
content_length: int,
auth_user: UserID,
) -> None:
"""Update the content of the given media ID.
Args:
media_id: The media ID to replace.
media_type: The content type of the file.
upload_name: The name of the file, if provided.
content: A file like object that is the content to store
content_length: The length of the content
auth_user: The user_id of the uploader
"""
file_info = FileInfo(server_name=None, file_id=media_id)
fname = await self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
await self.store.update_local_media(
media_id=media_id,
media_type=media_type,
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
)
try:
await self._generate_thumbnails(None, media_id, media_id, media_type)
except Exception as e:
logger.info("Failed to generate thumbnails: %s", e)
@trace
async def create_content(
self,
@@ -229,8 +347,74 @@ class MediaRepository:
return MXCUri(self.server_name, media_id)
def respond_not_yet_uploaded(self, request: SynapseRequest) -> None:
respond_with_json(
request,
504,
cs_error("Media has not been uploaded yet", code=Codes.NOT_YET_UPLOADED),
send_cors=True,
)
async def get_local_media_info(
self, request: SynapseRequest, media_id: str, max_timeout_ms: int
) -> Optional[LocalMedia]:
"""Gets the info dictionary for given local media ID. If the media has
not been uploaded yet, this function will wait up to ``max_timeout_ms``
milliseconds for the media to be uploaded.
Args:
request: The incoming request.
media_id: The media ID of the content. (This is the same as
the file_id for local content.)
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns:
Either the info dictionary for the given local media ID or
``None``. If ``None``, then no further processing is necessary as
this function will send the necessary JSON response.
"""
wait_until = self.clock.time_msec() + max_timeout_ms
while True:
# Get the info for the media
media_info = await self.store.get_local_media(media_id)
if not media_info:
logger.info("Media %s is unknown", media_id)
respond_404(request)
return None
if media_info.quarantined_by:
logger.info("Media %s is quarantined", media_id)
respond_404(request)
return None
# The file has been uploaded, so stop looping
if media_info.media_length is not None:
return media_info
# Check if the media ID has expired and still hasn't been uploaded to.
now = self.clock.time_msec()
expired_time_ms = now - self.unused_expiration_time
if media_info.created_ts < expired_time_ms:
logger.info("Media %s has expired without being uploaded", media_id)
respond_404(request)
return None
if now >= wait_until:
break
await self.clock.sleep(0.5)
logger.info("Media %s has not yet been uploaded", media_id)
self.respond_not_yet_uploaded(request)
return None
async def get_local_media(
self, request: SynapseRequest, media_id: str, name: Optional[str]
self,
request: SynapseRequest,
media_id: str,
name: Optional[str],
max_timeout_ms: int,
) -> None:
"""Responds to requests for local media, if exists, or returns 404.
@@ -240,23 +424,24 @@ class MediaRepository:
the file_id for local content.)
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns:
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
media_info = await self.get_local_media_info(request, media_id, max_timeout_ms)
if not media_info:
return
self.mark_recently_accessed(None, media_id)
media_type = media_info["media_type"]
media_type = media_info.media_type
if not media_type:
media_type = "application/octet-stream"
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
media_length = media_info.media_length
upload_name = name if name else media_info.upload_name
url_cache = media_info.url_cache
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
@@ -271,6 +456,7 @@ class MediaRepository:
server_name: str,
media_id: str,
name: Optional[str],
max_timeout_ms: int,
) -> None:
"""Respond to requests for remote media.
@@ -280,6 +466,8 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the remote server).
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns:
Resolves once a response has successfully been written to request
@@ -305,27 +493,33 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
server_name, media_id, max_timeout_ms
)
# We deliberately stream the file outside the lock
if responder:
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
if responder and media_info:
upload_name = name if name else media_info.upload_name
await respond_with_responder(
request, responder, media_type, media_length, upload_name
request,
responder,
media_info.media_type,
media_info.media_length,
upload_name,
)
else:
respond_404(request)
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
async def get_remote_media_info(
self, server_name: str, media_id: str, max_timeout_ms: int
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns:
The media info of the file
@@ -341,7 +535,7 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
server_name, media_id, max_timeout_ms
)
# Ensure we actually use the responder so that it releases resources
@@ -352,8 +546,8 @@ class MediaRepository:
return media_info
async def _get_remote_media_impl(
self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]:
self, server_name: str, media_id: str, max_timeout_ms: int
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -361,6 +555,8 @@ class MediaRepository:
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the
remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns:
A tuple of responder and the media info of the file.
@@ -373,15 +569,17 @@ class MediaRepository:
# If we have an entry in the DB, try and look for it
if media_info:
file_id = media_info["filesystem_id"]
file_id = media_info.filesystem_id
file_info = FileInfo(server_name, file_id)
if media_info["quarantined_by"]:
if media_info.quarantined_by:
logger.info("Media is quarantined")
raise NotFoundError()
if not media_info["media_type"]:
media_info["media_type"] = "application/octet-stream"
if not media_info.media_type:
media_info = attr.evolve(
media_info, media_type="application/octet-stream"
)
responder = await self.media_storage.fetch_media(file_info)
if responder:
@@ -391,8 +589,7 @@ class MediaRepository:
try:
media_info = await self._download_remote_file(
server_name,
media_id,
server_name, media_id, max_timeout_ms
)
except SynapseError:
raise
@@ -403,9 +600,9 @@ class MediaRepository:
if not media_info:
raise e
file_id = media_info["filesystem_id"]
if not media_info["media_type"]:
media_info["media_type"] = "application/octet-stream"
file_id = media_info.filesystem_id
if not media_info.media_type:
media_info = attr.evolve(media_info, media_type="application/octet-stream")
file_info = FileInfo(server_name, file_id)
# We generate thumbnails even if another process downloaded the media
@@ -415,7 +612,7 @@ class MediaRepository:
# otherwise they'll request thumbnails and get a 404 if they're not
# ready yet.
await self._generate_thumbnails(
server_name, media_id, file_id, media_info["media_type"]
server_name, media_id, file_id, media_info.media_type
)
responder = await self.media_storage.fetch_media(file_info)
@@ -425,7 +622,8 @@ class MediaRepository:
self,
server_name: str,
media_id: str,
) -> dict:
max_timeout_ms: int,
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -434,7 +632,8 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
file_id: Local file ID
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns:
The media info of the file.
@@ -458,7 +657,8 @@ class MediaRepository:
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false"
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
)
except RequestSendFailed as e:
@@ -518,7 +718,7 @@ class MediaRepository:
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
@@ -526,15 +726,17 @@ class MediaRepository:
logger.info("Stored remote media in file %r", fname)
media_info = {
"media_type": media_type,
"media_length": length,
"upload_name": upload_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}
return media_info
return RemoteMedia(
media_origin=server_name,
media_id=media_id,
media_type=media_type,
media_length=length,
upload_name=upload_name,
created_ts=time_now_ms,
filesystem_id=file_id,
last_access_ts=time_now_ms,
quarantined_by=None,
)
def _get_thumbnail_requirements(
self, media_type: str
+5 -6
View File
@@ -240,15 +240,14 @@ class UrlPreviewer:
cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
and cache_result["response_code"] / 100 == 2
and cache_result.expires_ts > ts
and cache_result.response_code // 100 == 2
):
# It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on.
og = cache_result["og"]
if isinstance(og, str):
og = og.encode("utf8")
return og
if isinstance(cache_result.og, str):
return cache_result.og.encode("utf8")
return cache_result.og
# If this URL can be accessed via an allowed oEmbed, use that instead.
url_to_download = url
+2 -1
View File
@@ -1860,7 +1860,8 @@ class PublicRoomListManager:
if not room:
return False
return room.get("is_public", False)
# The first item is whether the room is public.
return room[0]
async def add_room_to_public_room_list(self, room_id: str) -> None:
"""Publishes a room to the public room list.
@@ -295,7 +295,8 @@ class ThirdPartyEventRulesModuleApiCallbacks:
raise
except SynapseError as e:
# FIXME: Being able to throw SynapseErrors is relied upon by
# some modules. PR #10386 accidentally broke this ability.
# some modules. PR https://github.com/matrix-org/synapse/pull/10386
# accidentally broke this ability.
# That said, we aren't keen on exposing this implementation detail
# to modules and we should one day have a proper way to do what
# is wanted.
+53 -19
View File
@@ -257,6 +257,11 @@ class ReplicationCommandHandler:
if hs.config.redis.redis_enabled:
self._notifier.add_lock_released_callback(self.on_lock_released)
# Marks if we should send POSITION commands for all streams ASAP. This
# is checked by the `ReplicationStreamer` which manages sending
# RDATA/POSITION commands
self._should_announce_positions = True
def subscribe_to_channel(self, channel_name: str) -> None:
"""
Indicates that we wish to subscribe to a Redis channel by name.
@@ -397,29 +402,23 @@ class ReplicationCommandHandler:
return self._streams_to_replicate
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None:
self.send_positions_to_connection(conn)
self.send_positions_to_connection()
def send_positions_to_connection(self, conn: IReplicationConnection) -> None:
def send_positions_to_connection(self) -> None:
"""Send current position of all streams this process is source of to
the connection.
"""
# We respond with current position of all streams this instance
# replicates.
for stream in self.get_streams_to_replicate():
# Note that we use the current token as the prev token here (rather
# than stream.last_token), as we can't be sure that there have been
# no rows written between last token and the current token (since we
# might be racing with the replication sending bg process).
current_token = stream.current_token(self._instance_name)
self.send_command(
PositionCommand(
stream.NAME,
self._instance_name,
current_token,
current_token,
)
)
self._should_announce_positions = True
self._notifier.notify_replication()
def should_announce_positions(self) -> bool:
"""Check if we should send POSITION commands for all streams ASAP."""
return self._should_announce_positions
def will_announce_positions(self) -> None:
"""Mark that we're about to send POSITIONs out for all streams."""
self._should_announce_positions = False
def on_USER_SYNC(
self, conn: IReplicationConnection, cmd: UserSyncCommand
@@ -588,6 +587,21 @@ class ReplicationCommandHandler:
logger.debug("Handling '%s %s'", cmd.NAME, cmd.to_line())
# Check if we can early discard this position. We can only do so for
# connected streams.
stream = self._streams[cmd.stream_name]
if stream.can_discard_position(
cmd.instance_name, cmd.prev_token, cmd.new_token
) and self.is_stream_connected(conn, cmd.stream_name):
logger.debug(
"Discarding redundant POSITION %s/%s %s %s",
cmd.instance_name,
cmd.stream_name,
cmd.prev_token,
cmd.new_token,
)
return
self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
@@ -599,6 +613,18 @@ class ReplicationCommandHandler:
"""
stream = self._streams[stream_name]
if stream.can_discard_position(
cmd.instance_name, cmd.prev_token, cmd.new_token
) and self.is_stream_connected(conn, cmd.stream_name):
logger.debug(
"Discarding redundant POSITION %s/%s %s %s",
cmd.instance_name,
cmd.stream_name,
cmd.prev_token,
cmd.new_token,
)
return
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
for streams in self._streams_by_connection.values():
@@ -626,8 +652,9 @@ class ReplicationCommandHandler:
# for why this can happen.
logger.info(
"Fetching replication rows for '%s' between %i and %i",
"Fetching replication rows for '%s' / %s between %i and %i",
stream_name,
cmd.instance_name,
current_token,
cmd.new_token,
)
@@ -657,6 +684,13 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
def is_stream_connected(
self, conn: IReplicationConnection, stream_name: str
) -> bool:
"""Return if stream has been successfully connected and is ready to
receive updates"""
return stream_name in self._streams_by_connection.get(conn, ())
def on_REMOTE_SERVER_UP(
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
) -> None:
+1 -1
View File
@@ -141,7 +141,7 @@ class RedisSubscriber(SubscriberProtocol):
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE.
self.synapse_handler.send_positions_to_connection(self)
self.synapse_handler.send_positions_to_connection()
def messageReceived(self, pattern: str, channel: str, message: str) -> None:
"""Received a message from redis."""
+16 -1
View File
@@ -123,7 +123,7 @@ class ReplicationStreamer:
# We check up front to see if anything has actually changed, as we get
# poked because of changes that happened on other instances.
if all(
if not self.command_handler.should_announce_positions() and all(
stream.last_token == stream.current_token(self._instance_name)
for stream in self.streams
):
@@ -158,6 +158,21 @@ class ReplicationStreamer:
all_streams = list(all_streams)
random.shuffle(all_streams)
if self.command_handler.should_announce_positions():
# We need to send out POSITIONs for all streams, usually
# because a worker has reconnected.
self.command_handler.will_announce_positions()
for stream in all_streams:
self.command_handler.send_command(
PositionCommand(
stream.NAME,
self._instance_name,
stream.last_token,
stream.last_token,
)
)
for stream in all_streams:
if stream.last_token == stream.current_token(
self._instance_name
+18
View File
@@ -144,6 +144,16 @@ class Stream:
"""
raise NotImplementedError()
def can_discard_position(
self, instance_name: str, prev_token: int, new_token: int
) -> bool:
"""Whether or not a position command for this stream can be discarded.
Useful for streams that can never go backwards and where we already know
the stream ID for the instance has advanced.
"""
return False
def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
@@ -221,6 +231,14 @@ class _StreamFromIdGen(Stream):
def minimal_local_current_token(self) -> Token:
return self._stream_id_gen.get_minimal_local_current_token()
def can_discard_position(
self, instance_name: str, prev_token: int, new_token: int
) -> bool:
# These streams can't go backwards, so we know we can ignore any
# positions where the tokens are from before the current token.
return new_token <= self.current_token(instance_name)
def current_token_without_instance(
current_token: Callable[[], int]
+2
View File
@@ -88,6 +88,7 @@ from synapse.rest.admin.users import (
UserByThreePid,
UserMembershipRestServlet,
UserRegisterServlet,
UserReplaceMasterCrossSigningKeyRestServlet,
UserRestServletV2,
UsersRestServletV2,
UserTokenRestServlet,
@@ -292,6 +293,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ListDestinationsRestServlet(hs).register(http_server)
RoomMessagesRestServlet(hs).register(http_server)
RoomTimestampToEventRestServlet(hs).register(http_server)
UserReplaceMasterCrossSigningKeyRestServlet(hs).register(http_server)
UserByExternalId(hs).register(http_server)
UserByThreePid(hs).register(http_server)
+4 -4
View File
@@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")
members = await self.store.get_users_in_room(room_id)
@@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
+40
View File
@@ -1270,6 +1270,46 @@ class AccountDataRestServlet(RestServlet):
}
class UserReplaceMasterCrossSigningKeyRestServlet(RestServlet):
"""Allow a given user to replace their master cross-signing key without UIA.
This replacement is permitted for a limited period (currently 10 minutes).
While this is exposed via the admin API, this is intended for use by the
Matrix Authentication Service rather than server admins.
"""
PATTERNS = admin_patterns(
"/users/(?P<user_id>[^/]*)/_allow_cross_signing_replacement_without_uia"
)
REPLACEMENT_PERIOD_MS = 10 * 60 * 1000 # 10 minutes
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastores().main
async def on_POST(
self,
request: SynapseRequest,
user_id: str,
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request)
if user_id is None:
raise NotFoundError("User not found")
timestamp = (
await self._store.allow_master_cross_signing_key_replacement_without_uia(
user_id, self.REPLACEMENT_PERIOD_MS
)
)
if timestamp is None:
raise NotFoundError("User has no master cross-signing key")
return HTTPStatus.OK, {"updatable_without_uia_before_ms": timestamp}
class UserByExternalId(RestServlet):
"""Find a user based on an external ID from an auth provider"""
+8 -11
View File
@@ -299,19 +299,16 @@ class DeactivateAccountRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
# allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), body.erase, requester
# allow ASes to deactivate their own users:
# ASes don't need user-interactive auth
if not requester.app_service:
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body.dict(exclude_unset=True),
"deactivate your account",
)
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body.dict(exclude_unset=True),
"deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), body.erase, requester, id_server=body.id_server
)
+1 -1
View File
@@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
if room is None:
raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"}
return 200, {"visibility": "public" if room[0] else "private"}
class PutBody(RequestBodyModel):
visibility: Literal["public", "private"] = "public"
+11 -5
View File
@@ -376,9 +376,10 @@ class SigningKeyUploadServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
is_cross_signing_setup = (
await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id)
)
(
is_cross_signing_setup,
master_key_updatable_without_uia,
) = await self.e2e_keys_handler.check_cross_signing_setup(user_id)
# Before MSC3967 we required UIA both when setting up cross signing for the
# first time and when resetting the device signing key. With MSC3967 we only
@@ -386,9 +387,14 @@ class SigningKeyUploadServlet(RestServlet):
# time. Because there is no UIA in MSC3861, for now we throw an error if the
# user tries to reset the device signing key when MSC3861 is enabled, but allow
# first-time setup.
#
# XXX: We now have a get-out clause by which MAS can temporarily mark the master
# key as replaceable. It should do its own equivalent of user interactive auth
# before doing so.
if self.hs.config.experimental.msc3861.enabled:
# There is no way to reset the device signing key with MSC3861
if is_cross_signing_setup:
# The auth service has to explicitly mark the master key as replaceable
# without UIA to reset the device signing key with MSC3861.
if is_cross_signing_setup and not master_key_updatable_without_uia:
raise SynapseError(
HTTPStatus.NOT_IMPLEMENTED,
"Resetting cross signing keys is not yet supported with MSC3861",
+83
View File
@@ -0,0 +1,83 @@
# Copyright 2023 Beeper Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING
from synapse.api.errors import LimitExceededError
from synapse.api.ratelimiting import Ratelimiter
from synapse.http.server import respond_with_json
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class CreateResource(RestServlet):
PATTERNS = [re.compile("/_matrix/media/v1/create")]
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
# A rate limiter for creating new media IDs.
self._create_media_rate_limiter = Ratelimiter(
store=hs.get_datastores().main,
clock=self.clock,
cfg=hs.config.ratelimiting.rc_media_create,
)
async def on_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
# If the create media requests for the user are over the limit, drop them.
await self._create_media_rate_limiter.ratelimit(requester)
(
reached_pending_limit,
first_expiration_ts,
) = await self.media_repo.reached_pending_media_limit(requester.user)
if reached_pending_limit:
raise LimitExceededError(
limiter_name="max_pending_media_uploads",
retry_after_ms=first_expiration_ts - self.clock.time_msec(),
)
content_uri, unused_expires_at = await self.media_repo.create_media_id(
requester.user
)
logger.info(
"Created Media URI %r that if unused will expire at %d",
content_uri,
unused_expires_at,
)
respond_with_json(
request,
200,
{
"content_uri": content_uri,
"unused_expires_at": unused_expires_at,
},
send_cors=True,
)
+15 -7
View File
@@ -17,9 +17,13 @@ import re
from typing import TYPE_CHECKING, Optional
from synapse.http.server import set_corp_headers, set_cors_headers
from synapse.http.servlet import RestServlet, parse_boolean
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
from synapse.http.site import SynapseRequest
from synapse.media._base import respond_404
from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
respond_404,
)
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
@@ -65,12 +69,16 @@ class DownloadResource(RestServlet):
)
# Limited non-standard form of CSP for IE11
request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
request.setHeader(
b"Referrer-Policy",
b"no-referrer",
request.setHeader(b"Referrer-Policy", b"no-referrer")
max_timeout_ms = parse_integer(
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
)
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
if self._is_mine_server_name(server_name):
await self.media_repo.get_local_media(request, media_id, file_name)
await self.media_repo.get_local_media(
request, media_id, file_name, max_timeout_ms
)
else:
allow_remote = parse_boolean(request, "allow_remote", default=True)
if not allow_remote:
@@ -83,5 +91,5 @@ class DownloadResource(RestServlet):
return
await self.media_repo.get_remote_media(
request, server_name, media_id, file_name
request, server_name, media_id, file_name, max_timeout_ms
)
@@ -18,10 +18,11 @@ from synapse.config._base import ConfigError
from synapse.http.server import HttpServer, JsonResource
from .config_resource import MediaConfigResource
from .create_resource import CreateResource
from .download_resource import DownloadResource
from .preview_url_resource import PreviewUrlResource
from .thumbnail_resource import ThumbnailResource
from .upload_resource import UploadResource
from .upload_resource import AsyncUploadServlet, UploadServlet
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -91,8 +92,9 @@ class MediaRepositoryResource(JsonResource):
# Note that many of these should not exist as v1 endpoints, but empirically
# a lot of traffic still goes to them.
UploadResource(hs, media_repo).register(http_server)
CreateResource(hs, media_repo).register(http_server)
UploadServlet(hs, media_repo).register(http_server)
AsyncUploadServlet(hs, media_repo).register(http_server)
DownloadResource(hs, media_repo).register(http_server)
ThumbnailResource(hs, media_repo, media_repo.media_storage).register(
http_server
+50 -31
View File
@@ -23,6 +23,8 @@ from synapse.http.server import respond_with_json, set_corp_headers, set_cors_he
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
FileInfo,
ThumbnailInfo,
respond_404,
@@ -75,15 +77,19 @@ class ThumbnailResource(RestServlet):
method = parse_string(request, "method", "scale")
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
m_type = "image/png"
max_timeout_ms = parse_integer(
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
)
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails:
await self._select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type
request, media_id, width, height, method, m_type, max_timeout_ms
)
else:
await self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
request, media_id, width, height, method, m_type, max_timeout_ms
)
self.media_repo.mark_recently_accessed(None, media_id)
else:
@@ -95,14 +101,21 @@ class ThumbnailResource(RestServlet):
respond_404(request)
return
if self.dynamic_thumbnails:
await self._select_or_generate_remote_thumbnail(
request, server_name, media_id, width, height, method, m_type
)
else:
await self._respond_remote_thumbnail(
request, server_name, media_id, width, height, method, m_type
)
remote_resp_function = (
self._select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
else self._respond_remote_thumbnail
)
await remote_resp_function(
request,
server_name,
media_id,
width,
height,
method,
m_type,
max_timeout_ms,
)
self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail(
@@ -113,15 +126,12 @@ class ThumbnailResource(RestServlet):
height: int,
method: str,
m_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.store.get_local_media(media_id)
media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
respond_404(request)
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
@@ -134,7 +144,7 @@ class ThumbnailResource(RestServlet):
thumbnail_infos,
media_id,
media_id,
url_cache=bool(media_info["url_cache"]),
url_cache=bool(media_info.url_cache),
server_name=None,
)
@@ -146,15 +156,13 @@ class ThumbnailResource(RestServlet):
desired_height: int,
desired_method: str,
desired_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.store.get_local_media(media_id)
media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
respond_404(request)
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
@@ -168,7 +176,7 @@ class ThumbnailResource(RestServlet):
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
url_cache=bool(media_info.url_cache),
thumbnail=info,
)
@@ -188,7 +196,7 @@ class ThumbnailResource(RestServlet):
desired_height,
desired_method,
desired_type,
url_cache=bool(media_info["url_cache"]),
url_cache=bool(media_info.url_cache),
)
if file_path:
@@ -206,14 +214,20 @@ class ThumbnailResource(RestServlet):
desired_height: int,
desired_method: str,
desired_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
respond_404(request)
return
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
file_id = media_info["filesystem_id"]
file_id = media_info.filesystem_id
for info in thumbnail_infos:
t_w = info.width == desired_width
@@ -224,7 +238,7 @@ class ThumbnailResource(RestServlet):
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
file_id=file_id,
thumbnail=info,
)
@@ -263,11 +277,16 @@ class ThumbnailResource(RestServlet):
height: int,
method: str,
m_type: str,
max_timeout_ms: int,
) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
@@ -280,7 +299,7 @@ class ThumbnailResource(RestServlet):
m_type,
thumbnail_infos,
media_id,
media_info["filesystem_id"],
media_info.filesystem_id,
url_cache=False,
server_name=server_name,
)
+68 -7
View File
@@ -15,7 +15,7 @@
import logging
import re
from typing import IO, TYPE_CHECKING, Dict, List, Optional
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import respond_with_json
@@ -29,23 +29,24 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# The name of the lock to use when uploading media.
_UPLOAD_MEDIA_LOCK_NAME = "upload_media"
class UploadResource(RestServlet):
PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")]
class BaseUploadServlet(RestServlet):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.filepaths = media_repo.filepaths
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.auth = hs.get_auth()
self.max_upload_size = hs.config.media.max_upload_size
self.clock = hs.get_clock()
async def on_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
def _get_file_metadata(
self, request: SynapseRequest
) -> Tuple[int, Optional[str], str]:
raw_content_length = request.getHeader("Content-Length")
if raw_content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
@@ -88,6 +89,16 @@ class UploadResource(RestServlet):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
return content_length, upload_name, media_type
class UploadServlet(BaseUploadServlet):
PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload$")]
async def on_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
content_length, upload_name, media_type = self._get_file_metadata(request)
try:
content: IO = request.content # type: ignore
content_uri = await self.media_repo.create_content(
@@ -103,3 +114,53 @@ class UploadResource(RestServlet):
respond_with_json(
request, 200, {"content_uri": str(content_uri)}, send_cors=True
)
class AsyncUploadServlet(BaseUploadServlet):
PATTERNS = [
re.compile(
"/_matrix/media/v3/upload/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
]
async def on_PUT(
self, request: SynapseRequest, server_name: str, media_id: str
) -> None:
requester = await self.auth.get_user_by_req(request)
if server_name != self.server_name:
raise SynapseError(
404,
"Non-local server name specified",
errcode=Codes.NOT_FOUND,
)
lock = await self.store.try_acquire_lock(_UPLOAD_MEDIA_LOCK_NAME, media_id)
if not lock:
raise SynapseError(
409,
"Media ID cannot be overwritten",
errcode=Codes.CANNOT_OVERWRITE_MEDIA,
)
async with lock:
await self.media_repo.verify_can_upload(media_id, requester.user)
content_length, upload_name, media_type = self._get_file_metadata(request)
try:
content: IO = request.content # type: ignore
await self.media_repo.update_content(
media_id,
media_type,
upload_name,
content,
content_length,
requester.user,
)
except SpamMediaException:
# For uploading of media we want to respond with a 400, instead of
# the default 404, as that would just be confusing.
raise SynapseError(400, "Bad content")
logger.info("Uploaded content for media ID %r", media_id)
respond_with_json(request, 200, {}, send_cors=True)
+12 -6
View File
@@ -49,7 +49,11 @@ else:
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
logger = logging.getLogger(__name__)
@@ -746,10 +750,10 @@ class BackgroundUpdater:
The named index will be dropped upon completion of the new index.
"""
def create_index_psql(conn: Connection) -> None:
def create_index_psql(conn: "LoggingDatabaseConnection") -> None:
conn.rollback()
# postgres insists on autocommit for the index
conn.set_session(autocommit=True) # type: ignore
conn.engine.attempt_to_set_autocommit(conn.conn, True)
try:
c = conn.cursor()
@@ -793,9 +797,9 @@ class BackgroundUpdater:
undo_timeout_sql = f"SET statement_timeout = {default_timeout}"
conn.cursor().execute(undo_timeout_sql)
conn.set_session(autocommit=False) # type: ignore
conn.engine.attempt_to_set_autocommit(conn.conn, False)
def create_index_sqlite(conn: Connection) -> None:
def create_index_sqlite(conn: "LoggingDatabaseConnection") -> None:
# Sqlite doesn't support concurrent creation of indexes.
#
# We assume that sqlite doesn't give us invalid indices; however
@@ -825,7 +829,9 @@ class BackgroundUpdater:
c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner: Optional[Callable[[Connection], None]] = create_index_psql
runner: Optional[
Callable[[LoggingDatabaseConnection], None]
] = create_index_psql
elif psql_only:
runner = None
else:
+6 -6
View File
@@ -1116,7 +1116,7 @@ class DatabasePool:
def simple_insert_many_txn(
txn: LoggingTransaction,
table: str,
keys: Collection[str],
keys: Sequence[str],
values: Collection[Iterable[Any]],
) -> None:
"""Executes an INSERT query on the named table.
@@ -1597,7 +1597,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
) -> Tuple[Any, ...]:
...
@overload
@@ -1608,7 +1608,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
...
async def simple_select_one(
@@ -1618,7 +1618,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@@ -2127,7 +2127,7 @@ class DatabasePool:
keyvalues: Dict[str, Any],
retcols: Collection[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
if keyvalues:
@@ -2145,7 +2145,7 @@ class DatabasePool:
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
return dict(zip(retcols, row))
return row
async def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+1 -1
View File
@@ -45,7 +45,7 @@ class Databases(Generic[DataStoreT]):
"""
databases: List[DatabasePool]
main: "DataStore" # FIXME: #11165: actually an instance of `main_store_class`
main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class`
state: StateGroupDataStore
persist_events: Optional[PersistEventsStore]
+18 -6
View File
@@ -747,8 +747,16 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
self._invalidate_cache_and_stream_bulk(
txn,
self.ignored_by,
[
(ignored_user_id,)
for ignored_user_id in (
previously_ignored_users ^ currently_ignored_users
)
],
)
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def remove_account_data_for_user(
@@ -824,10 +832,14 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
# Invalidate the cache for ignored users which were removed.
for ignored_user_id in previously_ignored_users:
self._invalidate_cache_and_stream(
txn, self.ignored_by, (ignored_user_id,)
)
self._invalidate_cache_and_stream_bulk(
txn,
self.ignored_by,
[
(ignored_user_id,)
for ignored_user_id in previously_ignored_users
],
)
# Invalidate for this user the cache tracking ignored users.
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
+71 -4
View File
@@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_cache_and_stream_bulk(
self,
txn: LoggingTransaction,
cache_func: CachedFunction,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""A bulk version of _invalidate_cache_and_stream.
Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
for each key-tuple over replication.
This implementation is more efficient than a loop which repeatedly calls the
non-bulk version.
"""
if not key_tuples:
return
for keys in key_tuples:
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication_bulk(
txn, cache_func.__name__, key_tuples
)
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
@@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None
# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
@@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
},
)
def _send_invalidation_to_replication_bulk(
self,
txn: LoggingTransaction,
cache_name: str,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""Announce the invalidation of multiple (but not all) cache entries.
This is more efficient than repeated calls to the non-bulk version. It should
NOT be used to invalidating the entire cache: use
`_send_invalidation_to_replication` with keys=None.
Note that this does *not* invalidate the cache locally.
Args:
txn
cache_name
key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
"""
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None
stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
ts = self._clock.time_msec()
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self.db_pool.simple_insert_many_txn(
txn,
table="cache_invalidation_stream_by_instance",
keys=(
"stream_id",
"instance_name",
"cache_func",
"keys",
"invalidation_ts",
),
values=[
# We convert key_tuples to a list here because psycopg2 serialises
# lists as pq arrrays, but serialises tuples as "composite types".
# (We need an array because the `keys` column has type `[]text`.)
# See:
# https://www.psycopg.org/docs/usage.html#adapt-list
# https://www.psycopg.org/docs/usage.html#adapt-tuple
(stream_id, self._instance_name, cache_name, list(key_tuple), ts)
for stream_id, key_tuple in zip(stream_ids, key_tuples)
],
)
def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
return self._cache_id_gen.get_current_token_for_writer(instance_name)
+13 -30
View File
@@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
A dict containing the device information, or `None` if the device does not
exist.
"""
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
async def get_device_opt(
self, user_id: str, device_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
A dict containing the device information, or None if the device does not exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
if row is None:
return None
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
async def get_devices_by_user(
self, user_id: str
@@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
retcols=["device_id", "device_data"],
allow_none=True,
)
return (
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
)
return (row[0], json_decoder.decode(row[1])) if row else None
def _store_dehydrated_device_txn(
self,
@@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
`FALSE` have not been converted.
"""
row = await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
return cast(
Tuple[int, str],
await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
),
)
return row["stream_id"], row["room_id"]
async def set_device_change_last_converted_pos(
self,
+19 -12
View File
@@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
# it isn't there.
raise StoreError(404, "No backup with that version exists")
result = self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
row = cast(
Tuple[int, str, str, Optional[int]],
self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={
"user_id": user_id,
"version": this_version,
"deleted": 0,
},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
),
)
assert result is not None # see comment on `simple_select_one_txn`
result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
return result
return {
"auth_data": db_to_json(row[2]),
"version": str(row[0]),
"algorithm": row[1],
"etag": 0 if row[3] is None else row[3],
}
return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
@@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)
self._invalidate_cache_and_stream_bulk(
txn, self.get_e2e_unused_fallback_key_types, seen_user_device
)
return results
@@ -1268,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if row is None:
continue
key_id = row["key_id"]
key_json = row["key_json"]
used = row["used"]
key_id, key_json, used = row
# Mark fallback key as used if not already.
if not used and mark_as_used:
@@ -1376,17 +1372,62 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
)
seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, _, _, _ in otk_rows:
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
seen_user_device = {
(user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
}
self._invalidate_cache_and_stream_bulk(
txn,
self.count_e2e_one_time_keys,
seen_user_device,
)
return otk_rows
async def get_master_cross_signing_key_updatable_before(
self, user_id: str
) -> Tuple[bool, Optional[int]]:
"""Get time before which a master cross-signing key may be replaced without UIA.
(UIA means "User-Interactive Auth".)
There are three cases to distinguish:
(1) No master cross-signing key.
(2) The key exists, but there is no replace-without-UI timestamp in the DB.
(3) The key exists, and has such a timestamp recorded.
Returns: a 2-tuple of:
- a boolean: is there a master cross-signing key already?
- an optional timestamp, directly taken from the DB.
In terms of the cases above, these are:
(1) (False, None).
(2) (True, None).
(3) (True, <timestamp in ms>).
"""
def impl(txn: LoggingTransaction) -> Tuple[bool, Optional[int]]:
# We want to distinguish between three cases:
txn.execute(
"""
SELECT updatable_without_uia_before_ms
FROM e2e_cross_signing_keys
WHERE user_id = ? AND keytype = 'master'
ORDER BY stream_id DESC
LIMIT 1
""",
(user_id,),
)
row = cast(Optional[Tuple[Optional[int]]], txn.fetchone())
if row is None:
return False, None
return True, row[0]
return await self.db_pool.runInteraction(
"e2e_cross_signing_keys",
impl,
)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(
@@ -1634,3 +1675,42 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
],
desc="add_e2e_signing_key",
)
async def allow_master_cross_signing_key_replacement_without_uia(
self, user_id: str, duration_ms: int
) -> Optional[int]:
"""Mark this user's latest master key as being replaceable without UIA.
Said replacement will only be permitted for a short time after calling this
function. That time period is controlled by the duration argument.
Returns:
None, if there is no such key.
Otherwise, the timestamp before which replacement is allowed without UIA.
"""
timestamp = self._clock.time_msec() + duration_ms
def impl(txn: LoggingTransaction) -> Optional[int]:
txn.execute(
"""
UPDATE e2e_cross_signing_keys
SET updatable_without_uia_before_ms = ?
WHERE stream_id = (
SELECT stream_id
FROM e2e_cross_signing_keys
WHERE user_id = ? AND keytype = 'master'
ORDER BY stream_id DESC
LIMIT 1
)
""",
(timestamp, user_id),
)
if txn.rowcount == 0:
return None
return timestamp
return await self.db_pool.runInteraction(
"allow_master_cross_signing_key_replacement_without_uia",
impl,
)
@@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
@@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains",
@@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
if event_lookup_result is not None:
event_type, depth, stream_ordering = event_lookup_result
logger.debug(
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
room_id,
seed_event_id,
event_lookup_result["depth"],
event_lookup_result["stream_ordering"],
event_lookup_result["type"],
depth,
stream_ordering,
event_type,
)
if event_lookup_result["depth"]:
queue.put(
(
-event_lookup_result["depth"],
-event_lookup_result["stream_ordering"],
seed_event_id,
event_lookup_result["type"],
)
)
if depth:
queue.put((-depth, -stream_ordering, seed_event_id, event_type))
while not queue.empty() and len(event_id_results) < limit:
try:
+1 -2
View File
@@ -1934,8 +1934,7 @@ class PersistEventsStore:
if row is None:
return
redacted_relates_to = row["relates_to_id"]
rel_type = row["relation_type"]
redacted_relates_to, rel_type = row
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
@@ -425,7 +425,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"""Background update to clean out extremities that should have been
deleted previously.
Mainly used to deal with the aftermath of #5269.
Mainly used to deal with the aftermath of https://github.com/matrix-org/synapse/issues/5269.
"""
# This works by first copying all existing forward extremities into the
@@ -558,7 +558,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
logger.info(
"Deleted %d forward extremities of %d checked, to clean up #5269",
"Deleted %d forward extremities of %d checked, to clean up matrix-org/synapse#5269",
deleted,
len(original_set),
)
@@ -1222,14 +1222,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
# Iterate the parent IDs and invalidate caches.
for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,)
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
)
cache_tuples = {(r[1],) for r in relations_to_insert}
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined]
)
if results:
latest_event_id = results[-1][0]
@@ -1312,7 +1312,8 @@ class EventsWorkerStore(SQLBaseStore):
room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
# arrived before https://github.com/matrix-org/synapse/issues/6983
# landed. For all other events, we should have
# an entry in the 'rooms' table.
#
# However, the 'out_of_band_membership' flag is unreliable for older
@@ -1323,7 +1324,8 @@ class EventsWorkerStore(SQLBaseStore):
"Room %s for event %s is unknown" % (d["room_id"], event_id)
)
# so, assuming this is an out-of-band-invite that arrived before #6983
# so, assuming this is an out-of-band-invite that arrived before
# https://github.com/matrix-org/synapse/issues/6983
# landed, we know that the room version must be v5 or earlier (because
# v6 hadn't been invented at that point, so invites from such rooms
# would have been rejected.)
@@ -1998,7 +2000,7 @@ class EventsWorkerStore(SQLBaseStore):
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
return int(res["topological_ordering"]), int(res["stream_ordering"])
return int(res[0]), int(res[1])
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
+10 -7
View File
@@ -107,13 +107,16 @@ class KeyStore(CacheInvalidationWorkerStore):
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
for key_id in verify_keys:
self._invalidate_cache_and_stream(
txn, self._get_server_keys_json, ((server_name, key_id),)
)
self._invalidate_cache_and_stream(
txn, self.get_server_key_json_for_remote, (server_name, key_id)
)
self._invalidate_cache_and_stream_bulk(
txn,
self._get_server_keys_json,
[((server_name, key_id),) for key_id in verify_keys],
)
self._invalidate_cache_and_stream_bulk(
txn,
self.get_server_key_json_for_remote,
[(server_name, key_id) for key_id in verify_keys],
)
await self.db_pool.runInteraction(
"store_server_keys_response", store_server_keys_response_txn
@@ -15,9 +15,7 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
@@ -51,12 +49,34 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
class LocalMedia:
media_id: str
media_type: str
media_length: int
media_length: Optional[int]
upload_name: str
created_ts: int
url_cache: Optional[str]
last_access_ts: int
quarantined_by: Optional[str]
safe_from_quarantine: bool
user_id: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RemoteMedia:
media_origin: str
media_id: str
media_type: str
media_length: int
upload_name: Optional[str]
filesystem_id: str
created_ts: int
last_access_ts: int
quarantined_by: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UrlCache:
response_code: int
expires_ts: int
og: Union[str, bytes]
class MediaSortOrder(Enum):
@@ -130,6 +150,13 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
self._drop_media_index_without_method,
)
if hs.config.media.can_load_media_repo:
self.unused_expiration_time: Optional[
int
] = hs.config.media.unused_expiration_time
else:
self.unused_expiration_time = None
async def _drop_media_index_without_method(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -165,13 +192,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
super().__init__(database, db_conn, hs)
self.server_name: str = hs.hostname
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -181,11 +208,27 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"created_ts",
"quarantined_by",
"url_cache",
"last_access_ts",
"safe_from_quarantine",
"user_id",
),
allow_none=True,
desc="get_local_media",
)
if row is None:
return None
return LocalMedia(
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
quarantined_by=row[4],
url_cache=row[5],
last_access_ts=row[6],
safe_from_quarantine=row[7],
user_id=row[8],
)
async def get_local_media_by_user_paginate(
self,
@@ -236,9 +279,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length,
upload_name,
created_ts,
url_cache,
last_access_ts,
quarantined_by,
safe_from_quarantine
safe_from_quarantine,
user_id
FROM local_media_repository
WHERE user_id = ?
ORDER BY {order_by_column} {order}, media_id ASC
@@ -257,9 +302,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length=row[2],
upload_name=row[3],
created_ts=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
safe_from_quarantine=bool(row[7]),
url_cache=row[5],
last_access_ts=row[6],
quarantined_by=row[7],
safe_from_quarantine=bool(row[8]),
user_id=row[9],
)
for row in txn
]
@@ -356,6 +403,23 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_local_media_ids", _get_local_media_ids_txn
)
@trace
async def store_local_media_id(
self,
media_id: str,
time_now_ms: int,
user_id: UserID,
) -> None:
await self.db_pool.simple_insert(
"local_media_repository",
{
"media_id": media_id,
"created_ts": time_now_ms,
"user_id": user_id.to_string(),
},
desc="store_local_media_id",
)
@trace
async def store_local_media(
self,
@@ -381,6 +445,30 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_media",
)
async def update_local_media(
self,
media_id: str,
media_type: str,
upload_name: Optional[str],
media_length: int,
user_id: UserID,
url_cache: Optional[str] = None,
) -> None:
await self.db_pool.simple_update_one(
"local_media_repository",
keyvalues={
"user_id": user_id.to_string(),
"media_id": media_id,
},
updatevalues={
"media_type": media_type,
"upload_name": upload_name,
"media_length": media_length,
"url_cache": url_cache,
},
desc="update_local_media",
)
async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
"""Mark a local media as safe or unsafe from quarantining."""
await self.db_pool.simple_update_one(
@@ -390,51 +478,72 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe",
)
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
async def count_pending_media(self, user_id: UserID) -> Tuple[int, int]:
"""Count the number of pending media for a user.
Returns:
A tuple of two integers: the total pending media requests and the earliest
expiration timestamp.
"""
def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]:
sql = """
SELECT COUNT(*), MIN(created_ts)
FROM local_media_repository
WHERE user_id = ?
AND created_ts > ?
AND media_length IS NULL
"""
assert self.unused_expiration_time is not None
txn.execute(
sql,
(
user_id.to_string(),
self._clock.time_msec() - self.unused_expiration_time,
),
)
row = txn.fetchone()
if not row:
return 0, 0
return row[0], (row[1] + self.unused_expiration_time if row[1] else 0)
return await self.db_pool.runInteraction(
"get_pending_media", get_pending_media_txn
)
async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
"""
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
# get the most recently cached result (relative to the given ts)
sql = (
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts <= ?"
" ORDER BY download_ts DESC LIMIT 1"
)
sql = """
SELECT response_code, expires_ts, og
FROM local_media_repository_url_cache
WHERE url = ? AND download_ts <= ?
ORDER BY download_ts DESC LIMIT 1
"""
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
# ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any)
sql = (
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts > ?"
" ORDER BY download_ts ASC LIMIT 1"
)
sql = """
SELECT response_code, expires_ts, og
FROM local_media_repository_url_cache
WHERE url = ? AND download_ts > ?
ORDER BY download_ts ASC LIMIT 1
"""
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
return None
return dict(
zip(
(
"response_code",
"etag",
"expires_ts",
"og",
"media_id",
"download_ts",
),
row,
)
)
return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
@@ -444,7 +553,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
response_code: int,
etag: Optional[str],
expires_ts: int,
og: Optional[str],
og: str,
media_id: str,
download_ts: int,
) -> None:
@@ -510,8 +619,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_cached_remote_media(
self, origin: str, media_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
) -> Optional[RemoteMedia]:
row = await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -520,11 +629,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"upload_name",
"created_ts",
"filesystem_id",
"last_access_ts",
"quarantined_by",
),
allow_none=True,
desc="get_cached_remote_media",
)
if row is None:
return row
return RemoteMedia(
media_origin=origin,
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
filesystem_id=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
)
async def store_cached_remote_media(
self,
@@ -623,10 +746,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
t_width: int,
t_height: int,
t_type: str,
) -> Optional[Dict[str, Any]]:
) -> Optional[ThumbnailInfo]:
"""Fetch the thumbnail info of given width, height and type."""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",
keyvalues={
"media_origin": origin,
@@ -641,11 +764,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
"filesystem_id",
),
allow_none=True,
desc="get_remote_media_thumbnail",
)
if row is None:
return None
return ThumbnailInfo(
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
@trace
async def store_remote_media_thumbnail(
@@ -317,7 +317,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
if user_id:
is_support = self.is_support_user_txn(txn, user_id)
if not is_support:
# We do this manually here to avoid hitting #6791
# We do this manually here to avoid hitting https://github.com/matrix-org/synapse/issues/6791
self.db_pool.simple_upsert_txn(
txn,
table="monthly_active_users",
+5 -4
View File
@@ -363,10 +363,11 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
# for their user ID.
value_values=[(presence_stream_id,) for _ in user_ids],
)
for user_id in user_ids:
self._invalidate_cache_and_stream(
txn, self._get_full_presence_stream_token_for_user, (user_id,)
)
self._invalidate_cache_and_stream_bulk(
txn,
self._get_full_presence_stream_token_for_user,
[(user_id,) for user_id in user_ids],
)
return await self.db_pool.runInteraction(
"add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to
+11 -17
View File
@@ -13,7 +13,6 @@
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
return 50
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"full_user_id": user_id.to_string()},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
)
except StoreError as e:
if e.code == 404:
# no match
return ProfileInfo(None, None)
else:
raise
return ProfileInfo(
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"full_user_id": user_id.to_string()},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
allow_none=True,
)
if profile is None:
# no match
return ProfileInfo(None, None)
return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
+21 -12
View File
@@ -295,19 +295,28 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# so make sure to keep this actually last.
txn.execute("DROP TABLE events_to_purge")
for event_id, should_delete in event_rows:
self._invalidate_cache_and_stream(
txn, self._get_state_group_for_event, (event_id,)
)
self._invalidate_cache_and_stream_bulk(
txn,
self._get_state_group_for_event,
[(event_id,) for event_id, _ in event_rows],
)
# XXX: This is racy, since have_seen_events could be called between the
# transaction completing and the invalidation running. On the other hand,
# that's no different to calling `have_seen_events` just before the
# event is deleted from the database.
# XXX: This is racy, since have_seen_events could be called between the
# transaction completing and the invalidation running. On the other hand,
# that's no different to calling `have_seen_events` just before the
# event is deleted from the database.
self._invalidate_cache_and_stream_bulk(
txn,
self.have_seen_event,
[
(room_id, event_id)
for event_id, should_delete in event_rows
if should_delete
],
)
for event_id, should_delete in event_rows:
if should_delete:
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
self.invalidate_get_event_cache_after_txn(txn, event_id)
logger.info("[purge] done")
@@ -485,7 +494,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# - room_tags_revisions
# The problem with these is that they are largeish and there is no room_id
# index on them. In any case we should be clearing out 'stream' tables
# periodically anyway (#5888)
# periodically anyway (https://github.com/matrix-org/synapse/issues/5888)
self._invalidate_caches_for_room_and_stream(txn, room_id)
+27 -17
View File
@@ -449,27 +449,28 @@ class PushRuleStore(PushRulesWorkerStore):
before: str,
after: str,
) -> None:
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
self.database_engine.lock_table(txn, "push_rules")
relative_to_rule = before or after
res = self.db_pool.simple_select_one_txn(
txn,
table="push_rules",
keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
retcols=["priority_class", "priority"],
allow_none=True,
)
sql = """
SELECT priority, priority_class FROM push_rules
WHERE user_name = ? AND rule_id = ?
"""
if not res:
if isinstance(self.database_engine, PostgresEngine):
sql += " FOR UPDATE"
else:
# Annoyingly SQLite doesn't support row level locking, so lock the whole table
self.database_engine.lock_table(txn, "push_rules")
txn.execute(sql, (user_id, relative_to_rule))
row = txn.fetchone()
if row is None:
raise RuleNotFoundException(
"before/after rule not found: %s" % (relative_to_rule,)
)
base_priority_class = res["priority_class"]
base_rule_priority = res["priority"]
base_rule_priority, base_priority_class = row
if base_priority_class != priority_class:
raise InconsistentRuleException(
@@ -517,9 +518,18 @@ class PushRuleStore(PushRulesWorkerStore):
conditions_json: str,
actions_json: str,
) -> None:
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
self.database_engine.lock_table(txn, "push_rules")
if isinstance(self.database_engine, PostgresEngine):
# Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first
# then re-select the count/max below.
sql = """
SELECT * FROM push_rules
WHERE user_name = ? and priority_class = ?
FOR UPDATE
"""
txn.execute(sql, (user_id, priority_class))
else:
# Annoyingly SQLite doesn't support row level locking, so lock the whole table
self.database_engine.lock_table(txn, "push_rules")
# find the highest priority rule in that class
sql = (
+2 -2
View File
@@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
stream_ordering = int(res["stream_ordering"]) if res else None
rx_ts = res["received_ts"] if res else 0
stream_ordering = int(res[0]) if res else None
rx_ts = res[1] if res else 0
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
+60 -47
View File
@@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
account timestamp as milliseconds since the epoch. None if the account
has not been renewed using the current token yet.
"""
ret_dict = await self.db_pool.simple_select_one(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
desc="get_user_from_renewal_token",
)
return (
ret_dict["user_id"],
ret_dict["expiration_ts_ms"],
ret_dict["token_used_ts_ms"],
return cast(
Tuple[str, int, Optional[int]],
await self.db_pool.simple_select_one(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
desc="get_user_from_renewal_token",
),
)
async def get_renewal_token_for_user(self, user_id: str) -> str:
@@ -564,16 +561,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
updatevalues={"shadow_banned": shadow_banned},
)
# In order for this to apply immediately, clear the cache for this user.
tokens = self.db_pool.simple_select_onecol_txn(
tokens = self.db_pool.simple_select_list_txn(
txn,
table="access_tokens",
keyvalues={"user_id": user_id},
retcol="token",
retcols=("token",),
)
self._invalidate_cache_and_stream_bulk(
txn, self.get_user_by_access_token, tokens
)
for token in tokens:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (token,)
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
@@ -989,16 +985,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
user id, or None if no user id/threepid mapping exists
"""
ret = self.db_pool.simple_select_one_txn(
return self.db_pool.simple_select_one_onecol_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
["user_id"],
"user_id",
True,
)
if ret:
return ret["user_id"]
return None
async def user_add_threepid(
self,
@@ -1435,16 +1428,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
if res is None:
return False
uses_allowed, pending, completed, expiry_time = res
# Check if the token has expired
now = self._clock.time_msec()
if res["expiry_time"] and res["expiry_time"] < now:
if expiry_time and expiry_time < now:
return False
# Check if the token has been used up
if (
res["uses_allowed"]
and res["pending"] + res["completed"] >= res["uses_allowed"]
):
if uses_allowed and pending + completed >= uses_allowed:
return False
# Otherwise, the token is valid
@@ -1490,8 +1482,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
res = cast(
Dict[str, Any],
pending, completed = cast(
Tuple[int, int],
self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
@@ -1506,8 +1498,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"registration_tokens",
keyvalues={"token": token},
updatevalues={
"completed": res["completed"] + 1,
"pending": res["pending"] - 1,
"completed": completed + 1,
"pending": pending - 1,
},
)
@@ -1585,13 +1577,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
A dict, or None if token doesn't exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
desc="get_one_registration_token",
)
if row is None:
return None
return {
"token": row[0],
"uses_allowed": row[1],
"pending": row[2],
"completed": row[3],
"expiry_time": row[4],
}
async def generate_registration_token(
self, length: int, chars: str
@@ -1714,7 +1715,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return None
# Get all info about the token so it can be sent in the response
return self.db_pool.simple_select_one_txn(
result = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
@@ -1728,6 +1729,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
allow_none=True,
)
if result is None:
return result
return {
"token": result[0],
"uses_allowed": result[1],
"pending": result[2],
"completed": result[3],
"expiry_time": result[4],
}
return await self.db_pool.runInteraction(
"update_registration_token", _update_registration_token_txn
)
@@ -1939,11 +1951,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"token": token},
updatevalues={"used_ts": ts},
)
user_id = values["user_id"]
expiry_ts = values["expiry_ts"]
used_ts = values["used_ts"]
auth_provider_id = values["auth_provider_id"]
auth_provider_session_id = values["auth_provider_session_id"]
(
user_id,
expiry_ts,
used_ts,
auth_provider_id,
auth_provider_session_id,
) = values
# Token was already used
if used_ts is not None:
@@ -2668,10 +2682,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
)
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
for token, _, _ in tokens_and_devices:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (token,)
)
self._invalidate_cache_and_stream_bulk(
txn,
self.get_user_by_access_token,
[(token,) for token, _, _ in tokens_and_devices],
)
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
@@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# reason, the next check is on the client secret, which is NOT NULL,
# so we don't have to worry about the client secret matching by
# accident.
row = {"client_secret": None, "validated_at": None}
row = None, None
else:
raise ThreepidValidationError("Unknown session_id")
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
retrieved_client_secret, validated_at = row
row = self.db_pool.simple_select_one_txn(
txn,
@@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
raise ThreepidValidationError(
"Validation token not found or has expired"
)
expires = row["expires"]
next_link = row["next_link"]
expires, next_link = row
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
+27 -18
View File
@@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
"""Retrieve a room.
Args:
room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
A tuple containing the room information:
* True if the room is public
* True if the room has an auth chain index
or None if the room is unknown.
"""
return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
desc="get_room",
allow_none=True,
row = cast(
Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("is_public", "has_auth_chain_index"),
desc="get_room",
allow_none=True,
),
)
if row is None:
return row
return bool(row[0]), bool(row[1])
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics.
@@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
if row:
return RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
)
return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
else:
return None
@@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
join.
"""
result = await self.db_pool.simple_select_one(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state",
return cast(
Tuple[str, int],
await self.db_pool.simple_select_one(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state",
),
)
return result["join_event_id"], result["device_lists_stream_id"]
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
+11 -8
View File
@@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"non-local user %s" % (user_id,),
)
results_dict = await self.db_pool.simple_select_one(
"local_current_membership",
{"room_id": room_id, "user_id": user_id},
("membership", "event_id"),
allow_none=True,
desc="get_local_current_membership_for_user_in_room",
results = cast(
Optional[Tuple[str, str]],
await self.db_pool.simple_select_one(
"local_current_membership",
{"room_id": room_id, "user_id": user_id},
("membership", "event_id"),
allow_none=True,
desc="get_local_current_membership_for_user_in_room",
),
)
if not results_dict:
if not results:
return None, None
return results_dict.get("membership"), results_dict.get("event_id")
return results
@cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
+4 -4
View File
@@ -275,7 +275,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# we have to set autocommit, because postgres refuses to
# CREATE INDEX CONCURRENTLY without it.
conn.set_session(autocommit=True)
conn.engine.attempt_to_set_autocommit(conn.conn, True)
try:
c = conn.cursor()
@@ -301,7 +301,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# we should now be able to delete the GIST index.
c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist")
finally:
conn.set_session(autocommit=False)
conn.engine.attempt_to_set_autocommit(conn.conn, False)
if isinstance(self.database_engine, PostgresEngine):
await self.db_pool.runWithConnection(create_index)
@@ -323,7 +323,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
conn.set_session(autocommit=True)
conn.engine.attempt_to_set_autocommit(conn.conn, True)
c = conn.cursor()
# We create with NULLS FIRST so that when we search *backwards*
@@ -340,7 +340,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
"""
)
conn.set_session(autocommit=False)
conn.engine.attempt_to_set_autocommit(conn.conn, False)
await self.db_pool.runWithConnection(create_index)
+12 -18
View File
@@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_position_for_event",
)
return PersistedEventPosition(
row["instance_name"] or "master", row["stream_ordering"]
)
return PersistedEventPosition(row[1] or "master", row[0])
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
@@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
return RoomStreamToken(
topological=row["topological_ordering"], stream=row["stream_ordering"]
)
return RoomStreamToken(topological=row[1], stream=row[0])
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
@@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
results = self.db_pool.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"],
stream_ordering, topological_ordering = cast(
Tuple[int, int],
self.db_pool.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"],
),
)
# This cannot happen as `allow_none=False`.
assert results is not None
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
topological=results["topological_ordering"] - 1,
stream=results["stream_ordering"],
topological=topological_ordering - 1, stream=stream_ordering
)
after_token = RoomStreamToken(
topological=results["topological_ordering"],
stream=results["stream_ordering"],
topological=topological_ordering, stream=stream_ordering
)
rows, start_token = self._paginate_room_events_txn(
@@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: the task if available, `None` otherwise
"""
row = await self.db_pool.simple_select_one(
table="scheduled_tasks",
keyvalues={"id": id},
retcols=(
"id",
"action",
"status",
"timestamp",
"resource_id",
"params",
"result",
"error",
row = cast(
Optional[ScheduledTaskRow],
await self.db_pool.simple_select_one(
table="scheduled_tasks",
keyvalues={"id": id},
retcols=(
"id",
"action",
"status",
"timestamp",
"resource_id",
"params",
"result",
"error",
),
allow_none=True,
desc="get_scheduled_task",
),
allow_none=True,
desc="get_scheduled_task",
)
return (
TaskSchedulerWorkerStore._convert_row_to_task(
(
row["id"],
row["action"],
row["status"],
row["timestamp"],
row["resource_id"],
row["params"],
row["result"],
row["error"],
)
)
if row
else None
)
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id.
+8 -12
View File
@@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
retcols=(
"transaction_id",
"origin",
"ts",
"response_code",
"response_json",
"has_been_referenced",
),
retcols=("response_code", "response_json"),
allow_none=True,
)
if result and result["response_code"]:
return result["response_code"], db_to_json(result["response_json"])
# If the result exists and the response code is non-0.
if result and result[0]:
return result[0], db_to_json(result[1])
else:
return None
@@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
# check we have a row and retry_last_ts is not null or zero
# (retry_last_ts can't be negative)
if result and result["retry_last_ts"]:
return DestinationRetryTimings(**result)
if result and result[1]:
return DestinationRetryTimings(
failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
)
else:
return None
+16 -15
View File
@@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session",
)
result["clientdict"] = db_to_json(result["clientdict"])
return UIAuthSessionData(session_id, **result)
return UIAuthSessionData(
session_id,
clientdict=db_to_json(result[0]),
uri=result[1],
method=result[2],
description=result[3],
)
async def mark_ui_auth_stage_complete(
self,
@@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
) -> None:
# Get the current value.
result = cast(
Dict[str, Any],
self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
),
result = self.db_pool.simple_select_one_onecol_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcol="serverdict",
)
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
serverdict = db_to_json(result)
serverdict[key] = value
self.db_pool.simple_update_one_txn(
@@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises:
StoreError if the session cannot be found.
"""
result = await self.db_pool.simple_select_one(
result = await self.db_pool.simple_select_one_onecol(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
retcol="serverdict",
desc="get_ui_auth_session_data",
)
serverdict = db_to_json(result["serverdict"])
serverdict = db_to_json(result)
return serverdict.get(key, default)
@@ -20,7 +20,6 @@ from typing import (
Collection,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
@@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
async def _get_user_in_directory(
self, user_id: str
) -> Optional[Tuple[Optional[str], Optional[str]]]:
"""
Fetch the user information in the user directory.
Returns:
None if the user is unknown, otherwise a tuple of display name and
avatar URL (both of which may be None).
"""
return cast(
Optional[Tuple[Optional[str], Optional[str]]],
await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
),
)
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
@@ -492,7 +492,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index
conn.set_session(autocommit=True)
conn.engine.attempt_to_set_autocommit(conn.conn, True)
try:
txn = conn.cursor()
txn.execute(
@@ -501,7 +501,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
finally:
conn.set_session(autocommit=False)
conn.engine.attempt_to_set_autocommit(conn.conn, False)
else:
txn = conn.cursor()
txn.execute(
+2 -1
View File
@@ -38,7 +38,8 @@ class PostgresEngine(
super().__init__(psycopg2, database_config)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
# Disables passing `bytes` to txn.execute, c.f.
# https://github.com/matrix-org/synapse/issues/6186. If you do
# actually want to use bytes than wrap it in `bytearray`.
def _disable_bytes_adapter(_: bytes) -> NoReturn:
raise Exception("Passing bytes to DB is disabled.")

Some files were not shown because too many files have changed in this diff Show More