1
0

Get sticky events working with Simplified Sliding Sync

This commit is contained in:
Kegan Dougal
2025-09-23 11:02:20 +01:00
parent 0cfdd0d6b5
commit 7c8daf4ed9
8 changed files with 160 additions and 15 deletions

View File

@@ -54,6 +54,7 @@ from synapse.util.async_helpers import (
concurrently_execute,
gather_optional_coroutines,
)
from synapse.visibility import filter_events_for_client
_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
@@ -76,7 +77,10 @@ class SlidingSyncExtensionHandler:
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self.clock = hs.get_clock()
self._storage_controllers = hs.get_storage_controllers()
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
self._enable_sticky_events = hs.config.experimental.msc4354_enabled
@trace
async def get_extensions_response(
@@ -177,6 +181,19 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
sticky_events_coro = None
if (
sync_config.extensions.sticky_events is not None
and self._enable_sticky_events
):
sticky_events_coro = self.get_sticky_events_extension_response(
sync_config=sync_config,
sticky_events_request=sync_config.extensions.sticky_events,
actual_room_ids=actual_room_ids,
to_token=to_token,
from_token=from_token,
)
(
to_device_response,
e2ee_response,
@@ -184,6 +201,7 @@ class SlidingSyncExtensionHandler:
receipts_response,
typing_response,
thread_subs_response,
sticky_events_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
@@ -191,6 +209,7 @@ class SlidingSyncExtensionHandler:
receipts_coro,
typing_coro,
thread_subs_coro,
sticky_events_coro,
)
return SlidingSyncResult.Extensions(
@@ -200,6 +219,7 @@ class SlidingSyncExtensionHandler:
receipts=receipts_response,
typing=typing_response,
thread_subscriptions=thread_subs_response,
sticky_events=sticky_events_response,
)
def find_relevant_room_ids_for_extension(
@@ -970,3 +990,43 @@ class SlidingSyncExtensionHandler:
unsubscribed=unsubscribed_threads,
prev_batch=prev_batch,
)
async def get_sticky_events_extension_response(
self,
sync_config: SlidingSyncConfig,
sticky_events_request: SlidingSyncConfig.Extensions.StickyEventsExtension,
actual_room_ids: Set[str],
to_token: StreamToken,
from_token: Optional[SlidingSyncStreamToken],
) -> Optional[SlidingSyncResult.Extensions.StickyEventsExtension]:
if not sticky_events_request.enabled:
return None
now = self.clock.time_msec()
from_id = from_token.stream_token.sticky_events_key if from_token else 0
_, room_to_event_ids = await self.store.get_sticky_events_in_rooms(
actual_room_ids,
from_id,
to_token.sticky_events_key,
now,
)
all_sticky_event_ids = {
ev_id for evs in room_to_event_ids.values() for ev_id in evs
}
event_map = await self.store.get_events(all_sticky_event_ids)
filtered_events = await filter_events_for_client(
self._storage_controllers,
sync_config.user.to_string(),
list(event_map.values()),
always_include_ids=frozenset(all_sticky_event_ids),
)
event_map = {ev.event_id: ev for ev in filtered_events}
return SlidingSyncResult.Extensions.StickyEventsExtension(
room_id_to_sticky_events={
room_id: {
event_map[event_id]
for event_id in sticky_event_ids
if event_id in event_map
}
for room_id, sticky_event_ids in room_to_event_ids.items()
}
)

View File

@@ -620,8 +620,7 @@ class SyncHandler:
Args:
sync_result_builder
now_token: Where the server is currently up to.
since_token: Where the server was when the client
last synced.
since_token: Where the server was when the client last synced.
Returns:
A tuple of the now StreamToken, updated to reflect the which sticky
events are included, and a dict mapping from room_id to a list of
@@ -638,6 +637,7 @@ class SyncHandler:
to_id, sticky_by_room = await self.store.get_sticky_events_in_rooms(
room_ids,
from_id,
now_token.sticky_events_key,
now,
)
now_token = now_token.copy_and_replace(StreamKeyType.STICKY_EVENTS, to_id)

View File

@@ -623,7 +623,7 @@ class SyncRestServlet(RestServlet):
serialized_sticky = await self._event_serializer.serialize_events(
room.sticky, time_now, config=serialize_options
)
result["sticky"] = {"events": serialized_sticky}
result["msc4354_sticky"] = {"events": serialized_sticky}
if room.unread_thread_notifications:
result["unread_thread_notifications"] = room.unread_thread_notifications
if self._msc3773_enabled:
@@ -653,6 +653,7 @@ class SlidingSyncRestServlet(RestServlet):
- receipts (MSC3960)
- account data (MSC3959)
- thread subscriptions (MSC4308)
- sticky events (MSC4354)
Request query parameters:
timeout: How long to wait for new events in milliseconds.
@@ -1096,8 +1097,35 @@ class SlidingSyncRestServlet(RestServlet):
_serialise_thread_subscriptions(extensions.thread_subscriptions)
)
if extensions.sticky_events:
serialized_extensions[
"org.matrix.msc4354.sticky_events"
] = await self._serialise_sticky_events(requester, extensions.sticky_events)
return serialized_extensions
async def _serialise_sticky_events(
self,
requester: Requester,
sticky_events: SlidingSyncResult.Extensions.StickyEventsExtension,
) -> JsonDict:
time_now = self.clock.time_msec()
# Same as SSS timelines. TODO: support more options like /sync does.
serialize_options = SerializeEventConfig(
event_format=format_event_for_client_v2_without_room_id,
requester=requester,
)
return {
"rooms": {
room_id: {
"events": await self.event_serializer.serialize_events(
sticky_events, time_now, config=serialize_options
)
}
for room_id, sticky_events in sticky_events.room_id_to_sticky_events.items()
},
}
def _serialise_thread_subscriptions(
thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension,

View File

@@ -118,6 +118,7 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore):
self,
room_ids: Collection[str],
from_id: int,
to_id: int,
now: int,
) -> Tuple[int, Dict[str, Set[str]]]:
"""
@@ -125,7 +126,8 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore):
Args:
room_ids: The room IDs to return sticky events in.
from_id: The sticky stream ID that sticky events should be returned from.
from_id: The sticky stream ID that sticky events should be returned from (exclusive).
to_id: The sticky stream ID that sticky events should end at (inclusive).
now: The current time in unix millis, used for skipping expired events.
Returns:
A tuple of (to_id, map[room_id, event_ids])
@@ -135,22 +137,24 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore):
self._get_sticky_events_in_rooms_txn,
room_ids,
from_id,
to_id,
now,
)
to_id = from_id
new_to_id = from_id
room_to_events: Dict[str, Set[str]] = {}
for stream_id, room_id, event_id in sticky_events_rows:
to_id = max(to_id, stream_id)
new_to_id = max(new_to_id, stream_id)
events = room_to_events.get(room_id, set())
events.add(event_id)
room_to_events[room_id] = events
return (to_id, room_to_events)
return (new_to_id, room_to_events)
def _get_sticky_events_in_rooms_txn(
self,
txn: LoggingTransaction,
room_ids: Collection[str],
from_id: int,
to_id: int,
now: int,
) -> List[Tuple[int, str, str]]:
if len(room_ids) == 0:
@@ -160,9 +164,10 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore):
)
txn.execute(
f"""
SELECT stream_id, room_id, event_id FROM sticky_events WHERE soft_failed=FALSE AND expires_at > ? AND stream_id > ? AND {clause}
SELECT stream_id, room_id, event_id FROM sticky_events
WHERE soft_failed=FALSE AND expires_at > ? AND stream_id > ? AND stream_id <= ? AND {clause}
""",
(now, from_id, room_id_values),
(now, from_id, to_id, room_id_values),
)
return cast(List[Tuple[int, str, str]], txn.fetchall())

View File

@@ -11,9 +11,8 @@
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
CREATE SEQUENCE sticky_events_sequence
-- Synapse streams start at 2, because the default position is 1
-- so any item inserted at position 1 is ignored.
-- This is also what existing streams do, except they use `setval(..., 1)`
-- which is semantically the same except less obvious.
START WITH 2;
CREATE SEQUENCE sticky_events_sequence;
-- Synapse streams start at 2, because the default position is 1
-- so any item inserted at position 1 is ignored.
-- We have to use nextval not START WITH 2, see https://github.com/element-hq/synapse/issues/18712
SELECT nextval('thread_subscriptions_sequence');

View File

@@ -21,6 +21,7 @@ from typing import (
AbstractSet,
Any,
Callable,
Collection,
Dict,
Final,
Generic,
@@ -396,12 +397,26 @@ class SlidingSyncResult:
or bool(self.prev_batch)
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class StickyEventsExtension:
"""The Sticky Events extension (MSC4354)
Attributes:
room_id_to_sticky_events: map (room_id -> [unexpired_sticky_events])
"""
room_id_to_sticky_events: Mapping[str, Collection[EventBase]]
def __bool__(self) -> bool:
return bool(self.room_id_to_sticky_events)
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None
receipts: Optional[ReceiptsExtension] = None
typing: Optional[TypingExtension] = None
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None
sticky_events: Optional[StickyEventsExtension] = None
def __bool__(self) -> bool:
return bool(
@@ -411,6 +426,7 @@ class SlidingSyncResult:
or self.receipts
or self.typing
or self.thread_subscriptions
or self.sticky_events
)
next_pos: SlidingSyncStreamToken

View File

@@ -376,6 +376,15 @@ class SlidingSyncBody(RequestBodyModel):
enabled: Optional[StrictBool] = False
limit: StrictInt = 100
class StickyEventsExtension(RequestBodyModel):
"""The Sticky Events extension (MSC4354)
Attributes:
enabled
"""
enabled: Optional[StrictBool] = False
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None
@@ -384,6 +393,9 @@ class SlidingSyncBody(RequestBodyModel):
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field(
alias="io.element.msc4308.thread_subscriptions"
)
sticky_events: Optional[StickyEventsExtension] = Field(
alias="org.matrix.msc4354.sticky_events"
)
conn_id: Optional[StrictStr]

View File

@@ -348,6 +348,7 @@ T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
T7 = TypeVar("T7")
@overload
@@ -479,6 +480,30 @@ async def gather_optional_coroutines(
]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
Optional[Coroutine[Any, Any, T4]],
Optional[Coroutine[Any, Any, T5]],
Optional[Coroutine[Any, Any, T6]],
Optional[Coroutine[Any, Any, T7]],
]
],
) -> Tuple[
Optional[T1],
Optional[T2],
Optional[T3],
Optional[T4],
Optional[T5],
Optional[T6],
Optional[T7],
]: ...
async def gather_optional_coroutines(
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
) -> Tuple[Optional[T1], ...]: