From 97c3d988161f69821f00b722aafaea4fcb31759f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Thalheim?= Date: Tue, 18 Jun 2024 17:21:51 +0200 Subject: [PATCH 1/4] register_new_matrix_user: add password-file flag (#17294) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Co-authored-by: Andrew Morgan --- changelog.d/17294.feature | 2 ++ debian/changelog | 6 ++++++ debian/register_new_matrix_user.ronn | 8 ++++++-- synapse/_scripts/register_new_matrix_user.py | 20 +++++++++++++++----- 4 files changed, 29 insertions(+), 7 deletions(-) create mode 100644 changelog.d/17294.feature diff --git a/changelog.d/17294.feature b/changelog.d/17294.feature new file mode 100644 index 0000000000..33aac7b0bc --- /dev/null +++ b/changelog.d/17294.feature @@ -0,0 +1,2 @@ +`register_new_matrix_user` now supports a --password-file flag, which +is useful for scripting. diff --git a/debian/changelog b/debian/changelog index e9b05f8553..55e17bd868 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.109.0+nmu1) UNRELEASED; urgency=medium + + * `register_new_matrix_user` now supports a --password-file flag. + + -- Synapse Packaging team Tue, 18 Jun 2024 13:29:36 +0100 + matrix-synapse-py3 (1.109.0) stable; urgency=medium * New synapse release 1.109.0. diff --git a/debian/register_new_matrix_user.ronn b/debian/register_new_matrix_user.ronn index 0410b1f4cd..963e67c004 100644 --- a/debian/register_new_matrix_user.ronn +++ b/debian/register_new_matrix_user.ronn @@ -31,8 +31,12 @@ A sample YAML file accepted by `register_new_matrix_user` is described below: Local part of the new user. Will prompt if omitted. * `-p`, `--password`: - New password for user. Will prompt if omitted. Supplying the password - on the command line is not recommended. Use the STDIN instead. + New password for user. Will prompt if this option and `--password-file` are omitted. + Supplying the password on the command line is not recommended. + + * `--password-file`: + File containing the new password for user. If set, overrides `--password`. + This is a more secure alternative to specifying the password on the command line. * `-a`, `--admin`: Register new user as an admin. Will prompt if omitted. diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 77a7129ee2..972b35e2dc 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -173,11 +173,18 @@ def main() -> None: default=None, help="Local part of the new user. Will prompt if omitted.", ) - parser.add_argument( + password_group = parser.add_mutually_exclusive_group() + password_group.add_argument( "-p", "--password", default=None, - help="New password for user. Will prompt if omitted.", + help="New password for user. Will prompt for a password if " + "this flag and `--password-file` are both omitted.", + ) + password_group.add_argument( + "--password-file", + default=None, + help="File containing the new password for user. If set, will override `--password`.", ) parser.add_argument( "-t", @@ -247,6 +254,11 @@ def main() -> None: print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr) sys.exit(1) + if args.password_file: + password = _read_file(args.password_file, "password-file").strip() + else: + password = args.password + if args.server_url: server_url = args.server_url elif config is not None: @@ -269,9 +281,7 @@ def main() -> None: if args.admin or args.no_admin: admin = args.admin - register_new_user( - args.user, args.password, server_url, secret, admin, args.user_type - ) + register_new_user(args.user, password, server_url, secret, admin, args.user_type) def _read_file(file_path: Any, config_path: str) -> str: From 199223062aff38936aee50910418ddc81451dc9e Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 18 Jun 2024 16:54:19 +0100 Subject: [PATCH 2/4] Revert "Support MSC3916 by adding a federation `/download` endpoint" (#17325) --- changelog.d/17172.feature | 2 - changelog.d/17325.misc | 1 + .../federation/transport/server/__init__.py | 24 -- synapse/federation/transport/server/_base.py | 24 +- .../federation/transport/server/federation.py | 41 --- synapse/media/_base.py | 63 +---- synapse/media/media_repository.py | 18 +- synapse/media/media_storage.py | 223 +---------------- synapse/media/storage_provider.py | 40 +-- tests/federation/test_federation_media.py | 234 ------------------ tests/media/test_media_storage.py | 14 +- 11 files changed, 25 insertions(+), 659 deletions(-) delete mode 100644 changelog.d/17172.feature create mode 100644 changelog.d/17325.misc delete mode 100644 tests/federation/test_federation_media.py diff --git a/changelog.d/17172.feature b/changelog.d/17172.feature deleted file mode 100644 index 245dea815c..0000000000 --- a/changelog.d/17172.feature +++ /dev/null @@ -1,2 +0,0 @@ -Support [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md) -by adding a federation /download endpoint (#17172). \ No newline at end of file diff --git a/changelog.d/17325.misc b/changelog.d/17325.misc new file mode 100644 index 0000000000..1a4ce7ceec --- /dev/null +++ b/changelog.d/17325.misc @@ -0,0 +1 @@ +This is a changelog so tests will run. \ No newline at end of file diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 266675c9b8..bac569e977 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -19,7 +19,6 @@ # [This file includes modifications made by New Vector Limited] # # -import inspect import logging from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type @@ -34,7 +33,6 @@ from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, FederationAccountStatusServlet, FederationUnstableClientKeysClaimServlet, - FederationUnstableMediaDownloadServlet, ) from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( @@ -317,28 +315,6 @@ def register_servlets( ): continue - if servletclass == FederationUnstableMediaDownloadServlet: - if ( - not hs.config.server.enable_media_repo - or not hs.config.experimental.msc3916_authenticated_media_enabled - ): - continue - - # don't load the endpoint if the storage provider is incompatible - media_repo = hs.get_media_repository() - load_download_endpoint = True - for provider in media_repo.media_storage.storage_providers: - signature = inspect.signature(provider.backend.fetch) - if "federation" not in signature.parameters: - logger.warning( - f"Federation media `/download` endpoint will not be enabled as storage provider {provider.backend} is not compatible with this endpoint." - ) - load_download_endpoint = False - break - - if not load_download_endpoint: - continue - servletclass( hs=hs, authenticator=authenticator, diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 4e2717b565..db0f5076a9 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -360,29 +360,13 @@ class BaseFederationServlet: "request" ) return None - if ( - func.__self__.__class__.__name__ # type: ignore - == "FederationUnstableMediaDownloadServlet" - ): - response = await func( - origin, content, request, *args, **kwargs - ) - else: - response = await func( - origin, content, request.args, *args, **kwargs - ) - else: - if ( - func.__self__.__class__.__name__ # type: ignore - == "FederationUnstableMediaDownloadServlet" - ): - response = await func( - origin, content, request, *args, **kwargs - ) - else: response = await func( origin, content, request.args, *args, **kwargs ) + else: + response = await func( + origin, content, request.args, *args, **kwargs + ) finally: # if we used the origin's context as the parent, add a new span using # the servlet span as a parent, so that we have a link diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 1f02451efa..a59734785f 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -44,13 +44,10 @@ from synapse.federation.transport.server._base import ( ) from synapse.http.servlet import ( parse_boolean_from_args, - parse_integer, parse_integer_from_args, parse_string_from_args, parse_strings_from_args, ) -from synapse.http.site import SynapseRequest -from synapse.media._base import DEFAULT_MAX_TIMEOUT_MS, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS from synapse.types import JsonDict from synapse.util import SYNAPSE_VERSION from synapse.util.ratelimitutils import FederationRateLimiter @@ -790,43 +787,6 @@ class FederationAccountStatusServlet(BaseFederationServerServlet): return 200, {"account_statuses": statuses, "failures": failures} -class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet): - """ - Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns - a multipart/form-data response consisting of a JSON object and the requested media - item. This endpoint only returns local media. - """ - - PATH = "/media/download/(?P[^/]*)" - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916" - RATELIMIT = True - - def __init__( - self, - hs: "HomeServer", - ratelimiter: FederationRateLimiter, - authenticator: Authenticator, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.media_repo = self.hs.get_media_repository() - - async def on_GET( - self, - origin: Optional[str], - content: Literal[None], - request: SynapseRequest, - media_id: str, - ) -> None: - max_timeout_ms = parse_integer( - request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS - ) - max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS) - await self.media_repo.get_local_media( - request, media_id, None, max_timeout_ms, federation=True - ) - - FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationSendServlet, FederationEventServlet, @@ -858,5 +818,4 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationV1SendKnockServlet, FederationMakeKnockServlet, FederationAccountStatusServlet, - FederationUnstableMediaDownloadServlet, ) diff --git a/synapse/media/_base.py b/synapse/media/_base.py index 19bca94170..3fbed6062f 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -25,16 +25,7 @@ import os import urllib from abc import ABC, abstractmethod from types import TracebackType -from typing import ( - TYPE_CHECKING, - Awaitable, - Dict, - Generator, - List, - Optional, - Tuple, - Type, -) +from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type import attr @@ -48,11 +39,6 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.util.stringutils import is_ascii -if TYPE_CHECKING: - from synapse.media.media_storage import MultipartResponder - from synapse.storage.databases.main.media_repository import LocalMedia - - logger = logging.getLogger(__name__) # list all text content types that will have the charset default to UTF-8 when @@ -274,53 +260,6 @@ def _can_encode_filename_as_token(x: str) -> bool: return True -async def respond_with_multipart_responder( - request: SynapseRequest, - responder: "Optional[MultipartResponder]", - media_info: "LocalMedia", -) -> None: - """ - Responds via a Multipart responder for the federation media `/download` requests - - Args: - request: the federation request to respond to - responder: the Multipart responder which will send the response - media_info: metadata about the media item - """ - if not responder: - respond_404(request) - return - - # If we have a responder we *must* use it as a context manager. - with responder: - if request._disconnected: - logger.warning( - "Not sending response to request %s, already disconnected.", request - ) - return - - logger.debug("Responding to media request with responder %s", responder) - if media_info.media_length is not None: - request.setHeader(b"Content-Length", b"%d" % (media_info.media_length,)) - request.setHeader( - b"Content-Type", b"multipart/mixed; boundary=%s" % responder.boundary - ) - - try: - await responder.write_to_consumer(request) - except Exception as e: - # The majority of the time this will be due to the client having gone - # away. Unfortunately, Twisted simply throws a generic exception at us - # in that case. - logger.warning("Failed to write to consumer: %s %s", type(e), e) - - # Unregister the producer, if it has one, so Twisted doesn't complain - if request.producer: - request.unregisterProducer() - - finish_request(request) - - async def respond_with_responder( request: SynapseRequest, responder: "Optional[Responder]", diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index c335e518a0..6ed56099ca 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -54,11 +54,10 @@ from synapse.media._base import ( ThumbnailInfo, get_filename_from_headers, respond_404, - respond_with_multipart_responder, respond_with_responder, ) from synapse.media.filepath import MediaFilePaths -from synapse.media.media_storage import MediaStorage, MultipartResponder +from synapse.media.media_storage import MediaStorage from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer @@ -430,7 +429,6 @@ class MediaRepository: media_id: str, name: Optional[str], max_timeout_ms: int, - federation: bool = False, ) -> None: """Responds to requests for local media, if exists, or returns 404. @@ -442,7 +440,6 @@ class MediaRepository: the filename in the Content-Disposition header of the response. max_timeout_ms: the maximum number of milliseconds to wait for the media to be uploaded. - federation: whether the local media being fetched is for a federation request Returns: Resolves once a response has successfully been written to request @@ -462,17 +459,10 @@ class MediaRepository: file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) - responder = await self.media_storage.fetch_media( - file_info, media_info, federation + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder( + request, responder, media_type, media_length, upload_name ) - if federation: - # this really should be a Multipart responder but just in case - assert isinstance(responder, MultipartResponder) - await respond_with_multipart_responder(request, responder, media_info) - else: - await respond_with_responder( - request, responder, media_type, media_length, upload_name - ) async def get_remote_media( self, diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index 2f55d12b6b..b3cd3fd8f4 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -19,12 +19,9 @@ # # import contextlib -import json import logging import os import shutil -from contextlib import closing -from io import BytesIO from types import TracebackType from typing import ( IO, @@ -33,19 +30,14 @@ from typing import ( AsyncIterator, BinaryIO, Callable, - List, Optional, Sequence, Tuple, Type, - Union, ) -from uuid import uuid4 import attr -from zope.interface import implementer -from twisted.internet import defer, interfaces from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender @@ -56,19 +48,15 @@ from synapse.logging.opentracing import start_active_span, trace, trace_with_opn from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer -from ..storage.databases.main.media_repository import LocalMedia -from ..types import JsonDict from ._base import FileInfo, Responder from .filepath import MediaFilePaths if TYPE_CHECKING: - from synapse.media.storage_provider import StorageProviderWrapper + from synapse.media.storage_provider import StorageProvider from synapse.server import HomeServer logger = logging.getLogger(__name__) -CRLF = b"\r\n" - class MediaStorage: """Responsible for storing/fetching files from local sources. @@ -85,7 +73,7 @@ class MediaStorage: hs: "HomeServer", local_media_directory: str, filepaths: MediaFilePaths, - storage_providers: Sequence["StorageProviderWrapper"], + storage_providers: Sequence["StorageProvider"], ): self.hs = hs self.reactor = hs.get_reactor() @@ -181,23 +169,15 @@ class MediaStorage: raise e from None - async def fetch_media( - self, - file_info: FileInfo, - media_info: Optional[LocalMedia] = None, - federation: bool = False, - ) -> Optional[Responder]: + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: """Attempts to fetch media described by file_info from the local cache and configured storage providers. Args: - file_info: Metadata about the media file - media_info: Metadata about the media item - federation: Whether this file is being fetched for a federation request + file_info Returns: - If the file was found returns a Responder (a Multipart Responder if the requested - file is for the federation /download endpoint), otherwise None. + Returns a Responder if the file was found, otherwise None. """ paths = [self._file_info_to_path(file_info)] @@ -217,19 +197,12 @@ class MediaStorage: local_path = os.path.join(self.local_media_directory, path) if os.path.exists(local_path): logger.debug("responding with local file %s", local_path) - if federation: - assert media_info is not None - boundary = uuid4().hex.encode("ascii") - return MultipartResponder( - open(local_path, "rb"), media_info, boundary - ) - else: - return FileResponder(open(local_path, "rb")) + return FileResponder(open(local_path, "rb")) logger.debug("local file %s did not exist", local_path) for provider in self.storage_providers: for path in paths: - res: Any = await provider.fetch(path, file_info, media_info, federation) + res: Any = await provider.fetch(path, file_info) if res: logger.debug("Streaming %s from %s", path, provider) return res @@ -343,7 +316,7 @@ class FileResponder(Responder): """Wraps an open file that can be sent to a request. Args: - open_file: A file like object to be streamed to the client, + open_file: A file like object to be streamed ot the client, is closed when finished streaming. """ @@ -364,38 +337,6 @@ class FileResponder(Responder): self.open_file.close() -class MultipartResponder(Responder): - """Wraps an open file, formats the response according to MSC3916 and sends it to a - federation request. - - Args: - open_file: A file like object to be streamed to the client, - is closed when finished streaming. - media_info: metadata about the media item - boundary: bytes to use for the multipart response boundary - """ - - def __init__(self, open_file: IO, media_info: LocalMedia, boundary: bytes) -> None: - self.open_file = open_file - self.media_info = media_info - self.boundary = boundary - - def write_to_consumer(self, consumer: IConsumer) -> Deferred: - return make_deferred_yieldable( - MultipartFileSender().beginFileTransfer( - self.open_file, consumer, self.media_info.media_type, {}, self.boundary - ) - ) - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - self.open_file.close() - - class SpamMediaException(NotFoundError): """The media was blocked by a spam checker, so we simply 404 the request (in the same way as if it was quarantined). @@ -429,151 +370,3 @@ class ReadableFileWrapper: # We yield to the reactor by sleeping for 0 seconds. await self.clock.sleep(0) - - -@implementer(interfaces.IProducer) -class MultipartFileSender: - """ - A producer that sends the contents of a file to a federation request in the format - outlined in MSC3916 - a multipart/format-data response where the first field is a - JSON object and the second is the requested file. - - This is a slight re-writing of twisted.protocols.basic.FileSender to achieve the format - outlined above. - """ - - CHUNK_SIZE = 2**14 - - lastSent = "" - deferred: Optional[defer.Deferred] = None - - def beginFileTransfer( - self, - file: IO, - consumer: IConsumer, - file_content_type: str, - json_object: JsonDict, - boundary: bytes, - ) -> Deferred: - """ - Begin transferring a file - - Args: - file: The file object to read data from - consumer: The synapse request to write the data to - file_content_type: The content-type of the file - json_object: The JSON object to write to the first field of the response - boundary: bytes to be used as the multipart/form-data boundary - - Returns: A deferred whose callback will be invoked when the file has - been completely written to the consumer. The last byte written to the - consumer is passed to the callback. - """ - self.file: Optional[IO] = file - self.consumer = consumer - self.json_field = json_object - self.json_field_written = False - self.content_type_written = False - self.file_content_type = file_content_type - self.boundary = boundary - self.deferred: Deferred = defer.Deferred() - self.consumer.registerProducer(self, False) - # while it's not entirely clear why this assignment is necessary, it mirrors - # the behavior in FileSender.beginFileTransfer and thus is preserved here - deferred = self.deferred - return deferred - - def resumeProducing(self) -> None: - # write the first field, which will always be a json field - if not self.json_field_written: - self.consumer.write(CRLF + b"--" + self.boundary + CRLF) - - content_type = Header(b"Content-Type", b"application/json") - self.consumer.write(bytes(content_type) + CRLF) - - json_field = json.dumps(self.json_field) - json_bytes = json_field.encode("utf-8") - self.consumer.write(json_bytes) - self.consumer.write(CRLF + b"--" + self.boundary + CRLF) - - self.json_field_written = True - - chunk: Any = "" - if self.file: - # if we haven't written the content type yet, do so - if not self.content_type_written: - type = self.file_content_type.encode("utf-8") - content_type = Header(b"Content-Type", type) - self.consumer.write(bytes(content_type) + CRLF) - self.content_type_written = True - - chunk = self.file.read(self.CHUNK_SIZE) - - if not chunk: - # we've reached the end of the file - self.consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF) - self.file = None - self.consumer.unregisterProducer() - - if self.deferred: - self.deferred.callback(self.lastSent) - self.deferred = None - return - - self.consumer.write(chunk) - self.lastSent = chunk[-1:] - - def pauseProducing(self) -> None: - pass - - def stopProducing(self) -> None: - if self.deferred: - self.deferred.errback(Exception("Consumer asked us to stop producing")) - self.deferred = None - - -class Header: - """ - `Header` This class is a tiny wrapper that produces - request headers. We can't use standard python header - class because it encodes unicode fields using =? bla bla ?= - encoding, which is correct, but no one in HTTP world expects - that, everyone wants utf-8 raw bytes. (stolen from treq.multipart) - - """ - - def __init__( - self, - name: bytes, - value: Any, - params: Optional[List[Tuple[Any, Any]]] = None, - ): - self.name = name - self.value = value - self.params = params or [] - - def add_param(self, name: Any, value: Any) -> None: - self.params.append((name, value)) - - def __bytes__(self) -> bytes: - with closing(BytesIO()) as h: - h.write(self.name + b": " + escape(self.value).encode("us-ascii")) - if self.params: - for name, val in self.params: - h.write(b"; ") - h.write(escape(name).encode("us-ascii")) - h.write(b"=") - h.write(b'"' + escape(val).encode("utf-8") + b'"') - h.seek(0) - return h.read() - - -def escape(value: Union[str, bytes]) -> str: - """ - This function prevents header values from corrupting the request, - a newline in the file name parameter makes form-data request unreadable - for a majority of parsers. (stolen from treq.multipart) - """ - if isinstance(value, bytes): - value = value.decode("utf-8") - return value.replace("\r", "").replace("\n", "").replace('"', '\\"') diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py index a2d50adf65..06e5d27a53 100644 --- a/synapse/media/storage_provider.py +++ b/synapse/media/storage_provider.py @@ -24,16 +24,14 @@ import logging import os import shutil from typing import TYPE_CHECKING, Callable, Optional -from uuid import uuid4 from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.opentracing import start_active_span, trace_with_opname from synapse.util.async_helpers import maybe_awaitable -from ..storage.databases.main.media_repository import LocalMedia from ._base import FileInfo, Responder -from .media_storage import FileResponder, MultipartResponder +from .media_storage import FileResponder logger = logging.getLogger(__name__) @@ -57,21 +55,13 @@ class StorageProvider(metaclass=abc.ABCMeta): """ @abc.abstractmethod - async def fetch( - self, - path: str, - file_info: FileInfo, - media_info: Optional[LocalMedia] = None, - federation: bool = False, - ) -> Optional[Responder]: + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. Args: path: Relative path of file in local cache file_info: The metadata of the file. - media_info: metadata of the media item - federation: Whether the requested media is for a federation request Returns: Returns a Responder if the provider has the file, otherwise returns None. @@ -134,13 +124,7 @@ class StorageProviderWrapper(StorageProvider): run_in_background(store) @trace_with_opname("StorageProviderWrapper.fetch") - async def fetch( - self, - path: str, - file_info: FileInfo, - media_info: Optional[LocalMedia] = None, - federation: bool = False, - ) -> Optional[Responder]: + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: if file_info.url_cache: # Files in the URL preview cache definitely aren't stored here, # so avoid any potentially slow I/O or network access. @@ -148,9 +132,7 @@ class StorageProviderWrapper(StorageProvider): # store_file is supposed to return an Awaitable, but guard # against improper implementations. - return await maybe_awaitable( - self.backend.fetch(path, file_info, media_info, federation) - ) + return await maybe_awaitable(self.backend.fetch(path, file_info)) class FileStorageProviderBackend(StorageProvider): @@ -190,23 +172,11 @@ class FileStorageProviderBackend(StorageProvider): ) @trace_with_opname("FileStorageProviderBackend.fetch") - async def fetch( - self, - path: str, - file_info: FileInfo, - media_info: Optional[LocalMedia] = None, - federation: bool = False, - ) -> Optional[Responder]: + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): - if federation: - assert media_info is not None - boundary = uuid4().hex.encode("ascii") - return MultipartResponder( - open(backup_fname, "rb"), media_info, boundary - ) return FileResponder(open(backup_fname, "rb")) return None diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py deleted file mode 100644 index 1c89d19e99..0000000000 --- a/tests/federation/test_federation_media.py +++ /dev/null @@ -1,234 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# . -# -# Originally licensed under the Apache License, Version 2.0: -# . -# -# [This file includes modifications made by New Vector Limited] -# -# -import io -import os -import shutil -import tempfile -from typing import Optional - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.media._base import FileInfo, Responder -from synapse.media.filepath import MediaFilePaths -from synapse.media.media_storage import MediaStorage -from synapse.media.storage_provider import ( - FileStorageProviderBackend, - StorageProviderWrapper, -) -from synapse.server import HomeServer -from synapse.storage.databases.main.media_repository import LocalMedia -from synapse.types import JsonDict, UserID -from synapse.util import Clock - -from tests import unittest -from tests.test_utils import SMALL_PNG -from tests.unittest import override_config - - -class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase): - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - super().prepare(reactor, clock, hs) - self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") - self.addCleanup(shutil.rmtree, self.test_dir) - self.primary_base_path = os.path.join(self.test_dir, "primary") - self.secondary_base_path = os.path.join(self.test_dir, "secondary") - - hs.config.media.media_store_path = self.primary_base_path - - storage_providers = [ - StorageProviderWrapper( - FileStorageProviderBackend(hs, self.secondary_base_path), - store_local=True, - store_remote=False, - store_synchronous=True, - ) - ] - - self.filepaths = MediaFilePaths(self.primary_base_path) - self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers - ) - self.media_repo = hs.get_media_repository() - - @override_config( - {"experimental_features": {"msc3916_authenticated_media_enabled": True}} - ) - def test_file_download(self) -> None: - content = io.BytesIO(b"file_to_stream") - content_uri = self.get_success( - self.media_repo.create_content( - "text/plain", - "test_upload", - content, - 46, - UserID.from_string("@user_id:whatever.org"), - ) - ) - # test with a text file - channel = self.make_signed_federation_request( - "GET", - f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}", - ) - self.pump() - self.assertEqual(200, channel.code) - - content_type = channel.headers.getRawHeaders("content-type") - assert content_type is not None - assert "multipart/mixed" in content_type[0] - assert "boundary" in content_type[0] - - # extract boundary - boundary = content_type[0].split("boundary=")[1] - # split on boundary and check that json field and expected value exist - stripped = channel.text_body.split("\r\n" + "--" + boundary) - # TODO: the json object expected will change once MSC3911 is implemented, currently - # {} is returned for all requests as a placeholder (per MSC3196) - found_json = any( - "\r\nContent-Type: application/json\r\n{}" in field for field in stripped - ) - self.assertTrue(found_json) - - # check that text file and expected value exist - found_file = any( - "\r\nContent-Type: text/plain\r\nfile_to_stream" in field - for field in stripped - ) - self.assertTrue(found_file) - - content = io.BytesIO(SMALL_PNG) - content_uri = self.get_success( - self.media_repo.create_content( - "image/png", - "test_png_upload", - content, - 67, - UserID.from_string("@user_id:whatever.org"), - ) - ) - # test with an image file - channel = self.make_signed_federation_request( - "GET", - f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}", - ) - self.pump() - self.assertEqual(200, channel.code) - - content_type = channel.headers.getRawHeaders("content-type") - assert content_type is not None - assert "multipart/mixed" in content_type[0] - assert "boundary" in content_type[0] - - # extract boundary - boundary = content_type[0].split("boundary=")[1] - # split on boundary and check that json field and expected value exist - body = channel.result.get("body") - assert body is not None - stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8")) - found_json = any( - b"\r\nContent-Type: application/json\r\n{}" in field - for field in stripped_bytes - ) - self.assertTrue(found_json) - - # check that png file exists and matches what was uploaded - found_file = any(SMALL_PNG in field for field in stripped_bytes) - self.assertTrue(found_file) - - @override_config( - {"experimental_features": {"msc3916_authenticated_media_enabled": False}} - ) - def test_disable_config(self) -> None: - content = io.BytesIO(b"file_to_stream") - content_uri = self.get_success( - self.media_repo.create_content( - "text/plain", - "test_upload", - content, - 46, - UserID.from_string("@user_id:whatever.org"), - ) - ) - channel = self.make_signed_federation_request( - "GET", - f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}", - ) - self.pump() - self.assertEqual(404, channel.code) - self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED") - - -class FakeFileStorageProviderBackend: - """ - Fake storage provider stub with incompatible `fetch` signature for testing - """ - - def __init__(self, hs: "HomeServer", config: str): - self.hs = hs - self.cache_directory = hs.config.media.media_store_path - self.base_directory = config - - def __str__(self) -> str: - return "FakeFileStorageProviderBackend[%s]" % (self.base_directory,) - - async def fetch( - self, path: str, file_info: FileInfo, media_info: Optional[LocalMedia] = None - ) -> Optional[Responder]: - pass - - -TEST_DIR = tempfile.mkdtemp(prefix="synapse-tests-") - - -class FederationUnstableMediaEndpointCompatibilityTest( - unittest.FederatingHomeserverTestCase -): - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - super().prepare(reactor, clock, hs) - self.test_dir = TEST_DIR - self.addCleanup(shutil.rmtree, self.test_dir) - self.media_repo = hs.get_media_repository() - - def default_config(self) -> JsonDict: - config = super().default_config() - primary_base_path = os.path.join(TEST_DIR, "primary") - config["media_storage_providers"] = [ - { - "module": "tests.federation.test_federation_media.FakeFileStorageProviderBackend", - "store_local": "True", - "store_remote": "False", - "store_synchronous": "False", - "config": {"directory": primary_base_path}, - } - ] - return config - - @override_config( - {"experimental_features": {"msc3916_authenticated_media_enabled": True}} - ) - def test_incompatible_storage_provider_fails_to_load_endpoint(self) -> None: - channel = self.make_signed_federation_request( - "GET", - "/_matrix/federation/unstable/org.matrix.msc3916/media/download/xyz", - ) - self.pump() - self.assertEqual(404, channel.code) - self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED") diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 47a89e9c66..46d20ce775 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -49,10 +49,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.media._base import FileInfo, ThumbnailInfo from synapse.media.filepath import MediaFilePaths from synapse.media.media_storage import MediaStorage, ReadableFileWrapper -from synapse.media.storage_provider import ( - FileStorageProviderBackend, - StorageProviderWrapper, -) +from synapse.media.storage_provider import FileStorageProviderBackend from synapse.media.thumbnailer import ThumbnailProvider from synapse.module_api import ModuleApi from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers @@ -81,14 +78,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): hs.config.media.media_store_path = self.primary_base_path - storage_providers = [ - StorageProviderWrapper( - FileStorageProviderBackend(hs, self.secondary_base_path), - store_local=True, - store_remote=False, - store_synchronous=True, - ) - ] + storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)] self.filepaths = MediaFilePaths(self.primary_base_path) self.media_storage = MediaStorage( From afaf2d9388f7012d0500932dad0af4bdb8d40d20 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 19 Jun 2024 10:05:39 +0100 Subject: [PATCH 3/4] Require the 'from' parameter for `/notifications` be an integer (#17283) Co-authored-by: Erik Johnston --- changelog.d/17283.bugfix | 1 + synapse/rest/client/notifications.py | 18 +- .../databases/main/event_push_actions.py | 2 +- tests/module_api/test_api.py | 2 +- tests/rest/client/test_notifications.py | 171 ++++++++++++++++-- 5 files changed, 173 insertions(+), 21 deletions(-) create mode 100644 changelog.d/17283.bugfix diff --git a/changelog.d/17283.bugfix b/changelog.d/17283.bugfix new file mode 100644 index 0000000000..98c1f05cc2 --- /dev/null +++ b/changelog.d/17283.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where an invalid 'from' parameter to [`/notifications`](https://spec.matrix.org/v1.10/client-server-api/#get_matrixclientv3notifications) would result in an Internal Server Error. \ No newline at end of file diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index be9b584748..168ce50d3f 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -32,6 +32,7 @@ from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.types import JsonDict +from ...api.errors import SynapseError from ._base import client_patterns if TYPE_CHECKING: @@ -56,7 +57,22 @@ class NotificationsServlet(RestServlet): requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() - from_token = parse_string(request, "from", required=False) + # While this is intended to be "string" to clients, the 'from' token + # is actually based on a numeric ID. So it must parse to an int. + from_token_str = parse_string(request, "from", required=False) + if from_token_str is not None: + # Parse to an integer. + try: + from_token = int(from_token_str) + except ValueError: + # If it doesn't parse to an integer, then this cannot possibly be a valid + # pagination token, as we only hand out integers. + raise SynapseError( + 400, 'Query parameter "from" contains unrecognised token' + ) + else: + from_token = None + limit = parse_integer(request, "limit", default=50) only = parse_string(request, "only", required=False) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index bdd0781c48..0ebf5b53d5 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -1829,7 +1829,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas async def get_push_actions_for_user( self, user_id: str, - before: Optional[str] = None, + before: Optional[int] = None, limit: int = 50, only_highlight: bool = False, ) -> List[UserPushAction]: diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 5eb1406a06..b6ba472d7d 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -688,7 +688,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): channel = self.make_request( "GET", - "/notifications?from=", + "/notifications", access_token=tok, ) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py index e9aa2e450e..e4b0455ce8 100644 --- a/tests/rest/client/test_notifications.py +++ b/tests/rest/client/test_notifications.py @@ -18,6 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # +from typing import List, Optional, Tuple from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -48,6 +49,14 @@ class HTTPPusherTests(HomeserverTestCase): self.sync_handler = homeserver.get_sync_handler() self.auth_handler = homeserver.get_auth_handler() + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + # Create a room + self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. fed_transport_client = Mock(spec=["send_transaction"]) @@ -61,32 +70,22 @@ class HTTPPusherTests(HomeserverTestCase): """ Local users will get notified for invites """ - - user_id = self.register_user("user", "pass") - access_token = self.login("user", "pass") - other_user_id = self.register_user("otheruser", "pass") - other_access_token = self.login("otheruser", "pass") - - # Create a room - room = self.helper.create_room_as(user_id, tok=access_token) - # Check we start with no pushes - channel = self.make_request( - "GET", - "/notifications", - access_token=other_access_token, - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(len(channel.json_body["notifications"]), 0, channel.json_body) + self._request_notifications(from_token=None, limit=1, expected_count=0) # Send an invite - self.helper.invite(room=room, src=user_id, targ=other_user_id, tok=access_token) + self.helper.invite( + room=self.room_id, + src=self.user_id, + targ=self.other_user_id, + tok=self.access_token, + ) # We should have a notification now channel = self.make_request( "GET", "/notifications", - access_token=other_access_token, + access_token=self.other_access_token, ) self.assertEqual(channel.code, 200) self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) @@ -95,3 +94,139 @@ class HTTPPusherTests(HomeserverTestCase): "invite", channel.json_body, ) + + def test_pagination_of_notifications(self) -> None: + """ + Check that pagination of notifications works. + """ + # Check we start with no pushes + self._request_notifications(from_token=None, limit=1, expected_count=0) + + # Send an invite and have the other user join the room. + self.helper.invite( + room=self.room_id, + src=self.user_id, + targ=self.other_user_id, + tok=self.access_token, + ) + self.helper.join(self.room_id, self.other_user_id, tok=self.other_access_token) + + # Send 5 messages in the room and note down their event IDs. + sent_event_ids = [] + for _ in range(5): + resp = self.helper.send_event( + self.room_id, + "m.room.message", + {"body": "honk", "msgtype": "m.text"}, + tok=self.access_token, + ) + sent_event_ids.append(resp["event_id"]) + + # We expect to get notifications for messages in reverse order. + # So reverse this list of event IDs to make it easier to compare + # against later. + sent_event_ids.reverse() + + # We should have a few notifications now. Let's try and fetch the first 2. + notification_event_ids, _ = self._request_notifications( + from_token=None, limit=2, expected_count=2 + ) + + # Check we got the expected event IDs back. + self.assertEqual(notification_event_ids, sent_event_ids[:2]) + + # Try requesting again without a 'from' query parameter. We should get the + # same two notifications back. + notification_event_ids, next_token = self._request_notifications( + from_token=None, limit=2, expected_count=2 + ) + self.assertEqual(notification_event_ids, sent_event_ids[:2]) + + # Ask for the next 5 notifications, though there should only be + # 4 remaining; the next 3 messages and the invite. + # + # We need to use the "next_token" from the response as the "from" + # query parameter in the next request in order to paginate. + notification_event_ids, next_token = self._request_notifications( + from_token=next_token, limit=5, expected_count=4 + ) + # Ensure we chop off the invite on the end. + notification_event_ids = notification_event_ids[:-1] + self.assertEqual(notification_event_ids, sent_event_ids[2:]) + + def _request_notifications( + self, from_token: Optional[str], limit: int, expected_count: int + ) -> Tuple[List[str], str]: + """ + Make a request to /notifications to get the latest events to be notified about. + + Only the event IDs are returned. The request is made by the "other user". + + Args: + from_token: An optional starting parameter. + limit: The maximum number of results to return. + expected_count: The number of events to expect in the response. + + Returns: + A list of event IDs that the client should be notified about. + Events are returned newest-first. + """ + # Construct the request path. + path = f"/notifications?limit={limit}" + if from_token is not None: + path += f"&from={from_token}" + + channel = self.make_request( + "GET", + path, + access_token=self.other_access_token, + ) + + self.assertEqual(channel.code, 200) + self.assertEqual( + len(channel.json_body["notifications"]), expected_count, channel.json_body + ) + + # Extract the necessary data from the response. + next_token = channel.json_body["next_token"] + event_ids = [ + event["event"]["event_id"] for event in channel.json_body["notifications"] + ] + + return event_ids, next_token + + def test_parameters(self) -> None: + """ + Test that appropriate errors are returned when query parameters are malformed. + """ + # Test that no parameters are required. + channel = self.make_request( + "GET", + "/notifications", + access_token=self.other_access_token, + ) + self.assertEqual(channel.code, 200) + + # Test that limit cannot be negative + channel = self.make_request( + "GET", + "/notifications?limit=-1", + access_token=self.other_access_token, + ) + self.assertEqual(channel.code, 400) + + # Test that the 'limit' parameter must be an integer. + channel = self.make_request( + "GET", + "/notifications?limit=foobar", + access_token=self.other_access_token, + ) + self.assertEqual(channel.code, 400) + + # Test that the 'from' parameter must be an integer. + channel = self.make_request( + "GET", + "/notifications?from=osborne", + access_token=self.other_access_token, + ) + self.assertEqual(channel.code, 400) From bdf82efea505c488953b46eb681b5a63c4e9655d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 19 Jun 2024 10:33:53 +0100 Subject: [PATCH 4/4] Handle large chain calc better (#17291) We calculate the auth chain links outside of the main persist event transaction to ensure that we do not block other event sending during the calculation. --- changelog.d/17291.misc | 1 + synapse/storage/controllers/persist_events.py | 12 + synapse/storage/databases/main/events.py | 261 +++++++++++++----- tests/storage/test_event_chain.py | 9 +- tests/storage/test_event_federation.py | 41 ++- 5 files changed, 236 insertions(+), 88 deletions(-) create mode 100644 changelog.d/17291.misc diff --git a/changelog.d/17291.misc b/changelog.d/17291.misc new file mode 100644 index 0000000000..b1f89a324d --- /dev/null +++ b/changelog.d/17291.misc @@ -0,0 +1 @@ +Do not block event sending/receiving while calulating large event auth chains. diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 84699a2ee1..d0e015bf19 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -617,6 +617,17 @@ class EventsPersistenceStorageController: room_id, chunk ) + with Measure(self._clock, "calculate_chain_cover_index_for_events"): + # We now calculate chain ID/sequence numbers for any state events we're + # persisting. We ignore out of band memberships as we're not in the room + # and won't have their auth chain (we'll fix it up later if we join the + # room). + # + # See: docs/auth_chain_difference_algorithm.md + new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events( + room_id, [e for e, _ in chunk] + ) + await self.persist_events_store._persist_events_and_state_updates( room_id, chunk, @@ -624,6 +635,7 @@ class EventsPersistenceStorageController: new_forward_extremities=new_forward_extremities, use_negative_stream_ordering=backfilled, inhibit_local_membership_updates=backfilled, + new_event_links=new_event_links, ) return replaced_events diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 66428e6c8e..c6df13c064 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -34,7 +34,6 @@ from typing import ( Optional, Set, Tuple, - Union, cast, ) @@ -100,6 +99,23 @@ class DeltaState: return not self.to_delete and not self.to_insert and not self.no_longer_in_room +@attr.s(slots=True, auto_attribs=True) +class NewEventChainLinks: + """Information about new auth chain links that need to be added to the DB. + + Attributes: + chain_id, sequence_number: the IDs corresponding to the event being + inserted, and the starting point of the links + links: Lists the links that need to be added, 2-tuple of the chain + ID/sequence number of the end point of the link. + """ + + chain_id: int + sequence_number: int + + links: List[Tuple[int, int]] = attr.Factory(list) + + class PersistEventsStore: """Contains all the functions for writing events to the database. @@ -148,6 +164,7 @@ class PersistEventsStore: *, state_delta_for_room: Optional[DeltaState], new_forward_extremities: Optional[Set[str]], + new_event_links: Dict[str, NewEventChainLinks], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, ) -> None: @@ -217,6 +234,7 @@ class PersistEventsStore: inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, new_forward_extremities=new_forward_extremities, + new_event_links=new_event_links, ) persist_event_counter.inc(len(events_and_contexts)) @@ -243,6 +261,87 @@ class PersistEventsStore: (room_id,), frozenset(new_forward_extremities) ) + async def calculate_chain_cover_index_for_events( + self, room_id: str, events: Collection[EventBase] + ) -> Dict[str, NewEventChainLinks]: + # Filter to state events, and ensure there are no duplicates. + state_events = [] + seen_events = set() + for event in events: + if not event.is_state() or event.event_id in seen_events: + continue + + state_events.append(event) + seen_events.add(event.event_id) + + if not state_events: + return {} + + return await self.db_pool.runInteraction( + "_calculate_chain_cover_index_for_events", + self.calculate_chain_cover_index_for_events_txn, + room_id, + state_events, + ) + + def calculate_chain_cover_index_for_events_txn( + self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase] + ) -> Dict[str, NewEventChainLinks]: + # We now calculate chain ID/sequence numbers for any state events we're + # persisting. We ignore out of band memberships as we're not in the room + # and won't have their auth chain (we'll fix it up later if we join the + # room). + # + # See: docs/auth_chain_difference_algorithm.md + + # We ignore legacy rooms that we aren't filling the chain cover index + # for. + row = self.db_pool.simple_select_one_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("room_id", "has_auth_chain_index"), + allow_none=True, + ) + if row is None: + return {} + + # Filter out already persisted events. + rows = self.db_pool.simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=[e.event_id for e in state_events], + keyvalues={}, + retcols=("event_id",), + ) + already_persisted_events = {event_id for event_id, in rows} + state_events = [ + event + for event in state_events + if event.event_id in already_persisted_events + ] + + if not state_events: + return {} + + # We need to know the type/state_key and auth events of the events we're + # calculating chain IDs for. We don't rely on having the full Event + # instances as we'll potentially be pulling more events from the DB and + # we don't need the overhead of fetching/parsing the full event JSON. + event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events} + event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events} + event_to_room_id = {e.event_id: e.room_id for e in state_events} + + return self._calculate_chain_cover_index( + txn, + self.db_pool, + self.store.event_chain_id_gen, + event_to_room_id, + event_to_types, + event_to_auth_chain, + ) + async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: """Filter the supplied list of event_ids to get those which are prev_events of existing (non-outlier/rejected) events. @@ -358,6 +457,7 @@ class PersistEventsStore: inhibit_local_membership_updates: bool, state_delta_for_room: Optional[DeltaState], new_forward_extremities: Optional[Set[str]], + new_event_links: Dict[str, NewEventChainLinks], ) -> None: """Insert some number of room events into the necessary database tables. @@ -466,7 +566,9 @@ class PersistEventsStore: # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) - self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts]) + self._persist_event_auth_chain_txn( + txn, [e for e, _ in events_and_contexts], new_event_links + ) # _store_rejected_events_txn filters out any events which were # rejected, and returns the filtered list. @@ -496,6 +598,7 @@ class PersistEventsStore: self, txn: LoggingTransaction, events: List[EventBase], + new_event_links: Dict[str, NewEventChainLinks], ) -> None: # We only care about state events, so this if there are no state events. if not any(e.is_state() for e in events): @@ -519,59 +622,8 @@ class PersistEventsStore: ], ) - # We now calculate chain ID/sequence numbers for any state events we're - # persisting. We ignore out of band memberships as we're not in the room - # and won't have their auth chain (we'll fix it up later if we join the - # room). - # - # See: docs/auth_chain_difference_algorithm.md - - # We ignore legacy rooms that we aren't filling the chain cover index - # for. - rows = cast( - List[Tuple[str, Optional[Union[int, bool]]]], - self.db_pool.simple_select_many_txn( - txn, - table="rooms", - column="room_id", - iterable={event.room_id for event in events if event.is_state()}, - keyvalues={}, - retcols=("room_id", "has_auth_chain_index"), - ), - ) - rooms_using_chain_index = { - room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index - } - - state_events = { - event.event_id: event - for event in events - if event.is_state() and event.room_id in rooms_using_chain_index - } - - if not state_events: - return - - # We need to know the type/state_key and auth events of the events we're - # calculating chain IDs for. We don't rely on having the full Event - # instances as we'll potentially be pulling more events from the DB and - # we don't need the overhead of fetching/parsing the full event JSON. - event_to_types = { - e.event_id: (e.type, e.state_key) for e in state_events.values() - } - event_to_auth_chain = { - e.event_id: e.auth_event_ids() for e in state_events.values() - } - event_to_room_id = {e.event_id: e.room_id for e in state_events.values()} - - self._add_chain_cover_index( - txn, - self.db_pool, - self.store.event_chain_id_gen, - event_to_room_id, - event_to_types, - event_to_auth_chain, - ) + if new_event_links: + self._persist_chain_cover_index(txn, self.db_pool, new_event_links) @classmethod def _add_chain_cover_index( @@ -583,6 +635,35 @@ class PersistEventsStore: event_to_types: Dict[str, Tuple[str, str]], event_to_auth_chain: Dict[str, StrCollection], ) -> None: + """Calculate and persist the chain cover index for the given events. + + Args: + event_to_room_id: Event ID to the room ID of the event + event_to_types: Event ID to type and state_key of the event + event_to_auth_chain: Event ID to list of auth event IDs of the + event (events with no auth events can be excluded). + """ + + new_event_links = cls._calculate_chain_cover_index( + txn, + db_pool, + event_chain_id_gen, + event_to_room_id, + event_to_types, + event_to_auth_chain, + ) + cls._persist_chain_cover_index(txn, db_pool, new_event_links) + + @classmethod + def _calculate_chain_cover_index( + cls, + txn: LoggingTransaction, + db_pool: DatabasePool, + event_chain_id_gen: SequenceGenerator, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, StrCollection], + ) -> Dict[str, NewEventChainLinks]: """Calculate the chain cover index for the given events. Args: @@ -590,6 +671,10 @@ class PersistEventsStore: event_to_types: Event ID to type and state_key of the event event_to_auth_chain: Event ID to list of auth event IDs of the event (events with no auth events can be excluded). + + Returns: + A mapping with any new auth chain links we need to add, keyed by + event ID. """ # Map from event ID to chain ID/sequence number. @@ -708,11 +793,11 @@ class PersistEventsStore: room_id = event_to_room_id.get(event_id) if room_id: e_type, state_key = event_to_types[event_id] - db_pool.simple_insert_txn( + db_pool.simple_upsert_txn( txn, table="event_auth_chain_to_calculate", + keyvalues={"event_id": event_id}, values={ - "event_id": event_id, "room_id": room_id, "type": e_type, "state_key": state_key, @@ -724,7 +809,7 @@ class PersistEventsStore: break if not events_to_calc_chain_id_for: - return + return {} # Allocate chain ID/sequence numbers to each new event. new_chain_tuples = cls._allocate_chain_ids( @@ -739,23 +824,10 @@ class PersistEventsStore: ) chain_map.update(new_chain_tuples) - db_pool.simple_insert_many_txn( - txn, - table="event_auth_chains", - keys=("event_id", "chain_id", "sequence_number"), - values=[ - (event_id, c_id, seq) - for event_id, (c_id, seq) in new_chain_tuples.items() - ], - ) - - db_pool.simple_delete_many_txn( - txn, - table="event_auth_chain_to_calculate", - keyvalues={}, - column="event_id", - values=new_chain_tuples, - ) + to_return = { + event_id: NewEventChainLinks(chain_id, sequence_number) + for event_id, (chain_id, sequence_number) in new_chain_tuples.items() + } # Now we need to calculate any new links between chains caused by # the new events. @@ -825,10 +897,38 @@ class PersistEventsStore: auth_chain_id, auth_sequence_number = chain_map[auth_id] # Step 2a, add link between the event and auth event + to_return[event_id].links.append((auth_chain_id, auth_sequence_number)) chain_links.add_link( (chain_id, sequence_number), (auth_chain_id, auth_sequence_number) ) + return to_return + + @classmethod + def _persist_chain_cover_index( + cls, + txn: LoggingTransaction, + db_pool: DatabasePool, + new_event_links: Dict[str, NewEventChainLinks], + ) -> None: + db_pool.simple_insert_many_txn( + txn, + table="event_auth_chains", + keys=("event_id", "chain_id", "sequence_number"), + values=[ + (event_id, new_links.chain_id, new_links.sequence_number) + for event_id, new_links in new_event_links.items() + ], + ) + + db_pool.simple_delete_many_txn( + txn, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="event_id", + values=new_event_links, + ) + db_pool.simple_insert_many_txn( txn, table="event_auth_chain_links", @@ -838,7 +938,16 @@ class PersistEventsStore: "target_chain_id", "target_sequence_number", ), - values=list(chain_links.get_additions()), + values=[ + ( + new_links.chain_id, + new_links.sequence_number, + target_chain_id, + target_sequence_number, + ) + for new_links in new_event_links.values() + for (target_chain_id, target_sequence_number) in new_links.links + ], ) @staticmethod diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 81feb3ec29..c4e216c308 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase): ) # Actually call the function that calculates the auth chain stuff. - persist_events_store._persist_event_auth_chain_txn(txn, events) + new_event_links = ( + persist_events_store.calculate_chain_cover_index_for_events_txn( + txn, events[0].room_id, [e for e in events if e.is_state()] + ) + ) + persist_events_store._persist_event_auth_chain_txn( + txn, events, new_event_links + ) self.get_success( persist_events_store.db_pool.runInteraction( diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 0a6253e22c..1832a23714 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): }, ) + events = [ + cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id])) + for event_id in AUTH_GRAPH + ] + new_event_links = ( + self.persist_events.calculate_chain_cover_index_for_events_txn( + txn, room_id, [e for e in events if e.is_state()] + ) + ) self.persist_events._persist_event_auth_chain_txn( txn, - [ - cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id])) - for event_id in AUTH_GRAPH - ], + events, + new_event_links, ) self.get_success( @@ -628,13 +635,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) # Insert all events apart from 'B' + events = [ + cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) + for event_id in auth_graph + if event_id != "b" + ] + new_event_links = ( + self.persist_events.calculate_chain_cover_index_for_events_txn( + txn, room_id, [e for e in events if e.is_state()] + ) + ) self.persist_events._persist_event_auth_chain_txn( txn, - [ - cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) - for event_id in auth_graph - if event_id != "b" - ], + events, + new_event_links, ) # Now we insert the event 'B' without a chain cover, by temporarily @@ -647,9 +661,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): updatevalues={"has_auth_chain_index": False}, ) + events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))] + new_event_links = ( + self.persist_events.calculate_chain_cover_index_for_events_txn( + txn, room_id, [e for e in events if e.is_state()] + ) + ) self.persist_events._persist_event_auth_chain_txn( - txn, - [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))], + txn, events, new_event_links ) self.store.db_pool.simple_update_txn(