From 7c8daf4ed9282f61deea707e201b7b216f187fda Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Tue, 23 Sep 2025 11:02:20 +0100 Subject: [PATCH] Get sticky events working with Simplified Sliding Sync --- synapse/handlers/sliding_sync/extensions.py | 60 +++++++++++++++++++ synapse/handlers/sync.py | 4 +- synapse/rest/client/sync.py | 30 +++++++++- .../storage/databases/main/sticky_events.py | 17 ++++-- .../93/01_sticky_events_seq.sql.postgres | 11 ++-- synapse/types/handlers/sliding_sync.py | 16 +++++ synapse/types/rest/client/__init__.py | 12 ++++ synapse/util/async_helpers.py | 25 ++++++++ 8 files changed, 160 insertions(+), 15 deletions(-) diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index 25ee954b7f..1bb7ff0d87 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -54,6 +54,7 @@ from synapse.util.async_helpers import ( concurrently_execute, gather_optional_coroutines, ) +from synapse.visibility import filter_events_for_client _ThreadSubscription: TypeAlias = ( SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription @@ -76,7 +77,10 @@ class SlidingSyncExtensionHandler: self.event_sources = hs.get_event_sources() self.device_handler = hs.get_device_handler() self.push_rules_handler = hs.get_push_rules_handler() + self.clock = hs.get_clock() + self._storage_controllers = hs.get_storage_controllers() self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled + self._enable_sticky_events = hs.config.experimental.msc4354_enabled @trace async def get_extensions_response( @@ -177,6 +181,19 @@ class SlidingSyncExtensionHandler: from_token=from_token, ) + sticky_events_coro = None + if ( + sync_config.extensions.sticky_events is not None + and self._enable_sticky_events + ): + sticky_events_coro = self.get_sticky_events_extension_response( + sync_config=sync_config, + sticky_events_request=sync_config.extensions.sticky_events, + actual_room_ids=actual_room_ids, + to_token=to_token, + from_token=from_token, + ) + ( to_device_response, e2ee_response, @@ -184,6 +201,7 @@ class SlidingSyncExtensionHandler: receipts_response, typing_response, thread_subs_response, + sticky_events_response, ) = await gather_optional_coroutines( to_device_coro, e2ee_coro, @@ -191,6 +209,7 @@ class SlidingSyncExtensionHandler: receipts_coro, typing_coro, thread_subs_coro, + sticky_events_coro, ) return SlidingSyncResult.Extensions( @@ -200,6 +219,7 @@ class SlidingSyncExtensionHandler: receipts=receipts_response, typing=typing_response, thread_subscriptions=thread_subs_response, + sticky_events=sticky_events_response, ) def find_relevant_room_ids_for_extension( @@ -970,3 +990,43 @@ class SlidingSyncExtensionHandler: unsubscribed=unsubscribed_threads, prev_batch=prev_batch, ) + + async def get_sticky_events_extension_response( + self, + sync_config: SlidingSyncConfig, + sticky_events_request: SlidingSyncConfig.Extensions.StickyEventsExtension, + actual_room_ids: Set[str], + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken], + ) -> Optional[SlidingSyncResult.Extensions.StickyEventsExtension]: + if not sticky_events_request.enabled: + return None + now = self.clock.time_msec() + from_id = from_token.stream_token.sticky_events_key if from_token else 0 + _, room_to_event_ids = await self.store.get_sticky_events_in_rooms( + actual_room_ids, + from_id, + to_token.sticky_events_key, + now, + ) + all_sticky_event_ids = { + ev_id for evs in room_to_event_ids.values() for ev_id in evs + } + event_map = await self.store.get_events(all_sticky_event_ids) + filtered_events = await filter_events_for_client( + self._storage_controllers, + sync_config.user.to_string(), + list(event_map.values()), + always_include_ids=frozenset(all_sticky_event_ids), + ) + event_map = {ev.event_id: ev for ev in filtered_events} + return SlidingSyncResult.Extensions.StickyEventsExtension( + room_id_to_sticky_events={ + room_id: { + event_map[event_id] + for event_id in sticky_event_ids + if event_id in event_map + } + for room_id, sticky_event_ids in room_to_event_ids.items() + } + ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6f8010fe82..c26be23e19 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -620,8 +620,7 @@ class SyncHandler: Args: sync_result_builder now_token: Where the server is currently up to. - since_token: Where the server was when the client - last synced. + since_token: Where the server was when the client last synced. Returns: A tuple of the now StreamToken, updated to reflect the which sticky events are included, and a dict mapping from room_id to a list of @@ -638,6 +637,7 @@ class SyncHandler: to_id, sticky_by_room = await self.store.get_sticky_events_in_rooms( room_ids, from_id, + now_token.sticky_events_key, now, ) now_token = now_token.copy_and_replace(StreamKeyType.STICKY_EVENTS, to_id) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index d1ef1f6193..7bafcf475e 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -623,7 +623,7 @@ class SyncRestServlet(RestServlet): serialized_sticky = await self._event_serializer.serialize_events( room.sticky, time_now, config=serialize_options ) - result["sticky"] = {"events": serialized_sticky} + result["msc4354_sticky"] = {"events": serialized_sticky} if room.unread_thread_notifications: result["unread_thread_notifications"] = room.unread_thread_notifications if self._msc3773_enabled: @@ -653,6 +653,7 @@ class SlidingSyncRestServlet(RestServlet): - receipts (MSC3960) - account data (MSC3959) - thread subscriptions (MSC4308) + - sticky events (MSC4354) Request query parameters: timeout: How long to wait for new events in milliseconds. @@ -1096,8 +1097,35 @@ class SlidingSyncRestServlet(RestServlet): _serialise_thread_subscriptions(extensions.thread_subscriptions) ) + if extensions.sticky_events: + serialized_extensions[ + "org.matrix.msc4354.sticky_events" + ] = await self._serialise_sticky_events(requester, extensions.sticky_events) + return serialized_extensions + async def _serialise_sticky_events( + self, + requester: Requester, + sticky_events: SlidingSyncResult.Extensions.StickyEventsExtension, + ) -> JsonDict: + time_now = self.clock.time_msec() + # Same as SSS timelines. TODO: support more options like /sync does. + serialize_options = SerializeEventConfig( + event_format=format_event_for_client_v2_without_room_id, + requester=requester, + ) + return { + "rooms": { + room_id: { + "events": await self.event_serializer.serialize_events( + sticky_events, time_now, config=serialize_options + ) + } + for room_id, sticky_events in sticky_events.room_id_to_sticky_events.items() + }, + } + def _serialise_thread_subscriptions( thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension, diff --git a/synapse/storage/databases/main/sticky_events.py b/synapse/storage/databases/main/sticky_events.py index f8ca147431..130bd388df 100644 --- a/synapse/storage/databases/main/sticky_events.py +++ b/synapse/storage/databases/main/sticky_events.py @@ -118,6 +118,7 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore): self, room_ids: Collection[str], from_id: int, + to_id: int, now: int, ) -> Tuple[int, Dict[str, Set[str]]]: """ @@ -125,7 +126,8 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore): Args: room_ids: The room IDs to return sticky events in. - from_id: The sticky stream ID that sticky events should be returned from. + from_id: The sticky stream ID that sticky events should be returned from (exclusive). + to_id: The sticky stream ID that sticky events should end at (inclusive). now: The current time in unix millis, used for skipping expired events. Returns: A tuple of (to_id, map[room_id, event_ids]) @@ -135,22 +137,24 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore): self._get_sticky_events_in_rooms_txn, room_ids, from_id, + to_id, now, ) - to_id = from_id + new_to_id = from_id room_to_events: Dict[str, Set[str]] = {} for stream_id, room_id, event_id in sticky_events_rows: - to_id = max(to_id, stream_id) + new_to_id = max(new_to_id, stream_id) events = room_to_events.get(room_id, set()) events.add(event_id) room_to_events[room_id] = events - return (to_id, room_to_events) + return (new_to_id, room_to_events) def _get_sticky_events_in_rooms_txn( self, txn: LoggingTransaction, room_ids: Collection[str], from_id: int, + to_id: int, now: int, ) -> List[Tuple[int, str, str]]: if len(room_ids) == 0: @@ -160,9 +164,10 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore): ) txn.execute( f""" - SELECT stream_id, room_id, event_id FROM sticky_events WHERE soft_failed=FALSE AND expires_at > ? AND stream_id > ? AND {clause} + SELECT stream_id, room_id, event_id FROM sticky_events + WHERE soft_failed=FALSE AND expires_at > ? AND stream_id > ? AND stream_id <= ? AND {clause} """, - (now, from_id, room_id_values), + (now, from_id, to_id, room_id_values), ) return cast(List[Tuple[int, str, str]], txn.fetchall()) diff --git a/synapse/storage/schema/main/delta/93/01_sticky_events_seq.sql.postgres b/synapse/storage/schema/main/delta/93/01_sticky_events_seq.sql.postgres index e4f4ff5798..5a28a309d9 100644 --- a/synapse/storage/schema/main/delta/93/01_sticky_events_seq.sql.postgres +++ b/synapse/storage/schema/main/delta/93/01_sticky_events_seq.sql.postgres @@ -11,9 +11,8 @@ -- See the GNU Affero General Public License for more details: -- . -CREATE SEQUENCE sticky_events_sequence - -- Synapse streams start at 2, because the default position is 1 - -- so any item inserted at position 1 is ignored. - -- This is also what existing streams do, except they use `setval(..., 1)` - -- which is semantically the same except less obvious. - START WITH 2; +CREATE SEQUENCE sticky_events_sequence; +-- Synapse streams start at 2, because the default position is 1 +-- so any item inserted at position 1 is ignored. +-- We have to use nextval not START WITH 2, see https://github.com/element-hq/synapse/issues/18712 +SELECT nextval('thread_subscriptions_sequence'); diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index b7bc565464..4f7d1b895c 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -21,6 +21,7 @@ from typing import ( AbstractSet, Any, Callable, + Collection, Dict, Final, Generic, @@ -396,12 +397,26 @@ class SlidingSyncResult: or bool(self.prev_batch) ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class StickyEventsExtension: + """The Sticky Events extension (MSC4354) + + Attributes: + room_id_to_sticky_events: map (room_id -> [unexpired_sticky_events]) + """ + + room_id_to_sticky_events: Mapping[str, Collection[EventBase]] + + def __bool__(self) -> bool: + return bool(self.room_id_to_sticky_events) + to_device: Optional[ToDeviceExtension] = None e2ee: Optional[E2eeExtension] = None account_data: Optional[AccountDataExtension] = None receipts: Optional[ReceiptsExtension] = None typing: Optional[TypingExtension] = None thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None + sticky_events: Optional[StickyEventsExtension] = None def __bool__(self) -> bool: return bool( @@ -411,6 +426,7 @@ class SlidingSyncResult: or self.receipts or self.typing or self.thread_subscriptions + or self.sticky_events ) next_pos: SlidingSyncStreamToken diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index 11d7e59b43..39d78e66a0 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -376,6 +376,15 @@ class SlidingSyncBody(RequestBodyModel): enabled: Optional[StrictBool] = False limit: StrictInt = 100 + class StickyEventsExtension(RequestBodyModel): + """The Sticky Events extension (MSC4354) + + Attributes: + enabled + """ + + enabled: Optional[StrictBool] = False + to_device: Optional[ToDeviceExtension] = None e2ee: Optional[E2eeExtension] = None account_data: Optional[AccountDataExtension] = None @@ -384,6 +393,9 @@ class SlidingSyncBody(RequestBodyModel): thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field( alias="io.element.msc4308.thread_subscriptions" ) + sticky_events: Optional[StickyEventsExtension] = Field( + alias="org.matrix.msc4354.sticky_events" + ) conn_id: Optional[StrictStr] diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index c21b7887f9..7b766c54aa 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -348,6 +348,7 @@ T3 = TypeVar("T3") T4 = TypeVar("T4") T5 = TypeVar("T5") T6 = TypeVar("T6") +T7 = TypeVar("T7") @overload @@ -479,6 +480,30 @@ async def gather_optional_coroutines( ]: ... +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + Optional[Coroutine[Any, Any, T5]], + Optional[Coroutine[Any, Any, T6]], + Optional[Coroutine[Any, Any, T7]], + ] + ], +) -> Tuple[ + Optional[T1], + Optional[T2], + Optional[T3], + Optional[T4], + Optional[T5], + Optional[T6], + Optional[T7], +]: ... + + async def gather_optional_coroutines( *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]], ) -> Tuple[Optional[T1], ...]: