1
0

Compare commits

..

19 Commits

Author SHA1 Message Date
Erik Johnston
4d97a894e1 Newsfiles 2024-05-29 12:59:45 +01:00
Erik Johnston
c7d897def7 Cache invalidation keys must be str 2024-05-29 12:59:19 +01:00
Erik Johnston
74dd17dad4 Drop fewer caches when we delete old rooms 2024-05-29 12:59:19 +01:00
Erik Johnston
726006cdf2 Don't invalidate all get_relations_for_event on history purge (#17083)
This is a tree cache already, so may as well move the room ID to the
front and use that
2024-05-29 12:57:10 +01:00
Erik Johnston
967b6948b0 Change allow_unsafe_locale to also apply on new databases (#17238)
We relax this as there are use cases where this is safe, though it is
still highly recommended that people avoid using it.
2024-05-29 12:04:13 +01:00
Erik Johnston
d7198dfb95 Ignore attempts to send to-device messages to bad users (#17240)
Currently sending a to-device message to a user ID with a dodgy
destination is accepted, but then ends up spamming the logs when we try
and send to the destination.

An alternative would be to reject the request, but I'm slightly nervous
that could break things.
2024-05-29 11:52:48 +01:00
Erik Johnston
94ef2f4f5d Handle duplicate OTK uploads racing (#17241)
Currently this causes one of then to 500.
2024-05-29 11:16:00 +01:00
Erik Johnston
bb5a692946 Fix slipped logging context when media rejected (#17239)
When a module rejects a piece of media we end up trying to close the
same logging context twice.

Instead of fixing the existing code we refactor to use an async context
manager, which is easier to write correctly.
2024-05-29 11:14:42 +01:00
Olivier 'reivilibre
ad179b0136 Merge branch 'master' into develop 2024-05-28 13:34:19 +01:00
dependabot[bot]
5147ce294a Bump phonenumbers from 8.13.35 to 8.13.37 (#17235) 2024-05-28 13:26:37 +01:00
Olivier 'reivilibre
f35bc08d39 1.108.0 2024-05-28 11:54:28 +01:00
dependabot[bot]
f2616edb73 Bump pyicu from 2.13 to 2.13.1 (#17236) 2024-05-28 11:28:58 +01:00
dependabot[bot]
86a2a0258f Bump pyopenssl from 24.0.0 to 24.1.0 (#17234) 2024-05-28 11:28:32 +01:00
dependabot[bot]
0893ee9af8 Bump prometheus-client from 0.19.0 to 0.20.0 (#17233) 2024-05-28 11:28:16 +01:00
dependabot[bot]
887f773472 Bump serde from 1.0.202 to 1.0.203 (#17232) 2024-05-28 11:27:51 +01:00
Shay
9edb725ebc Support MSC3916 by adding unstable media endpoints to _matrix/client (#17213)
[MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md)
adds new media endpoints under `_matrix/client`. This PR adds the
`/preview_url`, `/config`, and `/thumbnail` endpoints. `/download` will
be added in a follow-up PR once the work for the federation `/download`
endpoint is complete (see
https://github.com/element-hq/synapse/pull/17172).

Should be reviewable commit-by-commit.
2024-05-24 09:47:37 +01:00
Eric Eastwood
c97251d5ba Add Sliding Sync /sync/e2ee endpoint for To-Device messages (#17167)
This is being introduced as part of Sliding Sync but doesn't have any sliding window component. It's just a way to get E2EE events without having to sit through a big initial sync  (`/sync` v2). And we can avoid encryption events being backed up by the main sync response or vice-versa.

Part of some Sliding Sync simplification/experimentation. See [this discussion](https://github.com/element-hq/synapse/pull/17167#discussion_r1610495866) for why it may not be as useful as we thought.

Based on:

 - https://github.com/matrix-org/matrix-spec-proposals/pull/3575
 - https://github.com/matrix-org/matrix-spec-proposals/pull/3885
 - https://github.com/matrix-org/matrix-spec-proposals/pull/3884
2024-05-23 12:06:16 -05:00
reivilibre
7e2412265d Log exceptions when failing to auto-join new user according to the auto_join_rooms option. (#17176)
Would have been useful for tracking down #16878.

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
2024-05-22 14:22:33 +01:00
reivilibre
7ef00b7628 Add logging to tasks managed by the task scheduler, showing CPU and database usage. (#17219)
The log format is the same as the request log format, except:

- fields that are specific to HTTP requests have been removed
- the task's params are included at the end of the log line.

These log lines are emitted:
- when the task function finishes — both completion and failure (and I
suppose it is possible for a task to become schedulable again?)
- every 5 minutes whilst it is running

Closes #17217.

---------

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
2024-05-22 14:12:58 +01:00
47 changed files with 3576 additions and 943 deletions

View File

@@ -1,3 +1,10 @@
# Synapse 1.108.0 (2024-05-28)
No significant changes since 1.108.0rc1.
# Synapse 1.108.0rc1 (2024-05-21)
### Features

8
Cargo.lock generated
View File

@@ -485,18 +485,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "serde"
version = "1.0.202"
version = "1.0.203"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "226b61a0d411b2ba5ff6d7f73a476ac4f8bb900373459cd00fab8512828ba395"
checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.202"
version = "1.0.203"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6048858004bcff69094cd972ed40a32500f153bd3be9f716b2eed2e8217c4838"
checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba"
dependencies = [
"proc-macro2",
"quote",

1
changelog.d/17083.misc Normal file
View File

@@ -0,0 +1 @@
Improve DB usage when fetching related events.

1
changelog.d/17128.misc Normal file
View File

@@ -0,0 +1 @@
Improve DB usage when fetching related events.

View File

@@ -0,0 +1 @@
Add experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync/e2ee` endpoint for To-Device messages and device encryption info.

1
changelog.d/17176.misc Normal file
View File

@@ -0,0 +1 @@
Log exceptions when failing to auto-join new user according to the `auto_join_rooms` option.

View File

@@ -0,0 +1 @@
Support MSC3916 by adding unstable media endpoints to `_matrix/client` (#17213).

View File

@@ -0,0 +1 @@
Add logging to tasks managed by the task scheduler, showing CPU and database usage.

View File

@@ -1 +0,0 @@
Support loading pluggable modules from Docker images.

1
changelog.d/17238.misc Normal file
View File

@@ -0,0 +1 @@
Change the `allow_unsafe_locale` config option to also apply when setting up new databases.

1
changelog.d/17239.misc Normal file
View File

@@ -0,0 +1 @@
Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.

1
changelog.d/17240.bugfix Normal file
View File

@@ -0,0 +1 @@
Ignore attempts to send to-device messages to bad users, to avoid log spam when we try to connect to the bad server.

1
changelog.d/17241.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix handling of duplicate concurrent uploading of device one-time-keys.

6
debian/changelog vendored
View File

@@ -1,3 +1,9 @@
matrix-synapse-py3 (1.108.0) stable; urgency=medium
* New Synapse release 1.108.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 28 May 2024 11:54:22 +0100
matrix-synapse-py3 (1.108.0~rc1) stable; urgency=medium
* New Synapse release 1.108.0rc1.

View File

@@ -243,26 +243,3 @@ healthcheck:
Jemalloc is embedded in the image and will be used instead of the default allocator.
You can read about jemalloc by reading the Synapse
[Admin FAQ](https://element-hq.github.io/synapse/latest/usage/administration/admin_faq.html#help-synapse-is-slow-and-eats-all-my-ramcpu).
## Modules
Synapse supports loading additional modules, using the
[`modules`](https://element-hq.github.io/synapse/latest/modules/index.html)
config. Synapse will look for these modules `/modules`.
To install a package, simply run:
```
pip install --target <module_directory> <package>
```
Where `<module_directory>` is the directory mounted to `/modules`, and
`<package>` is either the package name or a path to the package. See
`pip install` for more details.
**Note**: Packages already installed as part of Synapse cannot be overridden by
different versions of the package in `/modules`, e.g. if the Synapse version
uses Twisted 24.3.0 then installing Twisted 23.10.0 in `/modules` won't have any
effect. This can cause issues if the required version of a package is different
between Synapse and the module being installed.

View File

@@ -269,15 +269,6 @@ running with 'migrate_config'. See the README for more details.
args += ["--config-path", config_path]
# Add the `/modules` directly to python search path, which allows users to
# add custom modules.
#
# We want to add the directory *last* so that nothing can overwrite the
# existing package versions. Therefore we load the current path and append
# `/modules` to that
path = ":".join(sys.path)
environ["PYTHONPATH"] = f"{path}:/modules"
log("Starting synapse with args " + " ".join(args))
args = [sys.executable] + args

View File

@@ -242,12 +242,11 @@ host all all ::1/128 ident
### Fixing incorrect `COLLATE` or `CTYPE`
Synapse will refuse to set up a new database if it has the wrong values of
`COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values
of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from
underneath the database, or if a different version of the locale is used on any
replicas.
Synapse will refuse to start when using a database with incorrect values of
`COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
`database` section of the config, is set to true. Using different locales can
cause issues if the locale library is updated from underneath the database, or
if a different version of the locale is used on any replicas.
If you have a database with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with
the correct locale parameter (as shown above). It is also possible to change the

24
poetry.lock generated
View File

@@ -1536,13 +1536,13 @@ files = [
[[package]]
name = "phonenumbers"
version = "8.13.35"
version = "8.13.37"
description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers."
optional = false
python-versions = "*"
files = [
{file = "phonenumbers-8.13.35-py2.py3-none-any.whl", hash = "sha256:58286a8e617bd75f541e04313b28c36398be6d4443a778c85e9617a93c391310"},
{file = "phonenumbers-8.13.35.tar.gz", hash = "sha256:64f061a967dcdae11e1c59f3688649e697b897110a33bb74d5a69c3e35321245"},
{file = "phonenumbers-8.13.37-py2.py3-none-any.whl", hash = "sha256:4ea00ef5012422c08c7955c21131e7ae5baa9a3ef52cf2d561e963f023006b80"},
{file = "phonenumbers-8.13.37.tar.gz", hash = "sha256:bd315fed159aea0516f7c367231810fe8344d5bec26156b88fa18374c11d1cf2"},
]
[[package]]
@@ -1673,13 +1673,13 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytes
[[package]]
name = "prometheus-client"
version = "0.19.0"
version = "0.20.0"
description = "Python client for the Prometheus monitoring system."
optional = false
python-versions = ">=3.8"
files = [
{file = "prometheus_client-0.19.0-py3-none-any.whl", hash = "sha256:c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92"},
{file = "prometheus_client-0.19.0.tar.gz", hash = "sha256:4585b0d1223148c27a225b10dbec5ae9bc4c81a99a3fa80774fa6209935324e1"},
{file = "prometheus_client-0.20.0-py3-none-any.whl", hash = "sha256:cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7"},
{file = "prometheus_client-0.20.0.tar.gz", hash = "sha256:287629d00b147a32dcb2be0b9df905da599b2d82f80377083ec8463309a4bb89"},
]
[package.extras]
@@ -1915,12 +1915,12 @@ plugins = ["importlib-metadata"]
[[package]]
name = "pyicu"
version = "2.13"
version = "2.13.1"
description = "Python extension wrapping the ICU C++ API"
optional = true
python-versions = "*"
files = [
{file = "PyICU-2.13.tar.gz", hash = "sha256:d481be888975df3097c2790241bbe8518f65c9676a74957cdbe790e559c828f6"},
{file = "PyICU-2.13.1.tar.gz", hash = "sha256:d4919085eaa07da12bade8ee721e7bbf7ade0151ca0f82946a26c8f4b98cdceb"},
]
[[package]]
@@ -1997,13 +1997,13 @@ tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
[[package]]
name = "pyopenssl"
version = "24.0.0"
version = "24.1.0"
description = "Python wrapper module around the OpenSSL library"
optional = false
python-versions = ">=3.7"
files = [
{file = "pyOpenSSL-24.0.0-py3-none-any.whl", hash = "sha256:ba07553fb6fd6a7a2259adb9b84e12302a9a8a75c44046e8bb5d3e5ee887e3c3"},
{file = "pyOpenSSL-24.0.0.tar.gz", hash = "sha256:6aa33039a93fffa4563e655b61d11364d01264be8ccb49906101e02a334530bf"},
{file = "pyOpenSSL-24.1.0-py3-none-any.whl", hash = "sha256:17ed5be5936449c5418d1cd269a1a9e9081bc54c17aed272b45856a3d3dc86ad"},
{file = "pyOpenSSL-24.1.0.tar.gz", hash = "sha256:cabed4bfaa5df9f1a16c0ef64a0cb65318b5cd077a7eda7d6970131ca2f41a6f"},
]
[package.dependencies]
@@ -2011,7 +2011,7 @@ cryptography = ">=41.0.5,<43"
[package.extras]
docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"]
test = ["flaky", "pretend", "pytest (>=3.0.1)"]
test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
[[package]]
name = "pysaml2"

View File

@@ -96,7 +96,7 @@ module-name = "synapse.synapse_rust"
[tool.poetry]
name = "matrix-synapse"
version = "1.108.0rc1"
version = "1.108.0"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "AGPL-3.0-or-later"

View File

@@ -332,6 +332,9 @@ class ExperimentalConfig(Config):
# MSC3391: Removing account data.
self.msc3391_enabled = experimental.get("msc3391_enabled", False)
# MSC3575 (Sliding Sync API endpoints)
self.msc3575_enabled: bool = experimental.get("msc3575_enabled", False)
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
@@ -436,3 +439,7 @@ class ExperimentalConfig(Config):
self.msc4115_membership_on_events = experimental.get(
"msc4115_membership_on_events", False
)
self.msc3916_authenticated_media_enabled = experimental.get(
"msc3916_authenticated_media_enabled", False
)

View File

@@ -236,6 +236,13 @@ class DeviceMessageHandler:
local_messages = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
if not UserID.is_valid(user_id):
logger.warning(
"Ignoring attempt to send device message to invalid user: %r",
user_id,
)
continue
# add an opentracing log entry for each message
for device_id, message_content in by_device.items():
log_kv(

View File

@@ -53,6 +53,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
class E2eKeysHandler:
def __init__(self, hs: "HomeServer"):
self.config = hs.config
@@ -62,6 +65,7 @@ class E2eKeysHandler:
self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
self._worker_lock_handler = hs.get_worker_locks_handler()
federation_registry = hs.get_federation_registry()
@@ -855,45 +859,53 @@ class E2eKeysHandler:
async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
) -> None:
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
device_id,
user_id,
time_now,
)
# We take out a lock so that we don't have to worry about a client
# sending duplicate requests.
lock_key = f"{user_id}_{device_id}"
async with self._worker_lock_handler.acquire_lock(
ONE_TIME_KEY_UPLOAD, lock_key
):
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
device_id,
user_id,
time_now,
)
# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((algorithm, key_id, key_obj))
# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys.
existing_key_map = await self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list]
)
# First we check if we have already persisted any of the keys.
existing_key_map = await self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list]
)
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
for algorithm, key_id, key in key_list:
ex_json = existing_key_map.get((algorithm, key_id), None)
if ex_json:
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
(
"One time key %s:%s already exists. "
"Old key: %s; new key: %r"
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
for algorithm, key_id, key in key_list:
ex_json = existing_key_map.get((algorithm, key_id), None)
if ex_json:
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
(
"One time key %s:%s already exists. "
"Old key: %s; new key: %r"
)
% (algorithm, key_id, ex_json, key),
)
% (algorithm, key_id, ex_json, key),
else:
new_keys.append(
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
)
else:
new_keys.append(
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
)
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
)
async def upload_signing_keys_for_user(
self, user_id: str, keys: JsonDict

View File

@@ -590,7 +590,7 @@ class RegistrationHandler:
# moving away from bare excepts is a good thing to do.
logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
logger.error("Failed to join new user to %r: %r", r, e, exc_info=True)
async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place

View File

@@ -393,9 +393,9 @@ class RelationsHandler:
# Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event(
room_id,
event_id,
event,
room_id,
RelationTypes.THREAD,
direction=Direction.FORWARDS,
)

View File

@@ -28,11 +28,14 @@ from typing import (
Dict,
FrozenSet,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
overload,
)
import attr
@@ -128,6 +131,8 @@ class SyncVersion(Enum):
# Traditional `/sync` endpoint
SYNC_V2 = "sync_v2"
# Part of MSC3575 Sliding Sync
E2EE_SYNC = "e2ee_sync"
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -280,6 +285,26 @@ class SyncResult:
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class E2eeSyncResult:
"""
Attributes:
next_batch: Token for the next sync
to_device: List of direct messages for the device.
device_lists: List of user_ids whose devices have changed
device_one_time_keys_count: Dict of algorithm to count for one time keys
for this device
device_unused_fallback_key_types: List of key types that have an unused fallback
key
"""
next_batch: StreamToken
to_device: List[JsonDict]
device_lists: DeviceListUpdates
device_one_time_keys_count: JsonMapping
device_unused_fallback_key_types: List[str]
class SyncHandler:
def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config
@@ -322,6 +347,31 @@ class SyncHandler:
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
@overload
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.SYNC_V2],
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> SyncResult: ...
@overload
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.E2EE_SYNC],
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> E2eeSyncResult: ...
@overload
async def wait_for_sync_for_user(
self,
requester: Requester,
@@ -331,7 +381,18 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> SyncResult:
) -> Union[SyncResult, E2eeSyncResult]: ...
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
sync_version: SyncVersion,
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> Union[SyncResult, E2eeSyncResult]:
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
@@ -344,8 +405,10 @@ class SyncHandler:
since_token: The point in the stream to sync from.
timeout: How long to wait for new data to arrive before giving up.
full_state: Whether to return the full state for each room.
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
"""
# If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
@@ -366,6 +429,29 @@ class SyncHandler:
logger.debug("Returning sync response for %s", user_id)
return res
@overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.SYNC_V2],
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> SyncResult: ...
@overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.E2EE_SYNC],
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> E2eeSyncResult: ...
@overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
@@ -374,7 +460,17 @@ class SyncHandler:
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> SyncResult:
) -> Union[SyncResult, E2eeSyncResult]: ...
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> Union[SyncResult, E2eeSyncResult]:
"""The start of the machinery that produces a /sync response.
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
@@ -417,14 +513,16 @@ class SyncHandler:
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
result: SyncResult = await self.current_sync_for_user(
sync_config, sync_version, since_token, full_state=full_state
result: Union[SyncResult, E2eeSyncResult] = (
await self.current_sync_for_user(
sync_config, sync_version, since_token, full_state=full_state
)
)
else:
# Otherwise, we wait for something to happen and report it to the user.
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
) -> SyncResult:
) -> Union[SyncResult, E2eeSyncResult]:
return await self.current_sync_for_user(
sync_config, sync_version, since_token
)
@@ -456,14 +554,43 @@ class SyncHandler:
return result
@overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.SYNC_V2],
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult: ...
@overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.E2EE_SYNC],
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> E2eeSyncResult: ...
@overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
"""Generates the response body of a sync result, represented as a SyncResult.
) -> Union[SyncResult, E2eeSyncResult]: ...
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> Union[SyncResult, E2eeSyncResult]:
"""
Generates the response body of a sync result, represented as a
`SyncResult`/`E2eeSyncResult`.
This is a wrapper around `generate_sync_result` which starts an open tracing
span to track the sync. See `generate_sync_result` for the next part of your
@@ -474,15 +601,25 @@ class SyncHandler:
sync_version: Determines what kind of sync response to generate.
since_token: The point in the stream to sync from.p.
full_state: Whether to return the full state for each room.
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
"""
with start_active_span("sync.current_sync_for_user"):
log_kv({"since_token": since_token})
# Go through the `/sync` v2 path
if sync_version == SyncVersion.SYNC_V2:
sync_result: SyncResult = await self.generate_sync_result(
sync_config, since_token, full_state
sync_result: Union[SyncResult, E2eeSyncResult] = (
await self.generate_sync_result(
sync_config, since_token, full_state
)
)
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path
elif sync_version == SyncVersion.E2EE_SYNC:
sync_result = await self.generate_e2ee_sync_result(
sync_config, since_token
)
else:
raise Exception(
@@ -1691,6 +1828,96 @@ class SyncHandler:
next_batch=sync_result_builder.now_token,
)
async def generate_e2ee_sync_result(
self,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
) -> E2eeSyncResult:
"""
Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.
This is represented by a `E2eeSyncResult` struct, which is built from small
pieces using a `SyncResultBuilder`. The `sync_result_builder` is passed as a
mutable ("inout") parameter to various helper functions. These retrieve and
process the data which forms the sync body, often writing to the
`sync_result_builder` to store their output.
At the end, we transfer data from the `sync_result_builder` to a new `E2eeSyncResult`
instance to signify that the sync calculation is complete.
"""
user_id = sync_config.user.to_string()
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
sync_result_builder = await self.get_sync_result_builder(
sync_config,
since_token,
full_state=False,
)
# 1. Calculate `to_device` events
await self._generate_sync_entry_for_to_device(sync_result_builder)
# 2. Calculate `device_lists`
# Device list updates are sent if a since token is provided.
device_lists = DeviceListUpdates()
include_device_list_updates = bool(since_token and since_token.device_list_key)
if include_device_list_updates:
# Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
# is used in calculate_user_changes below.
#
# TODO: Running `_generate_sync_entry_for_rooms()` is a lot of work just to
# figure out the membership changes/derived info needed for
# `_generate_sync_entry_for_device_list()`. In the future, we should try to
# refactor this away.
(
newly_joined_rooms,
newly_left_rooms,
) = await self._generate_sync_entry_for_rooms(sync_result_builder)
# This uses the sync_result_builder.joined which is set in
# `_generate_sync_entry_for_rooms`, if that didn't find any joined
# rooms for some reason it is a no-op.
(
newly_joined_or_invited_or_knocked_users,
newly_left_users,
) = sync_result_builder.calculate_user_changes()
device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
# 3. Calculate `device_one_time_keys_count` and `device_unused_fallback_key_types`
device_id = sync_config.device_id
one_time_keys_count: JsonMapping = {}
unused_fallback_key_types: List[str] = []
if device_id:
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
return E2eeSyncResult(
to_device=sync_result_builder.to_device,
device_lists=device_lists,
device_one_time_keys_count=one_time_keys_count,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)
async def get_sync_result_builder(
self,
sync_config: SyncConfig,
@@ -1889,7 +2116,7 @@ class SyncHandler:
users_that_have_changed = (
await self._device_handler.get_device_changes_in_shared_rooms(
user_id,
sync_result_builder.joined_room_ids,
joined_room_ids,
from_token=since_token,
now_token=sync_result_builder.now_token,
)

View File

@@ -650,7 +650,7 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
length, headers = await self.client.download_media(
server_name,
@@ -693,8 +693,6 @@ class MediaRepository:
)
raise SynapseError(502, "Failed to fetch remote media")
await finish()
if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else:
@@ -1045,14 +1043,9 @@ class MediaRepository:
),
)
with self.media_storage.store_into_file(file_info) as (
f,
fname,
finish,
):
async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
await self.media_storage.write_to_file(t_byte_source, f)
await finish()
finally:
t_byte_source.close()

View File

@@ -27,10 +27,9 @@ from typing import (
IO,
TYPE_CHECKING,
Any,
Awaitable,
AsyncIterator,
BinaryIO,
Callable,
Generator,
Optional,
Sequence,
Tuple,
@@ -97,11 +96,9 @@ class MediaStorage:
the file path written to in the primary media store
"""
with self.store_into_file(file_info) as (f, fname, finish_cb):
async with self.store_into_file(file_info) as (f, fname):
# Write to the main media repository
await self.write_to_file(source, f)
# Write to the other storage providers
await finish_cb()
return fname
@@ -111,32 +108,27 @@ class MediaStorage:
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@trace_with_opname("MediaStorage.store_into_file")
@contextlib.contextmanager
def store_into_file(
@contextlib.asynccontextmanager
async def store_into_file(
self, file_info: FileInfo
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as
) -> AsyncIterator[Tuple[BinaryIO, str]]:
"""Async Context manager used to get a file like object to write into, as
described by file_info.
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
on disk, and finish_cb is a function that returns an awaitable.
Actually yields a 2-tuple (file, fname,), where file is a file
like object that can be written to and fname is the absolute path of file
on disk.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
finish_cb must be called and waited on after the file has been successfully been
written to. Should not be called if there was an error. Checks for spam and
stores the file into the configured storage providers.
Args:
file_info: Info about the file to store
Example:
with media_storage.store_into_file(info) as (f, fname, finish_cb):
async with media_storage.store_into_file(info) as (f, fname,):
# .. write into f ...
await finish_cb()
"""
path = self._file_info_to_path(file_info)
@@ -145,62 +137,42 @@ class MediaStorage:
dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True)
finished_called = [False]
main_media_repo_write_trace_scope = start_active_span(
"writing to main media repo"
)
main_media_repo_write_trace_scope.__enter__()
try:
with open(fname, "wb") as f:
async def finish() -> None:
# When someone calls finish, we assume they are done writing to the main media repo
main_media_repo_write_trace_scope.__exit__(None, None, None)
with start_active_span("writing to other storage providers"):
# Ensure that all writes have been flushed and close the
# file.
f.flush()
f.close()
spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in
# the DB.
# We currently ignore any additional field returned by
# the spam-check API.
raise SpamMediaException(errcode=spam_check[0])
for provider in self.storage_providers:
with start_active_span(str(provider)):
await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
except Exception as e:
with main_media_repo_write_trace_scope:
try:
main_media_repo_write_trace_scope.__exit__(
type(e), None, e.__traceback__
with open(fname, "wb") as f:
yield f, fname
except Exception as e:
try:
os.remove(fname)
except Exception:
pass
raise e from None
with start_active_span("writing to other storage providers"):
spam_check = (
await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
)
os.remove(fname)
except Exception:
pass
raise e from None
if not finished_called:
exc = Exception("Finished callback not called")
main_media_repo_write_trace_scope.__exit__(
type(exc), None, exc.__traceback__
)
raise exc
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in
# the DB.
# We currently ignore any additional field returned by
# the spam-check API.
raise SpamMediaException(errcode=spam_check[0])
for provider in self.storage_providers:
with start_active_span(str(provider)):
await provider.store_file(path, file_info)
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache

View File

@@ -22,11 +22,27 @@
import logging
from io import BytesIO
from types import TracebackType
from typing import Optional, Tuple, Type
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
from PIL import Image
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
from synapse.media._base import (
FileInfo,
ThumbnailInfo,
respond_404,
respond_with_file,
respond_with_responder,
)
from synapse.media.media_storage import MediaStorage
if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -231,3 +247,471 @@ class Thumbnailer:
def __del__(self) -> None:
# Make sure we actually do close the image, rather than leak data.
self.close()
class ThumbnailProvider:
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
self.hs = hs
self.media_repo = media_repo
self.media_storage = media_storage
self.store = hs.get_datastores().main
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
async def respond_local_thumbnail(
self,
request: SynapseRequest,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_id,
media_id,
url_cache=bool(media_info.url_cache),
server_name=None,
)
async def select_or_generate_local_thumbnail(
self,
request: SynapseRequest,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info.width == desired_width
t_h = info.height == desired_height
t_method = info.method == desired_method
t_type = info.type == desired_type
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=bool(media_info.url_cache),
thumbnail=info,
)
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request, responder, info.type, info.length
)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = await self.media_repo.generate_local_exact_thumbnail(
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
url_cache=bool(media_info.url_cache),
)
if file_path:
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
raise SynapseError(400, "Failed to generate thumbnail.")
async def select_or_generate_remote_thumbnail(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
respond_404(request)
return
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
file_id = media_info.filesystem_id
for info in thumbnail_infos:
t_w = info.width == desired_width
t_h = info.height == desired_height
t_method = info.method == desired_method
t_type = info.type == desired_type
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=info,
)
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request, responder, info.type, info.length
)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id,
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
)
if file_path:
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
raise SynapseError(400, "Failed to generate thumbnail.")
async def respond_remote_thumbnail(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
max_timeout_ms: int,
) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_id,
media_info.filesystem_id,
url_cache=False,
server_name=server_name,
)
async def _select_and_respond_with_thumbnail(
self,
request: SynapseRequest,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[ThumbnailInfo],
media_id: str,
file_id: str,
url_cache: bool,
server_name: Optional[str] = None,
) -> None:
"""
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
Args:
request: The incoming request.
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of thumbnail info of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
"""
logger.debug(
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
media_id,
desired_width,
desired_height,
desired_method,
thumbnail_infos,
)
# If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
# different code path to handle it.
assert not self.dynamic_thumbnails
if thumbnail_infos:
file_info = self._select_thumbnail(
desired_width,
desired_height,
desired_method,
desired_type,
thumbnail_infos,
file_id,
url_cache,
server_name,
)
if not file_info:
logger.info("Couldn't find a thumbnail matching the desired inputs")
respond_404(request)
return
# The thumbnail property must exist.
assert file_info.thumbnail is not None
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request,
responder,
file_info.thumbnail.type,
file_info.thumbnail.length,
)
return
# If we can't find the thumbnail we regenerate it. This can happen
# if e.g. we've deleted the thumbnails but still have the original
# image somewhere.
#
# Since we have an entry for the thumbnail in the DB we a) know we
# have have successfully generated the thumbnail in the past (so we
# don't need to worry about repeatedly failing to generate
# thumbnails), and b) have already calculated that appropriate
# width/height/method so we can just call the "generate exact"
# methods.
# First let's check that we do actually have the original image
# still. This will throw a 404 if we don't.
# TODO: We should refetch the thumbnails for remote media.
await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
if server_name:
await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id=file_id,
media_id=media_id,
t_width=file_info.thumbnail.width,
t_height=file_info.thumbnail.height,
t_method=file_info.thumbnail.method,
t_type=file_info.thumbnail.type,
)
else:
await self.media_repo.generate_local_exact_thumbnail(
media_id=media_id,
t_width=file_info.thumbnail.width,
t_height=file_info.thumbnail.height,
t_method=file_info.thumbnail.method,
t_type=file_info.thumbnail.type,
url_cache=url_cache,
)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request,
responder,
file_info.thumbnail.type,
file_info.thumbnail.length,
)
else:
# This might be because:
# 1. We can't create thumbnails for the given media (corrupted or
# unsupported file type), or
# 2. The thumbnailing process never ran or errored out initially
# when the media was first uploaded (these bugs should be
# reported and fixed).
# Note that we don't attempt to generate a thumbnail now because
# `dynamic_thumbnails` is disabled.
logger.info("Failed to find any generated thumbnails")
assert request.path is not None
respond_with_json(
request,
400,
cs_error(
"Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
% (
request.path.decode(),
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
),
code=Codes.UNKNOWN,
),
send_cors=True,
)
def _select_thumbnail(
self,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[ThumbnailInfo],
file_id: str,
url_cache: bool,
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
Choose an appropriate thumbnail from the previously generated thumbnails.
Args:
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
Returns:
The thumbnail which best matches the desired parameters.
"""
desired_method = desired_method.lower()
# The chosen thumbnail.
thumbnail_info = None
d_w = desired_width
d_h = desired_height
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list: List[
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
] = []
# Other thumbnails.
crop_info_list2: List[
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info.method != "crop":
continue
t_w = info.width
t_h = info.height
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info.type
length_quality = info.length
if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
else:
crop_info_list2.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
# the thumbnail info and pick the thumbnail that appears earlier
# in the list of candidates.
if crop_info_list:
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
elif crop_info_list2:
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
# Other thumbnails.
info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info.method != "scale":
continue
t_w = info.width
t_h = info.height
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info.type
length_quality = info.length
if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
# the thumbnail info and pick the thumbnail that appears earlier
# in the list of candidates.
if info_list:
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
elif info_list2:
thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
if thumbnail_info:
return FileInfo(
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
thumbnail=thumbnail_info,
)
# No matching thumbnail was found.
return None

View File

@@ -592,7 +592,7 @@ class UrlPreviewer:
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
async with self.media_storage.store_into_file(file_info) as (f, fname):
if url.startswith("data:"):
if not allow_data_urls:
raise SynapseError(
@@ -603,8 +603,6 @@ class UrlPreviewer:
else:
download_result = await self._download_url(url, f)
await finish()
try:
time_now_ms = self.clock.time_msec()

View File

@@ -0,0 +1,205 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import logging
import re
from synapse.http.server import (
HttpServer,
respond_with_json,
respond_with_json_bytes,
set_corp_headers,
set_cors_headers,
)
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
respond_404,
)
from synapse.media.media_repository import MediaRepository
from synapse.media.media_storage import MediaStorage
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.server import HomeServer
from synapse.util.stringutils import parse_and_validate_server_name
logger = logging.getLogger(__name__)
class UnstablePreviewURLServlet(RestServlet):
"""
Same as `GET /_matrix/media/r0/preview_url`, this endpoint provides a generic preview API
for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
specific additions).
This does have trade-offs compared to other designs:
* Pros:
* Simple and flexible; can be used by any clients at any point
* Cons:
* If each homeserver provides one of these independently, all the homeservers in a
room may needlessly DoS the target URI
* The URL metadata must be stored somewhere, rather than just using Matrix
itself to store the media.
* Matrix cannot be used to distribute the metadata between homeservers.
"""
PATTERNS = [
re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/preview_url$")
]
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.media_repo = media_repo
self.media_storage = media_storage
assert self.media_repo.url_previewer is not None
self.url_previewer = self.media_repo.url_previewer
async def on_GET(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
url = parse_string(request, "url", required=True)
ts = parse_integer(request, "ts")
if ts is None:
ts = self.clock.time_msec()
og = await self.url_previewer.preview(url, requester.user, ts)
respond_with_json_bytes(request, 200, og, send_cors=True)
class UnstableMediaConfigResource(RestServlet):
PATTERNS = [
re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/config$")
]
def __init__(self, hs: "HomeServer"):
super().__init__()
config = hs.config
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.media.max_upload_size}
async def on_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
class UnstableThumbnailResource(RestServlet):
PATTERNS = [
re.compile(
"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
]
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__()
self.store = hs.get_datastores().main
self.media_repo = media_repo
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
self.thumbnailer = ThumbnailProvider(hs, media_repo, media_storage)
self.auth = hs.get_auth()
async def on_GET(
self, request: SynapseRequest, server_name: str, media_id: str
) -> None:
# Validate the server name, raising if invalid
parse_and_validate_server_name(server_name)
await self.auth.get_user_by_req(request)
set_cors_headers(request)
set_corp_headers(request)
width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale")
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
m_type = "image/png"
max_timeout_ms = parse_integer(
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
)
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails:
await self.thumbnailer.select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
else:
await self.thumbnailer.respond_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
self.media_repo.mark_recently_accessed(None, media_id)
else:
# Don't let users download media from configured domains, even if it
# is already downloaded. This is Trust & Safety tooling to make some
# media inaccessible to local users.
# See `prevent_media_downloads_from` config docs for more info.
if server_name in self.prevent_media_downloads_from:
respond_404(request)
return
remote_resp_function = (
self.thumbnailer.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
else self.thumbnailer.respond_remote_thumbnail
)
await remote_resp_function(
request,
server_name,
media_id,
width,
height,
method,
m_type,
max_timeout_ms,
)
self.media_repo.mark_recently_accessed(server_name, media_id)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc3916_authenticated_media_enabled:
media_repo = hs.get_media_repository()
if hs.config.media.url_preview_enabled:
UnstablePreviewURLServlet(
hs, media_repo, media_repo.media_storage
).register(http_server)
UnstableMediaConfigResource(hs).register(http_server)
UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
http_server
)

View File

@@ -567,5 +567,176 @@ class SyncRestServlet(RestServlet):
return result
class SlidingSyncE2eeRestServlet(RestServlet):
"""
API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
of Sliding Sync but doesn't have any sliding window component. It's just a way to
get E2EE events without having to sit through a big initial sync (`/sync` v2). And
we can avoid encryption events being backed up by the main sync response.
Having To-Device messages split out to this sync endpoint also helps when clients
need to have 2 or more sync streams open at a time, e.g a push notification process
and a main process. This can cause the two processes to race to fetch the To-Device
events, resulting in the need for complex synchronisation rules to ensure the token
is correctly and atomically exchanged between processes.
GET parameters::
timeout(int): How long to wait for new events in milliseconds.
since(batch_token): Batch token when asking for incremental deltas.
Response JSON::
{
"next_batch": // batch token for the next /sync
"to_device": {
// list of to-device events
"events": [
{
"content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
"type": "m.room.encrypted",
"sender": "@alice:example.com",
}
// ...
]
},
"device_lists": {
"changed": ["@alice:example.com"],
"left": ["@bob:example.com"]
},
"device_one_time_keys_count": {
"signed_curve25519": 50
},
"device_unused_fallback_key_types": [
"signed_curve25519"
]
}
"""
PATTERNS = client_patterns(
"/org.matrix.msc3575/sync/e2ee$", releases=[], v1=False, unstable=True
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.sync_handler = hs.get_sync_handler()
# Filtering only matters for the `device_lists` because it requires a bunch of
# derived information from rooms (see how `_generate_sync_entry_for_rooms()`
# prepares a bunch of data for `_generate_sync_entry_for_device_list()`).
self.only_member_events_filter_collection = FilterCollection(
self.hs,
{
"room": {
# We only care about membership events for the `device_lists`.
# Membership will tell us whether a user has joined/left a room and
# if there are new devices to encrypt for.
"timeline": {
"types": ["m.room.member"],
},
"state": {
"types": ["m.room.member"],
},
# We don't want any extra account_data generated because it's not
# returned by this endpoint. This helps us avoid work in
# `_generate_sync_entry_for_rooms()`
"account_data": {
"not_types": ["*"],
},
# We don't want any extra ephemeral data generated because it's not
# returned by this endpoint. This helps us avoid work in
# `_generate_sync_entry_for_rooms()`
"ephemeral": {
"not_types": ["*"],
},
},
# We don't want any extra account_data generated because it's not
# returned by this endpoint. (This is just here for good measure)
"account_data": {
"not_types": ["*"],
},
# We don't want any extra presence data generated because it's not
# returned by this endpoint. (This is just here for good measure)
"presence": {
"not_types": ["*"],
},
},
)
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = requester.user
device_id = requester.device_id
timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since")
sync_config = SyncConfig(
user=user,
filter_collection=self.only_member_events_filter_collection,
is_guest=requester.is_guest,
device_id=device_id,
)
since_token = None
if since is not None:
since_token = await StreamToken.from_string(self.store, since)
# Request cache key
request_key = (
SyncVersion.E2EE_SYNC,
user,
timeout,
since,
)
# Gather data for the response
sync_result = await self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
SyncVersion.E2EE_SYNC,
request_key,
since_token=since_token,
timeout=timeout,
full_state=False,
)
# The client may have disconnected by now; don't bother to serialize the
# response if so.
if request._disconnected:
logger.info("Client has disconnected; not serializing response.")
return 200, {}
response: JsonDict = defaultdict(dict)
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
if sync_result.to_device:
response["to_device"] = {"events": sync_result.to_device}
if sync_result.device_lists.changed:
response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
if sync_result.device_lists.left:
response["device_lists"]["left"] = list(sync_result.device_lists.left)
# We always include this because https://github.com/vector-im/element-android/issues/3725
# The spec isn't terribly clear on when this can be omitted and how a client would tell
# the difference between "no keys present" and "nothing changed" in terms of whole field
# absent / individual key type entry absent
# Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456
response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count
# https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
# states that this field should always be included, as long as the server supports the feature.
response["device_unused_fallback_key_types"] = (
sync_result.device_unused_fallback_key_types
)
return 200, response
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
if hs.config.experimental.msc3575_enabled:
SlidingSyncE2eeRestServlet(hs).register(http_server)

View File

@@ -22,23 +22,18 @@
import logging
import re
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
from synapse.http.server import respond_with_json, set_corp_headers, set_cors_headers
from synapse.http.server import set_corp_headers, set_cors_headers
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
FileInfo,
ThumbnailInfo,
respond_404,
respond_with_file,
respond_with_responder,
)
from synapse.media.media_storage import MediaStorage
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
@@ -66,10 +61,11 @@ class ThumbnailResource(RestServlet):
self.store = hs.get_datastores().main
self.media_repo = media_repo
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self.thumbnail_provider = ThumbnailProvider(hs, media_repo, media_storage)
async def on_GET(
self, request: SynapseRequest, server_name: str, media_id: str
@@ -91,11 +87,11 @@ class ThumbnailResource(RestServlet):
if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails:
await self._select_or_generate_local_thumbnail(
await self.thumbnail_provider.select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
else:
await self._respond_local_thumbnail(
await self.thumbnail_provider.respond_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
self.media_repo.mark_recently_accessed(None, media_id)
@@ -109,9 +105,9 @@ class ThumbnailResource(RestServlet):
return
remote_resp_function = (
self._select_or_generate_remote_thumbnail
self.thumbnail_provider.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
else self._respond_remote_thumbnail
else self.thumbnail_provider.respond_remote_thumbnail
)
await remote_resp_function(
request,
@@ -124,457 +120,3 @@ class ThumbnailResource(RestServlet):
max_timeout_ms,
)
self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail(
self,
request: SynapseRequest,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_id,
media_id,
url_cache=bool(media_info.url_cache),
server_name=None,
)
async def _select_or_generate_local_thumbnail(
self,
request: SynapseRequest,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info.width == desired_width
t_h = info.height == desired_height
t_method = info.method == desired_method
t_type = info.type == desired_type
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=bool(media_info.url_cache),
thumbnail=info,
)
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request, responder, info.type, info.length
)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = await self.media_repo.generate_local_exact_thumbnail(
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
url_cache=bool(media_info.url_cache),
)
if file_path:
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
raise SynapseError(400, "Failed to generate thumbnail.")
async def _select_or_generate_remote_thumbnail(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
respond_404(request)
return
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
file_id = media_info.filesystem_id
for info in thumbnail_infos:
t_w = info.width == desired_width
t_h = info.height == desired_height
t_method = info.method == desired_method
t_type = info.type == desired_type
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=info,
)
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request, responder, info.type, info.length
)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id,
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
)
if file_path:
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
max_timeout_ms: int,
) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_id,
media_info.filesystem_id,
url_cache=False,
server_name=server_name,
)
async def _select_and_respond_with_thumbnail(
self,
request: SynapseRequest,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[ThumbnailInfo],
media_id: str,
file_id: str,
url_cache: bool,
server_name: Optional[str] = None,
) -> None:
"""
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
Args:
request: The incoming request.
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of thumbnail info of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
"""
logger.debug(
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
media_id,
desired_width,
desired_height,
desired_method,
thumbnail_infos,
)
# If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
# different code path to handle it.
assert not self.dynamic_thumbnails
if thumbnail_infos:
file_info = self._select_thumbnail(
desired_width,
desired_height,
desired_method,
desired_type,
thumbnail_infos,
file_id,
url_cache,
server_name,
)
if not file_info:
logger.info("Couldn't find a thumbnail matching the desired inputs")
respond_404(request)
return
# The thumbnail property must exist.
assert file_info.thumbnail is not None
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request,
responder,
file_info.thumbnail.type,
file_info.thumbnail.length,
)
return
# If we can't find the thumbnail we regenerate it. This can happen
# if e.g. we've deleted the thumbnails but still have the original
# image somewhere.
#
# Since we have an entry for the thumbnail in the DB we a) know we
# have have successfully generated the thumbnail in the past (so we
# don't need to worry about repeatedly failing to generate
# thumbnails), and b) have already calculated that appropriate
# width/height/method so we can just call the "generate exact"
# methods.
# First let's check that we do actually have the original image
# still. This will throw a 404 if we don't.
# TODO: We should refetch the thumbnails for remote media.
await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
if server_name:
await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id=file_id,
media_id=media_id,
t_width=file_info.thumbnail.width,
t_height=file_info.thumbnail.height,
t_method=file_info.thumbnail.method,
t_type=file_info.thumbnail.type,
)
else:
await self.media_repo.generate_local_exact_thumbnail(
media_id=media_id,
t_width=file_info.thumbnail.width,
t_height=file_info.thumbnail.height,
t_method=file_info.thumbnail.method,
t_type=file_info.thumbnail.type,
url_cache=url_cache,
)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request,
responder,
file_info.thumbnail.type,
file_info.thumbnail.length,
)
else:
# This might be because:
# 1. We can't create thumbnails for the given media (corrupted or
# unsupported file type), or
# 2. The thumbnailing process never ran or errored out initially
# when the media was first uploaded (these bugs should be
# reported and fixed).
# Note that we don't attempt to generate a thumbnail now because
# `dynamic_thumbnails` is disabled.
logger.info("Failed to find any generated thumbnails")
assert request.path is not None
respond_with_json(
request,
400,
cs_error(
"Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
% (
request.path.decode(),
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
),
code=Codes.UNKNOWN,
),
send_cors=True,
)
def _select_thumbnail(
self,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[ThumbnailInfo],
file_id: str,
url_cache: bool,
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
Choose an appropriate thumbnail from the previously generated thumbnails.
Args:
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
Returns:
The thumbnail which best matches the desired parameters.
"""
desired_method = desired_method.lower()
# The chosen thumbnail.
thumbnail_info = None
d_w = desired_width
d_h = desired_height
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list: List[
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
] = []
# Other thumbnails.
crop_info_list2: List[
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info.method != "crop":
continue
t_w = info.width
t_h = info.height
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info.type
length_quality = info.length
if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
else:
crop_info_list2.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
# the thumbnail info and pick the thumbnail that appears earlier
# in the list of candidates.
if crop_info_list:
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
elif crop_info_list2:
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
# Other thumbnails.
info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info.method != "scale":
continue
t_w = info.width
t_h = info.height
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info.type
length_quality = info.length
if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
# the thumbnail info and pick the thumbnail that appears earlier
# in the list of candidates.
if info_list:
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
elif info_list2:
thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
if thumbnail_info:
return FileInfo(
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
thumbnail=thumbnail_info,
)
# No matching thumbnail was found.
return None

View File

@@ -148,6 +148,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
self._attempt_to_invalidate_cache("get_user_in_room_with_profile", None)
self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None)
self._attempt_to_invalidate_cache(
"get_rooms_for_user_with_stream_ordering", None
)

View File

@@ -43,6 +43,7 @@ from synapse.storage.database import (
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StrCollection
from synapse.util.caches.descriptors import CachedFunction
from synapse.util.iterutils import batch_iter
@@ -233,8 +234,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
room_id = row.keys[0]
self._invalidate_caches_for_room_events(room_id)
self._invalidate_caches_for_room(room_id)
server_joined = bool(row.keys.get(1, "true"))
self._invalidate_caches_for_room(room_id, server_joined)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
@@ -318,7 +319,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined]
# Caches which might leak edits must be invalidated for the event being
# redacted.
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
self._attempt_to_invalidate_cache(
"get_relations_for_event",
(
room_id,
redacts,
),
)
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
@@ -345,7 +352,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
self._attempt_to_invalidate_cache(
"get_relations_for_event",
(
room_id,
relates_to,
),
)
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
@@ -376,21 +389,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._invalidate_local_get_event_cache_room_id(room_id) # type: ignore[attr-defined]
self._attempt_to_invalidate_cache("have_seen_event", (room_id,))
self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,))
self._attempt_to_invalidate_cache(
"get_unread_event_push_actions_by_room_for_user", (room_id,)
)
self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,))
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
self._attempt_to_invalidate_cache("get_relations_for_event", None)
self._attempt_to_invalidate_cache("get_applicable_edit", None)
self._attempt_to_invalidate_cache("get_thread_id", None)
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None)
self._attempt_to_invalidate_cache(
"get_rooms_for_user_with_stream_ordering", None
)
self._attempt_to_invalidate_cache("get_rooms_for_user", None)
self._attempt_to_invalidate_cache("did_forget", None)
self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
self._attempt_to_invalidate_cache("get_references_for_event", None)
@@ -405,17 +413,28 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("_get_joined_profile_from_event_id", None)
def _invalidate_caches_for_room_and_stream(
self, txn: LoggingTransaction, room_id: str
self,
txn: LoggingTransaction,
room_id: str,
server_in_room: bool,
) -> None:
"""Invalidate caches associated with rooms, and stream to replication.
Used when we delete rooms.
Args:
txn
room_id
server_in_room: Whether the server was joined or invited to the
room when we deleted it.
"""
self._send_invalidation_to_replication(txn, DELETE_ROOM_CACHE_NAME, [room_id])
txn.call_after(self._invalidate_caches_for_room, room_id)
self._send_invalidation_to_replication(
txn, DELETE_ROOM_CACHE_NAME, [room_id, "true" if server_in_room else ""]
)
txn.call_after(self._invalidate_caches_for_room, room_id, server_in_room)
def _invalidate_caches_for_room(self, room_id: str) -> None:
def _invalidate_caches_for_room(self, room_id: str, server_in_room: bool) -> None:
"""Invalidate caches associated with rooms.
Used when we delete rooms.
@@ -427,8 +446,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_account_data_for_room", None)
self._attempt_to_invalidate_cache("get_account_data_for_room_and_type", None)
self._attempt_to_invalidate_cache("get_aliases_for_room", (room_id,))
self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,))
self._attempt_to_invalidate_cache("_get_forward_extremeties_for_room", None)
if server_in_room:
self._attempt_to_invalidate_cache(
"get_latest_event_ids_in_room", (room_id,)
)
self._attempt_to_invalidate_cache("_get_forward_extremeties_for_room", None)
# And delete state caches.
self._invalidate_state_caches_all(room_id)
self._attempt_to_invalidate_cache(
"get_unread_event_push_actions_by_room_for_user", (room_id,)
)
@@ -441,19 +468,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"_get_partial_state_servers_at_join", (room_id,)
)
self._attempt_to_invalidate_cache("is_partial_state_room", (room_id,))
self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None)
self._attempt_to_invalidate_cache(
"get_current_hosts_in_room_ordered", (room_id,)
)
self._attempt_to_invalidate_cache("did_forget", None)
self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
self._attempt_to_invalidate_cache("get_room_version_id", (room_id,))
# And delete state caches.
self._invalidate_state_caches_all(room_id)
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
) -> None:
@@ -548,7 +569,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
for chunk in batch_iter(members_changed, 50):
keys = itertools.chain([room_id], chunk)
self._send_invalidation_to_replication(
txn, CURRENT_STATE_CACHE_NAME, keys
txn, CURRENT_STATE_CACHE_NAME, list(keys)
)
else:
# if no members changed, we still need to invalidate the other caches.
@@ -567,7 +588,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def _send_invalidation_to_replication(
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
self, txn: LoggingTransaction, cache_name: str, keys: Optional[StrCollection]
) -> None:
"""Notifies replication that given cache has been invalidated.

View File

@@ -1923,7 +1923,12 @@ class PersistEventsStore:
# Any relation information for the related event must be cleared.
self.store._invalidate_cache_and_stream(
txn, self.store.get_relations_for_event, (redacted_relates_to,)
txn,
self.store.get_relations_for_event,
(
room_id,
redacted_relates_to,
),
)
if rel_type == RelationTypes.REFERENCE:
self.store._invalidate_cache_and_stream(

View File

@@ -1181,7 +1181,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
results = list(txn)
# (event_id, parent_id, rel_type) for each relation
relations_to_insert: List[Tuple[str, str, str]] = []
relations_to_insert: List[Tuple[str, str, str, str]] = []
for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
@@ -1214,7 +1214,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if not isinstance(parent_id, str):
continue
relations_to_insert.append((event_id, parent_id, rel_type))
room_id = event_json["room_id"]
relations_to_insert.append((room_id, event_id, parent_id, rel_type))
# Insert the missing data, note that we upsert here in case the event
# has already been processed.
@@ -1223,18 +1224,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="event_relations",
key_names=("event_id",),
key_values=[(r[0],) for r in relations_to_insert],
key_values=[(r[1],) for r in relations_to_insert],
value_names=("relates_to_id", "relation_type"),
value_values=[r[1:] for r in relations_to_insert],
value_values=[r[2:] for r in relations_to_insert],
)
# Iterate the parent IDs and invalidate caches.
cache_tuples = {(r[1],) for r in relations_to_insert}
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined]
txn,
self.get_relations_for_event, # type: ignore[attr-defined]
{
(
r[0], # room_id
r[2], # parent_id
)
for r in relations_to_insert
},
)
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined]
txn,
self.get_thread_summary, # type: ignore[attr-defined]
{(r[1],) for r in relations_to_insert},
)
if results:

View File

@@ -22,6 +22,7 @@
import logging
from typing import Any, List, Set, Tuple, cast
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -376,6 +377,20 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
(room_id,),
)
server_in_room = self._check_host_room_membership_txn( # type: ignore[attr-defined]
txn,
room_id,
host=self.hs.hostname,
membership=Membership.JOIN,
)
if not server_in_room:
server_in_room = self._check_host_room_membership_txn( # type: ignore[attr-defined]
txn,
room_id,
host=self.hs.hostname,
membership=Membership.INVITE,
)
# First, fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
@@ -503,6 +518,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# index on them. In any case we should be clearing out 'stream' tables
# periodically anyway (https://github.com/matrix-org/synapse/issues/5888)
self._invalidate_caches_for_room_and_stream(txn, room_id)
self._invalidate_caches_for_room_and_stream(txn, room_id, server_in_room)
return state_groups

View File

@@ -169,9 +169,9 @@ class RelationsWorkerStore(SQLBaseStore):
@cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
room_id: str,
event_id: str,
event: EventBase,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 5,

View File

@@ -962,6 +962,17 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
async def _check_host_room_membership(
self, room_id: str, host: str, membership: str
) -> bool:
return await self.db_pool.runInteraction(
"is_host_joined",
self._check_host_room_membership_txn,
room_id,
host,
membership,
)
def _check_host_room_membership_txn(
self, txn: LoggingTransaction, room_id: str, host: str, membership: str
) -> bool:
if "%" in host or "_" in host:
raise Exception("Invalid host name")
@@ -980,14 +991,14 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
rows = await self.db_pool.execute(
"is_host_joined", sql, membership, room_id, like_clause
)
txn.execute(sql, (membership, room_id, like_clause))
if not rows:
row = txn.fetchone()
if not row:
return False
user_id = rows[0][0]
user_id = row[0]
if get_domain_from_id(user_id) != host:
# This can only happen if the host name has something funky in it
raise Exception("Invalid host name")

View File

@@ -142,6 +142,10 @@ class PostgresEngine(
apply stricter checks on new databases versus existing database.
"""
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
if allow_unsafe_locale:
return
collation, ctype = self.get_db_locale(txn)
errors = []
@@ -155,7 +159,9 @@ class PostgresEngine(
if errors:
raise IncorrectDatabaseSetup(
"Database is incorrectly configured:\n\n%s\n\n"
"See docs/postgres.md for more information." % ("\n".join(errors))
"See docs/postgres.md for more information. You can override this check by"
"setting 'allow_unsafe_locale' to true in the database config.",
"\n".join(errors),
)
def convert_param_style(self, sql: str) -> str:

View File

@@ -24,7 +24,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
from twisted.python.failure import Failure
from synapse.logging.context import nested_logging_context
from synapse.logging.context import (
ContextResourceUsage,
LoggingContext,
nested_logging_context,
set_current_context,
)
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import (
run_as_background_process,
@@ -81,6 +86,8 @@ class TaskScheduler:
MAX_CONCURRENT_RUNNING_TASKS = 5
# Time from the last task update after which we will log a warning
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
# Report a running task's status and usage every so often.
OCCASIONAL_REPORT_INTERVAL_MS = 5 * 60 * 1000 # 5 minutes
def __init__(self, hs: "HomeServer"):
self._hs = hs
@@ -346,6 +353,33 @@ class TaskScheduler:
assert task.id not in self._running_tasks
await self._store.delete_scheduled_task(task.id)
@staticmethod
def _log_task_usage(
state: str, task: ScheduledTask, usage: ContextResourceUsage, active_time: float
) -> None:
"""
Log a line describing the state and usage of a task.
The log line is inspired by / a copy of the request log line format,
but with irrelevant fields removed.
active_time: Time that the task has been running for, in seconds.
"""
logger.info(
"Task %s: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" [%d dbevts] %r, %r",
state,
active_time,
usage.ru_utime,
usage.ru_stime,
usage.db_sched_duration_sec,
usage.db_txn_duration_sec,
int(usage.db_txn_count),
usage.evt_db_fetch_count,
task.resource_id,
task.params,
)
async def _launch_task(self, task: ScheduledTask) -> None:
"""Launch a scheduled task now.
@@ -360,8 +394,32 @@ class TaskScheduler:
)
function = self._actions[task.action]
def _occasional_report(
task_log_context: LoggingContext, start_time: float
) -> None:
"""
Helper to log a 'Task continuing' line every so often.
"""
current_time = self._clock.time()
calling_context = set_current_context(task_log_context)
try:
usage = task_log_context.get_resource_usage()
TaskScheduler._log_task_usage(
"continuing", task, usage, current_time - start_time
)
finally:
set_current_context(calling_context)
async def wrapper() -> None:
with nested_logging_context(task.id):
with nested_logging_context(task.id) as log_context:
start_time = self._clock.time()
occasional_status_call = self._clock.looping_call(
_occasional_report,
TaskScheduler.OCCASIONAL_REPORT_INTERVAL_MS,
log_context,
start_time,
)
try:
(status, result, error) = await function(task)
except Exception:
@@ -383,6 +441,13 @@ class TaskScheduler:
)
self._running_tasks.remove(task.id)
current_time = self._clock.time()
usage = log_context.get_resource_usage()
TaskScheduler._log_task_usage(
status.value, task, usage, current_time - start_time
)
occasional_status_call.stop()
# Try launch a new task since we've finished with this one.
self._clock.call_later(0.1, self._launch_scheduled_tasks)

View File

@@ -18,6 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
import itertools
import os
import shutil
import tempfile
@@ -46,11 +47,11 @@ from synapse.media._base import FileInfo, ThumbnailInfo
from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
from synapse.media.storage_provider import FileStorageProviderBackend
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.module_api import ModuleApi
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
from synapse.rest import admin
from synapse.rest.client import login
from synapse.rest.media.thumbnail_resource import ThumbnailResource
from synapse.rest.client import login, media
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
@@ -153,68 +154,54 @@ class _TestImage:
is_inline: bool = True
@parameterized_class(
("test_image",),
[
# small png
(
_TestImage(
SMALL_PNG,
b"image/png",
b".png",
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000020000000200806"
b"000000737a7af40000001a49444154789cedc101010000008220"
b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
b"44ae426082"
),
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000d49444154789c636060606000000005"
b"0001a5f645400000000049454e44ae426082"
),
),
),
# small png with transparency.
(
_TestImage(
unhexlify(
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
b"4154789c636800000082008177cd72b60000000049454e44ae426"
b"082"
),
b"image/png",
b".png",
# Note that we don't check the output since it varies across
# different versions of Pillow.
),
),
# small lossless webp
(
_TestImage(
unhexlify(
b"524946461a000000574542505650384c0d0000002f0000001007"
b"1011118888fe0700"
),
b"image/webp",
b".webp",
),
),
# an empty file
(
_TestImage(
b"",
b"image/gif",
b".gif",
expected_found=False,
unable_to_thumbnail=True,
),
),
# An SVG.
(
_TestImage(
b"""<?xml version="1.0"?>
small_png = _TestImage(
SMALL_PNG,
b"image/png",
b".png",
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000020000000200806"
b"000000737a7af40000001a49444154789cedc101010000008220"
b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
b"44ae426082"
),
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000d49444154789c636060606000000005"
b"0001a5f645400000000049454e44ae426082"
),
)
small_png_with_transparency = _TestImage(
unhexlify(
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
b"4154789c636800000082008177cd72b60000000049454e44ae426"
b"082"
),
b"image/png",
b".png",
# Note that we don't check the output since it varies across
# different versions of Pillow.
)
small_lossless_webp = _TestImage(
unhexlify(
b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
),
b"image/webp",
b".webp",
)
empty_file = _TestImage(
b"",
b"image/gif",
b".gif",
expected_found=False,
unable_to_thumbnail=True,
)
SVG = _TestImage(
b"""<?xml version="1.0"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
@@ -223,19 +210,32 @@ class _TestImage:
<circle cx="100" cy="100" r="50" stroke="black"
stroke-width="5" fill="red" />
</svg>""",
b"image/svg",
b".svg",
expected_found=False,
unable_to_thumbnail=True,
is_inline=False,
),
),
],
b"image/svg",
b".svg",
expected_found=False,
unable_to_thumbnail=True,
is_inline=False,
)
test_images = [
small_png,
small_png_with_transparency,
small_lossless_webp,
empty_file,
SVG,
]
urls = [
"_matrix/media/r0/thumbnail",
"_matrix/client/unstable/org.matrix.msc3916/media/thumbnail",
]
@parameterized_class(("test_image", "url"), itertools.product(test_images, urls))
class MediaRepoTests(unittest.HomeserverTestCase):
servlets = [media.register_servlets]
test_image: ClassVar[_TestImage]
hijack_auth = True
user_id = "@test:user"
url: ClassVar[str]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fetches: List[
@@ -298,6 +298,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
hs = self.setup_test_homeserver(config=config, federation_http_client=client)
@@ -502,7 +503,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=scale"
channel = self.make_request(
"GET",
f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
f"/{self.url}/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -530,7 +531,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
f"/{self.url}/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -566,12 +567,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=" + method
channel = self.make_request(
"GET",
f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
f"/{self.url}/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
self.pump()
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
b"Content-Type": [self.test_image.content_type],
@@ -580,7 +580,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
(self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
if expected_found:
self.assertEqual(channel.code, 200)
@@ -603,7 +602,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_UNKNOWN",
"error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
"error": f"Cannot find any thumbnails for the requested media ('/{self.url}/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
},
)
else:
@@ -613,7 +612,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_NOT_FOUND",
"error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'",
"error": f"Not found '/{self.url}/example.com/12345'",
},
)
@@ -625,12 +624,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
content_type = self.test_image.content_type.decode()
media_repo = self.hs.get_media_repository()
thumbnail_resouce = ThumbnailResource(
thumbnail_provider = ThumbnailProvider(
self.hs, media_repo, media_repo.media_storage
)
self.assertIsNotNone(
thumbnail_resouce._select_thumbnail(
thumbnail_provider._select_thumbnail(
desired_width=desired_size,
desired_height=desired_size,
desired_method=method,

View File

@@ -24,8 +24,8 @@ from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError
from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, keys, login, register
from synapse.rest import admin, devices, sync
from synapse.rest.client import keys, login, register
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -33,146 +33,6 @@ from synapse.util import Clock
from tests import unittest
class DeviceListsTestCase(unittest.HomeserverTestCase):
"""Tests regarding device list changes."""
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
register.register_servlets,
account.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def test_receiving_local_device_list_changes(self) -> None:
"""Tests that a local users that share a room receive each other's device list
changes.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# Create a room for them to coexist peacefully in
new_room_id = self.helper.create_room_as(
alice_user_id, is_public=True, tok=alice_access_token
)
self.assertIsNotNone(new_room_id)
# Have Bob join the room
self.helper.invite(
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
)
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
# Now have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
"/sync",
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"/sync?since={next_batch_token}&timeout=30000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that bob's incremental sync contains the updated device list.
# If not, the client would only receive the device list update on the
# *next* sync.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []
)
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
def test_not_receiving_local_device_list_changes(self) -> None:
"""Tests a local users DO NOT receive device updates from each other if they do not
share a room.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# These users do not share a room. They are lonely.
# Have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
"/sync",
access_token=bob_access_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"/sync?since={next_batch_token}&timeout=1000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that bob's incremental sync does not contain the updated device list.
bob_sync_channel.await_result()
self.assertEqual(
bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []
)
self.assertNotIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
class DevicesTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,

File diff suppressed because it is too large Load Diff

View File

@@ -18,15 +18,39 @@
# [This file includes modifications made by New Vector Limited]
#
#
from parameterized import parameterized_class
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
from synapse.types import JsonDict
from tests.unittest import HomeserverTestCase, override_config
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class SendToDeviceTestCase(HomeserverTestCase):
"""
Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
Attributes:
sync_endpoint: The endpoint under test to use for syncing.
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -34,6 +58,11 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
@@ -54,7 +83,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
@@ -67,15 +96,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
}
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should re-appear if we do another sync
channel = self.make_request("GET", "/sync", access_token=user2_tok)
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@@ -99,15 +132,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 200, chan.result)
# now sync: we should get two of the three
channel = self.make_request("GET", "/sync", access_token=user2_tok)
# now sync: we should get two of the three (because burst_count=2)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
{"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
{
"sender": user1,
"type": "m.room_key_request",
"content": {"idx": i},
},
)
sync_token = channel.json_body["next_batch"]
@@ -125,7 +162,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -159,7 +198,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
# now sync: we should get two of the three
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@@ -193,7 +232,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -217,7 +258,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
@@ -233,7 +274,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
@@ -241,7 +284,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])

View File

@@ -21,7 +21,7 @@
import json
from typing import List
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
@@ -688,24 +688,180 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding device list (`device_lists`) changes.
Attributes:
sync_endpoint: The endpoint under test to use for syncing.
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def test_receiving_local_device_list_changes(self) -> None:
"""Tests that a local users that share a room receive each other's device list
changes.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# Create a room for them to coexist peacefully in
new_room_id = self.helper.create_room_as(
alice_user_id, is_public=True, tok=alice_access_token
)
self.assertIsNotNone(new_room_id)
# Have Bob join the room
self.helper.invite(
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
)
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
# Now have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
self.sync_endpoint,
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that bob's incremental sync contains the updated device list.
# If not, the client would only receive the device list update on the
# *next* sync.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []
)
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
def test_not_receiving_local_device_list_changes(self) -> None:
"""Tests a local users DO NOT receive device updates from each other if they do not
share a room.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# These users do not share a room. They are lonely.
# Have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
self.sync_endpoint,
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that bob's incremental sync does not contain the updated device list.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []
)
self.assertNotIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates"""
device_id = "TESTDEVICE"
test_device_id = "TESTDEVICE"
# Register a user and login, creating a device
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey", device_id=device_id)
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request("GET", "/sync", access_token=self.tok)
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch = channel.json_body["next_batch"]
@@ -713,19 +869,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# It won't return until something has happened
incremental_sync_channel = self.make_request(
"GET",
f"/sync?since={next_batch}&timeout=30000",
access_token=self.tok,
f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
access_token=alice_access_token,
await_result=False,
)
# Change our device's display name
channel = self.make_request(
"PUT",
f"devices/{device_id}",
f"devices/{test_device_id}",
{
"display_name": "freeze ray",
},
access_token=self.tok,
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -739,7 +895,230 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
).get("changed", [])
self.assertIn(
self.user_id, device_list_changes, incremental_sync_channel.json_body
alice_user_id, device_list_changes, incremental_sync_channel.json_body
)
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding device one time keys (`device_one_time_keys_count`) changes.
Attributes:
sync_endpoint: The endpoint under test to use for syncing.
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.e2e_keys_handler = hs.get_e2e_keys_handler()
def test_no_device_one_time_keys(self) -> None:
"""
Tests when no one time keys set, it still has the default `signed_curve25519` in
`device_one_time_keys_count`
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertDictEqual(
channel.json_body["device_one_time_keys_count"],
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
{"signed_curve25519": 0},
channel.json_body["device_one_time_keys_count"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that one time keys for the device/user are counted correctly in the `/sync`
response
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Upload one time keys for the user/device
keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
}
res = self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id, test_device_id, {"one_time_keys": keys}
)
)
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
self.assertDictEqual(
res,
{"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertDictEqual(
channel.json_body["device_one_time_keys_count"],
{"alg1": 1, "alg2": 2, "signed_curve25519": 0},
channel.json_body["device_one_time_keys_count"],
)
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
Attributes:
sync_endpoint: The endpoint under test to use for syncing.
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.e2e_keys_handler = hs.get_e2e_keys_handler()
def test_no_device_unused_fallback_key(self) -> None:
"""
Test when no unused fallback key is set, it just returns an empty list. The MSC
says "The device_unused_fallback_key_types parameter must be present if the
server supports fallback keys.",
https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
[],
channel.json_body["device_unused_fallback_key_types"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that device unused fallback key type is returned correctly in the `/sync`
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# We shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
)
self.assertEqual(res, [])
# Upload a fallback key for the user/device
fallback_key = {"alg1:k1": "fallback_key1"}
self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id,
test_device_id,
{"fallback_keys": fallback_key},
)
)
# We should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
)
self.assertEqual(fallback_res, ["alg1"], fallback_res)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for the unused fallback key types
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
["alg1"],
channel.json_body["device_unused_fallback_key_types"],
)

View File

@@ -44,13 +44,13 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
# from a regular 404.
file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
f,
fname,
finish,
):
f.write(SMALL_PNG)
self.get_success(finish())
media_storage = hs.get_media_repository().media_storage
ctx = media_storage.store_into_file(file_info)
(f, fname) = self.get_success(ctx.__aenter__())
f.write(SMALL_PNG)
self.get_success(ctx.__aexit__(None, None, None))
self.get_success(
self.store.store_cached_remote_media(