diff --git a/changelog.d/19365.feature b/changelog.d/19365.feature new file mode 100644 index 0000000000..c35afdc179 --- /dev/null +++ b/changelog.d/19365.feature @@ -0,0 +1 @@ +Support sending and receiving [MSC4354 Sticky Event](https://github.com/matrix-org/matrix-spec-proposals/pull/4354) metadata. \ No newline at end of file diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 101ff153a5..120b3b9496 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -139,6 +139,8 @@ experimental_features: msc4155_enabled: true # Thread Subscriptions msc4306_enabled: true + # Sticky Events + msc4354_enabled: true server_notices: system_mxid_localpart: _server diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 9b6a68e929..b8ef5dac50 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -24,7 +24,9 @@ """Contains constants from the specification.""" import enum -from typing import Final +from typing import Final, TypedDict + +from synapse.util.duration import Duration # the max size of a (canonical-json-encoded) event MAX_PDU_SIZE = 65536 @@ -292,6 +294,8 @@ class EventUnsignedContentFields: # Requesting user's membership, per MSC4115 MEMBERSHIP: Final = "membership" + STICKY_TTL: Final = "msc4354_sticky_duration_ttl_ms" + class MTextFields: """Fields found inside m.text content blocks.""" @@ -377,3 +381,40 @@ class Direction(enum.Enum): class ProfileFields: DISPLAYNAME: Final = "displayname" AVATAR_URL: Final = "avatar_url" + + +class StickyEventField(TypedDict): + """ + Dict content of the `sticky` part of an event. + """ + + duration_ms: int + + +class StickyEvent: + QUERY_PARAM_NAME: Final = "org.matrix.msc4354.sticky_duration_ms" + """ + Query parameter used by clients for setting the sticky duration of an event they are sending. + + Applies to: + - /rooms/.../send/... + - /rooms/.../state/... + """ + + EVENT_FIELD_NAME: Final = "msc4354_sticky" + """ + Name of the field in the top-level event dict that contains the sticky event dict. + """ + + MAX_DURATION: Duration = Duration(hours=1) + """ + Maximum stickiness duration as specified in MSC4354. + Ensures that data in the /sync response can go down and not grow unbounded. + """ + + MAX_EVENTS_IN_SYNC: Final = 100 + """ + Maximum number of sticky events to include in /sync. + + This is the default specified in the MSC. Chosen arbitrarily. + """ diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 0a4abd1839..159cd44237 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -102,6 +102,7 @@ from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.sliding_sync import SlidingSyncStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.sticky_events import StickyEventsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore @@ -137,6 +138,7 @@ class GenericWorkerStore( RoomWorkerStore, DirectoryWorkerStore, ThreadSubscriptionsWorkerStore, + StickyEventsWorkerStore, PushRulesWorkerStore, ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 9b6ff482cf..b6c8b8c062 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -580,5 +580,11 @@ class ExperimentalConfig(Config): # (and MSC4308: Thread Subscriptions extension to Sliding Sync) self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) + # MSC4354: Sticky Events + # Tracked in: https://github.com/element-hq/synapse/issues/19409 + # Note that sticky events persisted before this feature is enabled will not be + # considered sticky by the local homeserver. + self.msc4354_enabled: bool = experimental.get("msc4354_enabled", False) + # MSC4380: Invite blocking self.msc4380_enabled: bool = experimental.get("msc4380_enabled", False) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index ec8ab9506b..996be88cb2 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -127,7 +127,9 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instances that write to the event and backfill streams. + events: The instances that write to the event, backfill and `sticky_events` streams. + (`sticky_events` is written to during event persistence so must be handled by the + same stream writers.) typing: The instances that write to the typing stream. Currently can only be a single instance. to_device: The instances that write to the to_device stream. Currently diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index c7eaf7eda2..e6162997dd 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -36,7 +36,12 @@ from typing import ( import attr from unpaddedbase64 import encode_base64 -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + RelationTypes, + StickyEvent, +) from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.synapse_rust.events import EventInternalMetadata from synapse.types import ( @@ -44,6 +49,7 @@ from synapse.types import ( StrCollection, ) from synapse.util.caches import intern_dict +from synapse.util.duration import Duration from synapse.util.frozenutils import freeze if TYPE_CHECKING: @@ -318,6 +324,28 @@ class EventBase(metaclass=abc.ABCMeta): # this will be a no-op if the event dict is already frozen. self._dict = freeze(self._dict) + def sticky_duration(self) -> Duration | None: + """ + Returns the effective sticky duration of this event, or None + if the event does not have a sticky duration. + (Sticky Events are a MSC4354 feature.) + + Clamps the sticky duration to the maximum allowed duration. + """ + sticky_obj = self.get_dict().get(StickyEvent.EVENT_FIELD_NAME, None) + if type(sticky_obj) is not dict: + return None + sticky_duration_ms = sticky_obj.get("duration_ms", None) + # MSC: Clamp to 0 and MAX_DURATION (1 hour) + # We use `type(...) is int` to avoid accepting bools as `isinstance(True, int)` + # (bool is a subclass of int) + if type(sticky_duration_ms) is int and sticky_duration_ms >= 0: + return min( + Duration(milliseconds=sticky_duration_ms), + StickyEvent.MAX_DURATION, + ) + return None + def __str__(self) -> str: return self.__repr__() diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 6a2812109d..2cd1bf6106 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any import attr from signedjson.types import SigningKey -from synapse.api.constants import MAX_DEPTH, EventTypes +from synapse.api.constants import MAX_DEPTH, EventTypes, StickyEvent, StickyEventField from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, EventFormatVersions, @@ -89,6 +89,10 @@ class EventBuilder: content: JsonDict = attr.Factory(dict) unsigned: JsonDict = attr.Factory(dict) + sticky: StickyEventField | None = None + """ + Fields for MSC4354: Sticky Events + """ # These only exist on a subset of events, so they raise AttributeError if # someone tries to get them when they don't exist. @@ -269,6 +273,9 @@ class EventBuilder: if self._origin_server_ts is not None: event_dict["origin_server_ts"] = self._origin_server_ts + if self.sticky is not None: + event_dict[StickyEvent.EVENT_FIELD_NAME] = self.sticky + return create_local_event_from_event_dict( clock=self._clock, hostname=self._hostname, @@ -318,6 +325,7 @@ class EventBuilderFactory: unsigned=key_values.get("unsigned", {}), redacts=key_values.get("redacts", None), origin_server_ts=key_values.get("origin_server_ts", None), + sticky=key_values.get(StickyEvent.EVENT_FIELD_NAME, None), ) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index c58d1d42bc..7e41716f1e 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Optional from twisted.internet.interfaces import IDelayedCall -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, StickyEvent, StickyEventField from synapse.api.errors import ShadowBanError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME @@ -333,6 +333,7 @@ class DelayedEventsHandler: origin_server_ts: int | None, content: JsonDict, delay: int, + sticky_duration_ms: int | None, ) -> str: """ Creates a new delayed event and schedules its delivery. @@ -346,7 +347,9 @@ class DelayedEventsHandler: If None, the timestamp will be the actual time when the event is sent. content: The content of the event to be sent. delay: How long (in milliseconds) to wait before automatically sending the event. - + sticky_duration_ms: If an MSC4354 sticky event: the sticky duration (in milliseconds). + The event will be attempted to be reliably delivered to clients and remote servers + during its sticky period. Returns: The ID of the added delayed event. Raises: @@ -382,6 +385,7 @@ class DelayedEventsHandler: origin_server_ts=origin_server_ts, content=content, delay=delay, + sticky_duration_ms=sticky_duration_ms, ) if self._repl_client is not None: @@ -570,7 +574,10 @@ class DelayedEventsHandler: if event.state_key is not None: event_dict["state_key"] = event.state_key - + if event.sticky_duration_ms is not None: + event_dict[StickyEvent.EVENT_FIELD_NAME] = StickyEventField( + duration_ms=event.sticky_duration_ms + ) ( sent_event, _, diff --git a/synapse/notifier.py b/synapse/notifier.py index cf3923110e..93d438def7 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -526,6 +526,7 @@ class Notifier: StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.THREAD_SUBSCRIPTIONS, + StreamKeyType.STICKY_EVENTS, ], new_token: int, users: Collection[str | UserID] | None = None, diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index fdda932ead..bc7e46d4c9 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -43,7 +43,10 @@ from synapse.replication.tcp.streams import ( UnPartialStatedEventStream, UnPartialStatedRoomStream, ) -from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream +from synapse.replication.tcp.streams._base import ( + StickyEventsStream, + ThreadSubscriptionsStream, +) from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, @@ -262,6 +265,12 @@ class ReplicationDataHandler: token, users=[row.user_id for row in rows], ) + elif stream_name == StickyEventsStream.NAME: + self.notifier.on_new_event( + StreamKeyType.STICKY_EVENTS, + token, + rooms=[row.room_id for row in rows], + ) await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 087c87545e..ad9fed72dd 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -67,6 +67,7 @@ from synapse.replication.tcp.streams import ( ) from synapse.replication.tcp.streams._base import ( DeviceListsStream, + StickyEventsStream, ThreadSubscriptionsStream, ) from synapse.util.background_queue import BackgroundQueue @@ -217,6 +218,12 @@ class ReplicationCommandHandler: continue + if isinstance(stream, StickyEventsStream): + if hs.get_instance_name() in hs.config.worker.writers.events: + self._streams_to_replicate.append(stream) + + continue + if isinstance(stream, DeviceListsStream): if hs.get_instance_name() in hs.config.worker.writers.device_lists: self._streams_to_replicate.append(stream) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 87ac0a5ae1..067847617f 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -40,6 +40,7 @@ from synapse.replication.tcp.streams._base import ( PushersStream, PushRulesStream, ReceiptsStream, + StickyEventsStream, Stream, ThreadSubscriptionsStream, ToDeviceStream, @@ -68,6 +69,7 @@ STREAMS_MAP = { ToDeviceStream, FederationStream, AccountDataStream, + StickyEventsStream, ThreadSubscriptionsStream, UnPartialStatedRoomStream, UnPartialStatedEventStream, @@ -90,6 +92,7 @@ __all__ = [ "ToDeviceStream", "FederationStream", "AccountDataStream", + "StickyEventsStream", "ThreadSubscriptionsStream", "UnPartialStatedRoomStream", "UnPartialStatedEventStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4fb2aac202..1ea6b4fa85 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -763,3 +763,48 @@ class ThreadSubscriptionsStream(_StreamFromIdGen): return [], to_token, False return rows, rows[-1][0], len(updates) == limit + + +@attr.s(slots=True, auto_attribs=True) +class StickyEventsStreamRow: + """Stream to inform workers about changes to sticky events.""" + + room_id: str + + event_id: str + """The sticky event ID""" + + +class StickyEventsStream(_StreamFromIdGen): + """A sticky event was changed.""" + + NAME = "sticky_events" + ROW_TYPE = StickyEventsStreamRow + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastores().main + super().__init__( + hs.get_instance_name(), + self._update_function, + self.store._sticky_events_id_gen, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + updates = await self.store.get_updated_sticky_events( + from_id=from_token, to_id=to_token, limit=limit + ) + rows = [ + ( + update.stream_id, + # These are the args to `StickyEventsStreamRow` + (update.room_id, update.event_id), + ) + for update in updates + ] + + if not rows: + return [], to_token, False + + return rows, rows[-1][0], len(updates) == limit diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 5e7dcb0191..9172bfcb4e 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -34,7 +34,13 @@ from prometheus_client.core import Histogram from twisted.web.server import Request from synapse import event_auth -from synapse.api.constants import Direction, EventTypes, Membership +from synapse.api.constants import ( + Direction, + EventTypes, + Membership, + StickyEvent, + StickyEventField, +) from synapse.api.errors import ( AuthError, Codes, @@ -210,6 +216,7 @@ class RoomStateEventRestServlet(RestServlet): self.clock = hs.get_clock() self._max_event_delay_ms = hs.config.server.max_event_delay_ms self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker + self._msc4354_enabled = hs.config.experimental.msc4354_enabled def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/state/$eventtype @@ -331,6 +338,10 @@ class RoomStateEventRestServlet(RestServlet): if requester.app_service: origin_server_ts = parse_integer(request, "ts") + sticky_duration_ms: int | None = None + if self._msc4354_enabled: + sticky_duration_ms = parse_integer(request, StickyEvent.QUERY_PARAM_NAME) + delay = _parse_request_delay(request, self._max_event_delay_ms) if delay is not None: delay_id = await self.delayed_events_handler.add( @@ -341,6 +352,7 @@ class RoomStateEventRestServlet(RestServlet): origin_server_ts=origin_server_ts, content=content, delay=delay, + sticky_duration_ms=sticky_duration_ms, ) set_tag("delay_id", delay_id) @@ -368,6 +380,10 @@ class RoomStateEventRestServlet(RestServlet): "room_id": room_id, "sender": requester.user.to_string(), } + if sticky_duration_ms is not None: + event_dict[StickyEvent.EVENT_FIELD_NAME] = StickyEventField( + duration_ms=sticky_duration_ms + ) if state_key is not None: event_dict["state_key"] = state_key @@ -400,6 +416,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() self._max_event_delay_ms = hs.config.server.max_event_delay_ms + self._msc4354_enabled = hs.config.experimental.msc4354_enabled def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] @@ -420,6 +437,10 @@ class RoomSendEventRestServlet(TransactionRestServlet): if requester.app_service: origin_server_ts = parse_integer(request, "ts") + sticky_duration_ms: int | None = None + if self._msc4354_enabled: + sticky_duration_ms = parse_integer(request, StickyEvent.QUERY_PARAM_NAME) + delay = _parse_request_delay(request, self._max_event_delay_ms) if delay is not None: delay_id = await self.delayed_events_handler.add( @@ -430,6 +451,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): origin_server_ts=origin_server_ts, content=content, delay=delay, + sticky_duration_ms=sticky_duration_ms, ) set_tag("delay_id", delay_id) @@ -446,6 +468,11 @@ class RoomSendEventRestServlet(TransactionRestServlet): if origin_server_ts is not None: event_dict["origin_server_ts"] = origin_server_ts + if sticky_duration_ms is not None: + event_dict[StickyEvent.EVENT_FIELD_NAME] = StickyEventField( + duration_ms=sticky_duration_ms + ) + try: ( event, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 75f27c98de..8945849531 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -182,6 +182,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc4306": self.config.experimental.msc4306_enabled, # MSC4169: Backwards-compatible redaction sending using `/send` "com.beeper.msc4169": self.config.experimental.msc4169_enabled, + # MSC4354: Sticky events + "org.matrix.msc4354": self.config.experimental.msc4354_enabled, # MSC4380: Invite blocking "org.matrix.msc4380": self.config.experimental.msc4380_enabled, }, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 12593094f1..9f8d4debbe 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -34,6 +34,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.sliding_sync import SlidingSyncStore from synapse.storage.databases.main.stats import UserSortOrder +from synapse.storage.databases.main.sticky_events import StickyEventsWorkerStore from synapse.storage.databases.main.thread_subscriptions import ( ThreadSubscriptionsWorkerStore, ) @@ -144,6 +145,7 @@ class DataStore( TagsStore, AccountDataStore, ThreadSubscriptionsWorkerStore, + StickyEventsWorkerStore, PushRulesWorkerStore, StreamWorkerStore, OpenIdStore, diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 5547150515..1727f589e2 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -54,6 +54,7 @@ class EventDetails: origin_server_ts: Timestamp | None content: JsonDict device_id: DeviceID | None + sticky_duration_ms: int | None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -122,6 +123,7 @@ class DelayedEventsStore(SQLBaseStore): origin_server_ts: int | None, content: JsonDict, delay: int, + sticky_duration_ms: int | None, ) -> tuple[DelayID, Timestamp]: """ Inserts a new delayed event in the DB. @@ -148,6 +150,7 @@ class DelayedEventsStore(SQLBaseStore): "state_key": state_key, "origin_server_ts": origin_server_ts, "content": json_encoder.encode(content), + "sticky_duration_ms": sticky_duration_ms, }, ) @@ -299,6 +302,7 @@ class DelayedEventsStore(SQLBaseStore): "send_ts", "content", "device_id", + "sticky_duration_ms", ) ) sql_update = "UPDATE delayed_events SET is_processed = TRUE" @@ -344,6 +348,7 @@ class DelayedEventsStore(SQLBaseStore): Timestamp(row[5] if row[5] is not None else row[6]), db_to_json(row[7]), DeviceID(row[8]) if row[8] is not None else None, + int(row[9]) if row[9] is not None else None, DelayID(row[0]), UserLocalpart(row[1]), ) @@ -392,6 +397,7 @@ class DelayedEventsStore(SQLBaseStore): origin_server_ts, content, device_id, + sticky_duration_ms, user_localpart """, (delay_id,), @@ -407,8 +413,9 @@ class DelayedEventsStore(SQLBaseStore): Timestamp(row[3]) if row[3] is not None else None, db_to_json(row[4]), DeviceID(row[5]) if row[5] is not None else None, + int(row[6]) if row[6] is not None else None, DelayID(delay_id), - UserLocalpart(row[6]), + UserLocalpart(row[7]), ) return event, self._get_next_delayed_event_send_ts_txn(txn) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 60fc884c3a..cb452dbc9b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -264,6 +264,7 @@ class PersistEventsStore: self.database_engine = db.engine self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() + self._msc4354_enabled = hs.config.experimental.msc4354_enabled self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -1185,6 +1186,11 @@ class PersistEventsStore: sliding_sync_table_changes, ) + if self._msc4354_enabled: + self.store.insert_sticky_events_txn( + txn, [ev for ev, _ in events_and_contexts] + ) + # We only update the sliding sync tables for non-backfilled events. self._update_sliding_sync_tables_with_new_persisted_events_txn( txn, room_id, events_and_contexts @@ -2646,6 +2652,11 @@ class PersistEventsStore: # event isn't an outlier any more. self._update_backward_extremeties(txn, [event]) + if self._msc4354_enabled and event.sticky_duration(): + # The de-outliered event is sticky. Update the sticky events table to ensure + # we deliver this down /sync. + self.store.insert_sticky_events_txn(txn, [event]) + return [ec for ec in events_and_contexts if ec[0] not in to_remove] def _store_event_txn( diff --git a/synapse/storage/databases/main/sticky_events.py b/synapse/storage/databases/main/sticky_events.py new file mode 100644 index 0000000000..101306296e --- /dev/null +++ b/synapse/storage/databases/main/sticky_events.py @@ -0,0 +1,322 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +import logging +import random +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, +) + +from twisted.internet.defer import Deferred + +from synapse.events import EventBase +from synapse.replication.tcp.streams._base import StickyEventsStream +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.util.duration import Duration + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +DELETE_EXPIRED_STICKY_EVENTS_INTERVAL = Duration(hours=1) +""" +Remove entries from the sticky_events table at this frequency. +Note: don't be misled, we still honour shorter expiration timeouts, +because readers of the sticky_events table filter out expired sticky events +themselves, even if they aren't deleted from the table yet. + +Currently just an arbitrary choice. +Frequent enough to clean up expired sticky events promptly, +especially given the short cap on the lifetime of sticky events. +""" + + +@dataclass(frozen=True) +class StickyEventUpdate: + stream_id: int + room_id: str + event_id: str + soft_failed: bool + + +class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._can_write_to_sticky_events = ( + self._instance_name in hs.config.worker.writers.events + ) + + # Technically this means we will cleanup N times, once per event persister, maybe put on master? + if self._can_write_to_sticky_events: + # Start a looping call to clean up the `sticky_events` table + # + # Because this will run once per event persister (for now), + # randomly stagger the initial time so that they don't all + # coincide with each other if the workers are deployed at the + # same time. This allows each cleanup to be somewhat more effective + # than if they all started at the same time, as they would all be + # cleaning up the same thing whereas each worker gets to clean up a little + # throughout the hour when they're staggered. + # + # Concurrent execution of the same deletions could also lead to + # repeatable serialisation violations in the database transaction, + # meaning we'd have to retry the transaction several times. + # + # This staggering is not critical, it's just best-effort. + self.clock.call_later( + # random() is 0.0 to 1.0 + DELETE_EXPIRED_STICKY_EVENTS_INTERVAL * random.random(), + self.clock.looping_call, + self._run_background_cleanup, + DELETE_EXPIRED_STICKY_EVENTS_INTERVAL, + ) + + self._sticky_events_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="sticky_events", + server_name=self.server_name, + instance_name=self._instance_name, + tables=[ + ("sticky_events", "instance_name", "stream_id"), + ], + sequence_name="sticky_events_sequence", + writers=hs.config.worker.writers.events, + ) + + if hs.config.experimental.msc4354_enabled and isinstance( + self.database_engine, Sqlite3Engine + ): + import sqlite3 + + if sqlite3.sqlite_version_info < (3, 40, 0): + raise RuntimeError( + f"Experimental MSC4354 Sticky Events enabled but SQLite3 version is too old: {sqlite3.sqlite_version_info}, must be at least 3.40. Disable MSC4354 Sticky Events, switch to Postgres, or upgrade SQLite. See https://github.com/element-hq/synapse/issues/19428" + ) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == StickyEventsStream.NAME: + self._sticky_events_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + + def get_max_sticky_events_stream_id(self) -> int: + """Get the current maximum stream_id for thread subscriptions. + + Returns: + The maximum stream_id + """ + return self._sticky_events_id_gen.get_current_token() + + def get_sticky_events_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._sticky_events_id_gen + + async def get_updated_sticky_events( + self, *, from_id: int, to_id: int, limit: int + ) -> list[StickyEventUpdate]: + """Get updates to sticky events between two stream IDs. + + Bounds: from_id < ... <= to_id + + Args: + from_id: The starting stream ID (exclusive) + to_id: The ending stream ID (inclusive) + limit: The maximum number of rows to return + + Returns: + list of StickyEventUpdate update rows + """ + + if not self.hs.config.experimental.msc4354_enabled: + # We need to prevent `_get_updated_sticky_events_txn` + # from running when MSC4354 is turned off, because the query used + # for SQLite is not compatible with Ubuntu 22.04 (as used in our CI olddeps run). + # It's technically out of support. + # See: https://github.com/element-hq/synapse/issues/19428 + return [] + + return await self.db_pool.runInteraction( + "get_updated_sticky_events", + self._get_updated_sticky_events_txn, + from_id, + to_id, + limit, + ) + + def _get_updated_sticky_events_txn( + self, txn: LoggingTransaction, from_id: int, to_id: int, limit: int + ) -> list[StickyEventUpdate]: + if isinstance(self.database_engine, PostgresEngine): + expr_soft_failed = "COALESCE(((ej.internal_metadata::jsonb)->>'soft_failed')::boolean, FALSE)" + else: + expr_soft_failed = "COALESCE(ej.internal_metadata->>'soft_failed', FALSE)" + + txn.execute( + f""" + SELECT se.stream_id, se.room_id, se.event_id, + {expr_soft_failed} AS "soft_failed" + FROM sticky_events se + INNER JOIN event_json ej USING (event_id) + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """, + (from_id, to_id, limit), + ) + + return [ + StickyEventUpdate( + stream_id=stream_id, + room_id=room_id, + event_id=event_id, + soft_failed=bool(soft_failed), + ) + for stream_id, room_id, event_id, soft_failed in txn + ] + + def insert_sticky_events_txn( + self, + txn: LoggingTransaction, + events: list[EventBase], + ) -> None: + """ + Insert events into the sticky_events table. + + Skips inserting events: + - if they are considered spammy by the policy server; + (unsure if correct, track: https://github.com/matrix-org/matrix-spec-proposals/pull/4354#discussion_r2727593350) + - if they are rejected; + - if they are outliers (they should be reconsidered for insertion when de-outliered); or + - if they are not sticky (e.g. if the stickiness expired). + + Skipping the insertion of these types of 'invalid' events is useful for performance reasons because + they would fill up the table yet we wouldn't show them to clients anyway. + + Since syncing clients can't (easily?) 'skip over' sticky events (due to being in-order, reliably delivered), + tracking loads of invalid events in the table could make it expensive for servers to retrieve the sticky events that are actually valid. + + For instance, someone spamming 1000s of rejected or 'policy_server_spammy' events could clog up this table in a way that means we either + have to deliver empty payloads to syncing clients, or consider substantially more than 100 events in order to gather a 100-sized batch to send down. + """ + + now_ms = self.clock.time_msec() + # event, expires_at + sticky_events: list[tuple[EventBase, int]] = [] + for ev in events: + # MSC: Note: policy servers and other similar antispam techniques still apply to these events. + if ev.internal_metadata.policy_server_spammy: + continue + # We shouldn't be passed rejected events, but if we do, we filter them out too. + if ev.rejected_reason is not None: + continue + # We can't persist outlier sticky events as we don't know the room state at that event + if ev.internal_metadata.is_outlier(): + continue + sticky_duration = ev.sticky_duration() + if sticky_duration is None: + continue + # Calculate the end time as start_time + effecitve sticky duration + expires_at = min(ev.origin_server_ts, now_ms) + sticky_duration.as_millis() + # Filter out already expired sticky events + if expires_at <= now_ms: + continue + + sticky_events.append((ev, expires_at)) + + if len(sticky_events) == 0: + return + + logger.info( + "inserting %d sticky events in room %s", + len(sticky_events), + sticky_events[0][0].room_id, + ) + + # Generate stream_ids in one go + sticky_events_with_ids = zip( + sticky_events, + self._sticky_events_id_gen.get_next_mult_txn(txn, len(sticky_events)), + strict=True, + ) + + self.db_pool.simple_insert_many_txn( + txn, + "sticky_events", + keys=( + "instance_name", + "stream_id", + "room_id", + "event_id", + "event_stream_ordering", + "sender", + "expires_at", + ), + values=[ + ( + self._instance_name, + stream_id, + ev.room_id, + ev.event_id, + ev.internal_metadata.stream_ordering, + ev.sender, + expires_at, + ) + for (ev, expires_at), stream_id in sticky_events_with_ids + ], + ) + + async def _delete_expired_sticky_events(self) -> None: + await self.db_pool.runInteraction( + "_delete_expired_sticky_events", + self._delete_expired_sticky_events_txn, + self.clock.time_msec(), + ) + + def _delete_expired_sticky_events_txn( + self, txn: LoggingTransaction, now: int + ) -> None: + """ + From the `sticky_events` table, deletes all entries whose expiry is in the past + (older than `now`). + + This is fine because we don't consider the events as sticky anymore when that's + happened. + """ + txn.execute( + """ + DELETE FROM sticky_events WHERE expires_at < ? + """, + (now,), + ) + + def _run_background_cleanup(self) -> Deferred: + return self.hs.run_as_background_process( + "delete_expired_sticky_events", + self._delete_expired_sticky_events, + ) diff --git a/synapse/storage/schema/main/delta/93/01_sticky_events.sql b/synapse/storage/schema/main/delta/93/01_sticky_events.sql new file mode 100644 index 0000000000..59fded5959 --- /dev/null +++ b/synapse/storage/schema/main/delta/93/01_sticky_events.sql @@ -0,0 +1,60 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Tracks sticky events. +-- Excludes 'policy_server_spammy' events, outliers, rejected events. +-- +-- Skipping the insertion of these types of 'invalid' events is useful for performance reasons because +-- they would fill up the table yet we wouldn't show them to clients anyway. +-- +-- Since syncing clients can't (easily) 'skip over' sticky events (due to being in-order, reliably delivered), +-- tracking loads of invalid events in the table could make it expensive for servers to retrieve the sticky events that are actually valid. +-- +-- For instance, someone spamming 1000s of rejected or 'policy_server_spammy' events could clog up this table in a way that means we either +-- have to deliver empty payloads to syncing clients, or consider substantially more than 100 events in order to gather a 100-sized batch to send down. +-- +-- May contain sticky events that have expired since being inserted, +-- although they will be periodically cleaned up in the background. +CREATE TABLE sticky_events ( + -- Position in the sticky events stream + stream_id INTEGER NOT NULL PRIMARY KEY, + + -- Name of the worker sending this. (This makes the stream compatible with multiple writers.) + instance_name TEXT NOT NULL, + + -- The event ID of the sticky event itself. + event_id TEXT NOT NULL, + + -- The room ID that the sticky event is in. + -- Denormalised for performance. (Safe as it's an immutable property of the event.) + room_id TEXT NOT NULL, + + -- The stream_ordering of the event. + -- Denormalised for performance since we will want to sort these by stream_ordering + -- when fetching them. (Safe as it's an immutable property of the event.) + event_stream_ordering INTEGER NOT NULL UNIQUE, + + -- Sender of the sticky event. + -- Denormalised for performance so we can query only for sticky events originating + -- from our homeserver. (Safe as it's an immutable property of the event.) + sender TEXT NOT NULL, + + -- When the sticky event expires, in milliseconds since the Unix epoch. + expires_at BIGINT NOT NULL +); + +-- For pulling out sticky events by room at send time, obeying stream ordering range limits. +CREATE INDEX sticky_events_room_idx ON sticky_events (room_id, event_stream_ordering); + +-- A optional integer for combining sticky events with delayed events. Used at send time. +ALTER TABLE delayed_events ADD COLUMN sticky_duration_ms BIGINT; diff --git a/synapse/storage/schema/main/delta/93/01_sticky_events_seq.sql.postgres b/synapse/storage/schema/main/delta/93/01_sticky_events_seq.sql.postgres new file mode 100644 index 0000000000..9ba72856bc --- /dev/null +++ b/synapse/storage/schema/main/delta/93/01_sticky_events_seq.sql.postgres @@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +CREATE SEQUENCE sticky_events_sequence; +-- Synapse streams start at 2, because the default position is 1 +-- so any item inserted at position 1 is ignored. +-- We have to use nextval not START WITH 2, see https://github.com/element-hq/synapse/issues/18712 +SELECT nextval('sticky_events_sequence'); diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 143f659499..d2720fb959 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -84,6 +84,7 @@ class EventSources: self._instance_name ) thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id() + sticky_events_key = self.store.get_max_sticky_events_stream_id() token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -98,6 +99,7 @@ class EventSources: groups_key=0, un_partial_stated_rooms_key=un_partial_stated_rooms_key, thread_subscriptions_key=thread_subscriptions_key, + sticky_events_key=sticky_events_key, ) return token @@ -125,6 +127,7 @@ class EventSources: StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(), StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(), StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(), + StreamKeyType.STICKY_EVENTS: self.store.get_sticky_events_stream_id_generator(), } for _, key in StreamKeyType.__members__.items(): diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 99eefb8acb..fb1f1192b7 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1006,6 +1006,7 @@ class StreamKeyType(Enum): DEVICE_LIST = "device_list_key" UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" THREAD_SUBSCRIPTIONS = "thread_subscriptions_key" + STICKY_EVENTS = "sticky_events_key" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -1027,6 +1028,7 @@ class StreamToken: 9. `groups_key`: `1` (note that this key is now unused) 10. `un_partial_stated_rooms_key`: `379` 11. `thread_subscriptions_key`: 4242 + 12. `sticky_events_key`: 4141 You can see how many of these keys correspond to the various fields in a "/sync" response: @@ -1086,6 +1088,7 @@ class StreamToken: groups_key: int un_partial_stated_rooms_key: int thread_subscriptions_key: int + sticky_events_key: int _SEPARATOR = "_" START: ClassVar["StreamToken"] @@ -1114,6 +1117,7 @@ class StreamToken: groups_key, un_partial_stated_rooms_key, thread_subscriptions_key, + sticky_events_key, ) = keys return cls( @@ -1130,6 +1134,7 @@ class StreamToken: groups_key=int(groups_key), un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), thread_subscriptions_key=int(thread_subscriptions_key), + sticky_events_key=int(sticky_events_key), ) except CancelledError: raise @@ -1153,6 +1158,7 @@ class StreamToken: str(self.groups_key), str(self.un_partial_stated_rooms_key), str(self.thread_subscriptions_key), + str(self.sticky_events_key), ] ) @@ -1218,6 +1224,7 @@ class StreamToken: StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.THREAD_SUBSCRIPTIONS, + StreamKeyType.STICKY_EVENTS, ], ) -> int: ... @@ -1274,7 +1281,7 @@ class StreamToken: f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, " f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, " f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key}," - f"thread_subscriptions: {self.thread_subscriptions_key})" + f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key})" ) @@ -1290,6 +1297,7 @@ StreamToken.START = StreamToken( groups_key=0, un_partial_stated_rooms_key=0, thread_subscriptions_key=0, + sticky_events_key=0, ) diff --git a/synapse/visibility.py b/synapse/visibility.py index 452a2d50fb..5ba2a14a24 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -237,6 +237,20 @@ async def filter_and_transform_events_for_client( # to the cache! cloned = clone_event(filtered) cloned.unsigned[EventUnsignedContentFields.MEMBERSHIP] = user_membership + if storage.main.config.experimental.msc4354_enabled: + sticky_duration = cloned.sticky_duration() + if sticky_duration: + now_ms = storage.main.clock.time_msec() + expires_at = ( + # min() ensures that the origin server can't lie about the time and + # send the event 'in the future', as that would allow them to exceed + # the 1 hour limit on stickiness duration. + min(cloned.origin_server_ts, now_ms) + sticky_duration.as_millis() + ) + if expires_at > now_ms: + cloned.unsigned[EventUnsignedContentFields.STICKY_TTL] = ( + expires_at - now_ms + ) return cloned diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 4070bcaeaa..b32665eb73 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2545,7 +2545,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): def test_topo_token_is_accepted(self) -> None: """Test Topo Token is accepted.""" - token = "t1-0_0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), @@ -2559,7 +2559,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: """Test that stream token is accepted for forward pagination.""" - token = "s0_0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 926560afd6..f85c9939ce 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2245,7 +2245,7 @@ class RoomMessageListTestCase(RoomBase): self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self) -> None: - token = "t1-0_0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) @@ -2256,7 +2256,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("end" in channel.json_body) def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: - token = "s0_0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) diff --git a/tests/rest/client/test_sticky_events.py b/tests/rest/client/test_sticky_events.py new file mode 100644 index 0000000000..a6e704fe8c --- /dev/null +++ b/tests/rest/client/test_sticky_events.py @@ -0,0 +1,179 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# + +import sqlite3 + +from twisted.internet.testing import MemoryReactor + +from synapse.api.constants import EventTypes, EventUnsignedContentFields +from synapse.rest import admin +from synapse.rest.client import login, register, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.clock import Clock +from synapse.util.duration import Duration + +from tests import unittest +from tests.utils import USE_POSTGRES_FOR_TESTS + + +class StickyEventsClientTestCase(unittest.HomeserverTestCase): + """ + Tests for the client-server API parts of MSC4354: Sticky Events + """ + + if not USE_POSTGRES_FOR_TESTS and sqlite3.sqlite_version_info < (3, 40, 0): + # We need the JSON functionality in SQLite + skip = f"SQLite version is too old to support sticky events: {sqlite3.sqlite_version_info} (See https://github.com/element-hq/synapse/issues/19428)" + + servlets = [ + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc4354_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Register an account + self.user_id = self.register_user("user1", "pass") + self.token = self.login(self.user_id, "pass") + + # Create a room + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def _assert_event_sticky_for(self, event_id: str, sticky_ttl: int) -> None: + channel = self.make_request( + "GET", + f"/rooms/{self.room_id}/event/{event_id}", + access_token=self.token, + ) + + self.assertEqual( + channel.code, 200, f"could not retrieve event {event_id}: {channel.result}" + ) + event = channel.json_body + + self.assertIn( + EventUnsignedContentFields.STICKY_TTL, + event["unsigned"], + f"No {EventUnsignedContentFields.STICKY_TTL} field in {event_id}; event not sticky: {event}", + ) + self.assertEqual( + event["unsigned"][EventUnsignedContentFields.STICKY_TTL], + sticky_ttl, + f"{event_id} had an unexpected sticky TTL: {event}", + ) + + def _assert_event_not_sticky(self, event_id: str) -> None: + channel = self.make_request( + "GET", + f"/rooms/{self.room_id}/event/{event_id}", + access_token=self.token, + ) + + self.assertEqual( + channel.code, 200, f"could not retrieve event {event_id}: {channel.result}" + ) + event = channel.json_body + + self.assertNotIn( + EventUnsignedContentFields.STICKY_TTL, + event["unsigned"], + f"{EventUnsignedContentFields.STICKY_TTL} field unexpectedly found in {event_id}: {event}", + ) + + def test_sticky_event_via_event_endpoint(self) -> None: + # Arrange: Send a sticky event with a specific duration + sticky_event_response = self.helper.send_sticky_event( + self.room_id, + EventTypes.Message, + duration=Duration(minutes=1), + content={"body": "sticky message", "msgtype": "m.text"}, + tok=self.token, + ) + event_id = sticky_event_response["event_id"] + + # If we request the event immediately, it will still have + # 1 minute of stickiness + # The other 100 ms is advanced in FakeChannel.await_result. + self._assert_event_sticky_for(event_id, 59_900) + + # But if we advance time by 59.799 seconds... + # we will get the event on its last millisecond of stickiness + # The other 100 ms is advanced in FakeChannel.await_result. + self.reactor.advance(59.799) + self._assert_event_sticky_for(event_id, 1) + + # Advancing time any more, the event is no longer sticky + self.reactor.advance(0.001) + self._assert_event_not_sticky(event_id) + + +class StickyEventsDisabledClientTestCase(unittest.HomeserverTestCase): + """ + Tests client-facing behaviour of MSC4354: Sticky Events when the feature is + disabled. + """ + + servlets = [ + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Register an account + self.user_id = self.register_user("user1", "pass") + self.token = self.login(self.user_id, "pass") + + # Create a room + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def _assert_event_not_sticky(self, event_id: str) -> None: + channel = self.make_request( + "GET", + f"/rooms/{self.room_id}/event/{event_id}", + access_token=self.token, + ) + + self.assertEqual( + channel.code, 200, f"could not retrieve event {event_id}: {channel.result}" + ) + event = channel.json_body + + self.assertNotIn( + EventUnsignedContentFields.STICKY_TTL, + event["unsigned"], + f"{EventUnsignedContentFields.STICKY_TTL} field unexpectedly found in {event_id}: {event}", + ) + + def test_sticky_event_via_event_endpoint(self) -> None: + sticky_event_response = self.helper.send_sticky_event( + self.room_id, + EventTypes.Message, + duration=Duration(minutes=1), + content={"body": "sticky message", "msgtype": "m.text"}, + tok=self.token, + ) + event_id = sticky_event_response["event_id"] + + # Since the feature is disabled, the event isn't sticky + self._assert_event_not_sticky(event_id) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index b3808d75bb..bfa8e6f3d8 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -48,6 +48,7 @@ from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.api.errors import Codes from synapse.server import HomeServer from synapse.types import JsonDict +from synapse.util.duration import Duration from tests.server import FakeChannel, make_request from tests.test_utils.html_parsers import TestHtmlParser @@ -453,6 +454,44 @@ class RestHelper: return channel.json_body + def send_sticky_event( + self, + room_id: str, + type: str, + *, + duration: Duration, + content: dict | None = None, + txn_id: str | None = None, + tok: str | None = None, + expect_code: int = HTTPStatus.OK, + custom_headers: Iterable[tuple[AnyStr, AnyStr]] | None = None, + ) -> JsonDict: + """ + Send an event that has a sticky duration according to MSC4354. + """ + + if txn_id is None: + txn_id = f"m{time.time()}" + + path = f"/_matrix/client/r0/rooms/{room_id}/send/{type}/{txn_id}?org.matrix.msc4354.sticky_duration_ms={duration.as_millis()}" + if tok: + path = path + f"&access_token={tok}" + + channel = make_request( + self.reactor, + self.site, + "PUT", + path, + content or {}, + custom_headers=custom_headers, + ) + + assert channel.code == expect_code, ( + f"Expected: {expect_code}, got: {channel.code}, resp: {channel.result['body']!r}" + ) + + return channel.json_body + def get_event( self, room_id: str, diff --git a/tests/storage/test_sticky_events.py b/tests/storage/test_sticky_events.py new file mode 100644 index 0000000000..60243cb2f4 --- /dev/null +++ b/tests/storage/test_sticky_events.py @@ -0,0 +1,278 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +import sqlite3 + +from twisted.internet.testing import MemoryReactor + +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + StickyEvent, + StickyEventField, +) +from synapse.api.room_versions import RoomVersions +from synapse.rest import admin +from synapse.rest.client import login, register, room +from synapse.server import HomeServer +from synapse.types import JsonDict, create_requester +from synapse.util.clock import Clock +from synapse.util.duration import Duration + +from tests import unittest +from tests.utils import USE_POSTGRES_FOR_TESTS + + +class StickyEventsTestCase(unittest.HomeserverTestCase): + """ + Tests for the storage functions related to MSC4354: Sticky Events + """ + + if not USE_POSTGRES_FOR_TESTS and sqlite3.sqlite_version_info < (3, 40, 0): + # We need the JSON functionality in SQLite + skip = f"SQLite version is too old to support sticky events: {sqlite3.sqlite_version_info} (See https://github.com/element-hq/synapse/issues/19428)" + + servlets = [ + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc4354_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = self.hs.get_datastores().main + + # Register an account and create a room + self.user_id = self.register_user("user", "pass") + self.token = self.login(self.user_id, "pass") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def test_get_updated_sticky_events(self) -> None: + """Test getting updated sticky events between stream IDs.""" + # Get the starting stream_id + start_id = self.store.get_max_sticky_events_stream_id() + + event_id_1 = self.helper.send_sticky_event( + self.room_id, + EventTypes.Message, + duration=Duration(minutes=1), + content={"body": "message 1", "msgtype": "m.text"}, + tok=self.token, + )["event_id"] + + mid_id = self.store.get_max_sticky_events_stream_id() + + event_id_2 = self.helper.send_sticky_event( + self.room_id, + EventTypes.Message, + duration=Duration(minutes=1), + content={"body": "message 2", "msgtype": "m.text"}, + tok=self.token, + )["event_id"] + + end_id = self.store.get_max_sticky_events_stream_id() + + # Get all updates + updates = self.get_success( + self.store.get_updated_sticky_events( + from_id=start_id, to_id=end_id, limit=10 + ) + ) + self.assertEqual(len(updates), 2) + self.assertEqual(updates[0].event_id, event_id_1) + self.assertEqual(updates[0].soft_failed, False) + self.assertEqual(updates[1].event_id, event_id_2) + self.assertEqual(updates[1].soft_failed, False) + + # Get only the second update + updates = self.get_success( + self.store.get_updated_sticky_events(from_id=mid_id, to_id=end_id, limit=10) + ) + self.assertEqual(len(updates), 1) + self.assertEqual(updates[0].event_id, event_id_2) + self.assertEqual(updates[0].soft_failed, False) + + def test_delete_expired_sticky_events(self) -> None: + """Test deletion of expired sticky events.""" + # Insert an expired event by advancing time past its duration + self.helper.send_sticky_event( + self.room_id, + EventTypes.Message, + duration=Duration(milliseconds=1), + content={"body": "expired message", "msgtype": "m.text"}, + tok=self.token, + ) + self.reactor.advance(0.002) + + # Insert a non-expired event + event_id_2 = self.helper.send_sticky_event( + self.room_id, + EventTypes.Message, + duration=Duration(minutes=1), + content={"body": "non-expired message", "msgtype": "m.text"}, + tok=self.token, + )["event_id"] + + end_id = self.store.get_max_sticky_events_stream_id() + + # Delete expired events + self.get_success(self.store._delete_expired_sticky_events()) + + # Check that only the non-expired event remains + sticky_events = self.get_success( + self.store.db_pool.simple_select_list( + table="sticky_events", keyvalues=None, retcols=("stream_id", "event_id") + ) + ) + self.assertEqual( + sticky_events, + [ + (end_id, event_id_2), + ], + ) + + def test_get_updated_sticky_events_with_limit(self) -> None: + """Test that the limit parameter works correctly.""" + # Get the starting stream_id + start_id = self.store.get_max_sticky_events_stream_id() + + event_id_1 = self.helper.send_sticky_event( + self.room_id, + EventTypes.Message, + duration=Duration(minutes=1), + content={"body": "message 1", "msgtype": "m.text"}, + tok=self.token, + )["event_id"] + + self.helper.send_sticky_event( + self.room_id, + EventTypes.Message, + duration=Duration(minutes=1), + content={"body": "message 2", "msgtype": "m.text"}, + tok=self.token, + ) + + # Get only the first update + updates = self.get_success( + self.store.get_updated_sticky_events( + from_id=start_id, to_id=start_id + 2, limit=1 + ) + ) + self.assertEqual(len(updates), 1) + self.assertEqual(updates[0].event_id, event_id_1) + + def test_outlier_events_not_in_table(self) -> None: + """ + Tests the behaviour of outliered and then de-outliered events in the + sticky_events table: they should only be added once they are de-outliered. + """ + persist_controller = self.hs.get_storage_controllers().persistence + assert persist_controller is not None + + user1_id = self.register_user("user1", "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + start_id = self.store.get_max_sticky_events_stream_id() + + room_id = self.helper.create_room_as( + user2_id, tok=user2_tok, room_version=RoomVersions.V10.identifier + ) + + # Create a membership event + event_dict = { + "type": EventTypes.Member, + "state_key": user1_id, + "sender": user1_id, + "room_id": room_id, + "content": {EventContentFields.MEMBERSHIP: Membership.JOIN}, + StickyEvent.EVENT_FIELD_NAME: StickyEventField( + duration_ms=Duration(hours=1).as_millis() + ), + } + + # Create the event twice: once as an outlier, once as a non-outlier. + # It's not at all obvious, but event creation before is deterministic + # (provided we don't change the forward extremities of the room!), + # so these two events are actually the same event with the same event ID. + ( + event_outlier, + unpersisted_context_outlier, + ) = self.get_success( + self.hs.get_event_creation_handler().create_event( + requester=create_requester(user1_id), + event_dict=event_dict, + outlier=True, + ) + ) + ( + event_non_outlier, + unpersisted_context_non_outlier, + ) = self.get_success( + self.hs.get_event_creation_handler().create_event( + requester=create_requester(user1_id), + event_dict=event_dict, + outlier=False, + ) + ) + + # Safety check that we're testing what we think we are + self.assertEqual(event_outlier.event_id, event_non_outlier.event_id) + + # Now persist the event as an outlier first of all + # FIXME: Should we use an `EventContext.for_outlier(...)` here? + # Doesn't seem to matter for this test. + context_outlier = self.get_success( + unpersisted_context_outlier.persist(event_outlier) + ) + self.get_success( + persist_controller.persist_event( + event_outlier, + context_outlier, + ) + ) + + # Since the event is outliered, it won't show up in the sticky_events table... + sticky_events = self.get_success( + self.store.db_pool.simple_select_list( + table="sticky_events", keyvalues=None, retcols=("stream_id", "event_id") + ) + ) + self.assertEqual(len(sticky_events), 0) + + # Now persist the event properly so that it gets de-outliered. + context_non_outlier = self.get_success( + unpersisted_context_non_outlier.persist(event_non_outlier) + ) + self.get_success( + persist_controller.persist_event( + event_non_outlier, + context_non_outlier, + ) + ) + + end_id = self.store.get_max_sticky_events_stream_id() + + # Check the event made it into the sticky_events table + updates = self.get_success( + self.store.get_updated_sticky_events( + from_id=start_id, to_id=end_id, limit=10 + ) + ) + self.assertEqual(len(updates), 1) + self.assertEqual(updates[0].event_id, event_non_outlier.event_id)