Get sticky events working with Simplified Sliding Sync
This commit is contained in:
@@ -54,6 +54,7 @@ from synapse.util.async_helpers import (
|
||||
concurrently_execute,
|
||||
gather_optional_coroutines,
|
||||
)
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
_ThreadSubscription: TypeAlias = (
|
||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
|
||||
@@ -76,7 +77,10 @@ class SlidingSyncExtensionHandler:
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.push_rules_handler = hs.get_push_rules_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
|
||||
self._enable_sticky_events = hs.config.experimental.msc4354_enabled
|
||||
|
||||
@trace
|
||||
async def get_extensions_response(
|
||||
@@ -177,6 +181,19 @@ class SlidingSyncExtensionHandler:
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
sticky_events_coro = None
|
||||
if (
|
||||
sync_config.extensions.sticky_events is not None
|
||||
and self._enable_sticky_events
|
||||
):
|
||||
sticky_events_coro = self.get_sticky_events_extension_response(
|
||||
sync_config=sync_config,
|
||||
sticky_events_request=sync_config.extensions.sticky_events,
|
||||
actual_room_ids=actual_room_ids,
|
||||
to_token=to_token,
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
(
|
||||
to_device_response,
|
||||
e2ee_response,
|
||||
@@ -184,6 +201,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts_response,
|
||||
typing_response,
|
||||
thread_subs_response,
|
||||
sticky_events_response,
|
||||
) = await gather_optional_coroutines(
|
||||
to_device_coro,
|
||||
e2ee_coro,
|
||||
@@ -191,6 +209,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts_coro,
|
||||
typing_coro,
|
||||
thread_subs_coro,
|
||||
sticky_events_coro,
|
||||
)
|
||||
|
||||
return SlidingSyncResult.Extensions(
|
||||
@@ -200,6 +219,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts=receipts_response,
|
||||
typing=typing_response,
|
||||
thread_subscriptions=thread_subs_response,
|
||||
sticky_events=sticky_events_response,
|
||||
)
|
||||
|
||||
def find_relevant_room_ids_for_extension(
|
||||
@@ -970,3 +990,43 @@ class SlidingSyncExtensionHandler:
|
||||
unsubscribed=unsubscribed_threads,
|
||||
prev_batch=prev_batch,
|
||||
)
|
||||
|
||||
async def get_sticky_events_extension_response(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
sticky_events_request: SlidingSyncConfig.Extensions.StickyEventsExtension,
|
||||
actual_room_ids: Set[str],
|
||||
to_token: StreamToken,
|
||||
from_token: Optional[SlidingSyncStreamToken],
|
||||
) -> Optional[SlidingSyncResult.Extensions.StickyEventsExtension]:
|
||||
if not sticky_events_request.enabled:
|
||||
return None
|
||||
now = self.clock.time_msec()
|
||||
from_id = from_token.stream_token.sticky_events_key if from_token else 0
|
||||
_, room_to_event_ids = await self.store.get_sticky_events_in_rooms(
|
||||
actual_room_ids,
|
||||
from_id,
|
||||
to_token.sticky_events_key,
|
||||
now,
|
||||
)
|
||||
all_sticky_event_ids = {
|
||||
ev_id for evs in room_to_event_ids.values() for ev_id in evs
|
||||
}
|
||||
event_map = await self.store.get_events(all_sticky_event_ids)
|
||||
filtered_events = await filter_events_for_client(
|
||||
self._storage_controllers,
|
||||
sync_config.user.to_string(),
|
||||
list(event_map.values()),
|
||||
always_include_ids=frozenset(all_sticky_event_ids),
|
||||
)
|
||||
event_map = {ev.event_id: ev for ev in filtered_events}
|
||||
return SlidingSyncResult.Extensions.StickyEventsExtension(
|
||||
room_id_to_sticky_events={
|
||||
room_id: {
|
||||
event_map[event_id]
|
||||
for event_id in sticky_event_ids
|
||||
if event_id in event_map
|
||||
}
|
||||
for room_id, sticky_event_ids in room_to_event_ids.items()
|
||||
}
|
||||
)
|
||||
|
||||
@@ -620,8 +620,7 @@ class SyncHandler:
|
||||
Args:
|
||||
sync_result_builder
|
||||
now_token: Where the server is currently up to.
|
||||
since_token: Where the server was when the client
|
||||
last synced.
|
||||
since_token: Where the server was when the client last synced.
|
||||
Returns:
|
||||
A tuple of the now StreamToken, updated to reflect the which sticky
|
||||
events are included, and a dict mapping from room_id to a list of
|
||||
@@ -638,6 +637,7 @@ class SyncHandler:
|
||||
to_id, sticky_by_room = await self.store.get_sticky_events_in_rooms(
|
||||
room_ids,
|
||||
from_id,
|
||||
now_token.sticky_events_key,
|
||||
now,
|
||||
)
|
||||
now_token = now_token.copy_and_replace(StreamKeyType.STICKY_EVENTS, to_id)
|
||||
|
||||
@@ -623,7 +623,7 @@ class SyncRestServlet(RestServlet):
|
||||
serialized_sticky = await self._event_serializer.serialize_events(
|
||||
room.sticky, time_now, config=serialize_options
|
||||
)
|
||||
result["sticky"] = {"events": serialized_sticky}
|
||||
result["msc4354_sticky"] = {"events": serialized_sticky}
|
||||
if room.unread_thread_notifications:
|
||||
result["unread_thread_notifications"] = room.unread_thread_notifications
|
||||
if self._msc3773_enabled:
|
||||
@@ -653,6 +653,7 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
- receipts (MSC3960)
|
||||
- account data (MSC3959)
|
||||
- thread subscriptions (MSC4308)
|
||||
- sticky events (MSC4354)
|
||||
|
||||
Request query parameters:
|
||||
timeout: How long to wait for new events in milliseconds.
|
||||
@@ -1096,8 +1097,35 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
_serialise_thread_subscriptions(extensions.thread_subscriptions)
|
||||
)
|
||||
|
||||
if extensions.sticky_events:
|
||||
serialized_extensions[
|
||||
"org.matrix.msc4354.sticky_events"
|
||||
] = await self._serialise_sticky_events(requester, extensions.sticky_events)
|
||||
|
||||
return serialized_extensions
|
||||
|
||||
async def _serialise_sticky_events(
|
||||
self,
|
||||
requester: Requester,
|
||||
sticky_events: SlidingSyncResult.Extensions.StickyEventsExtension,
|
||||
) -> JsonDict:
|
||||
time_now = self.clock.time_msec()
|
||||
# Same as SSS timelines. TODO: support more options like /sync does.
|
||||
serialize_options = SerializeEventConfig(
|
||||
event_format=format_event_for_client_v2_without_room_id,
|
||||
requester=requester,
|
||||
)
|
||||
return {
|
||||
"rooms": {
|
||||
room_id: {
|
||||
"events": await self.event_serializer.serialize_events(
|
||||
sticky_events, time_now, config=serialize_options
|
||||
)
|
||||
}
|
||||
for room_id, sticky_events in sticky_events.room_id_to_sticky_events.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _serialise_thread_subscriptions(
|
||||
thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension,
|
||||
|
||||
@@ -118,6 +118,7 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore):
|
||||
self,
|
||||
room_ids: Collection[str],
|
||||
from_id: int,
|
||||
to_id: int,
|
||||
now: int,
|
||||
) -> Tuple[int, Dict[str, Set[str]]]:
|
||||
"""
|
||||
@@ -125,7 +126,8 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore):
|
||||
|
||||
Args:
|
||||
room_ids: The room IDs to return sticky events in.
|
||||
from_id: The sticky stream ID that sticky events should be returned from.
|
||||
from_id: The sticky stream ID that sticky events should be returned from (exclusive).
|
||||
to_id: The sticky stream ID that sticky events should end at (inclusive).
|
||||
now: The current time in unix millis, used for skipping expired events.
|
||||
Returns:
|
||||
A tuple of (to_id, map[room_id, event_ids])
|
||||
@@ -135,22 +137,24 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore):
|
||||
self._get_sticky_events_in_rooms_txn,
|
||||
room_ids,
|
||||
from_id,
|
||||
to_id,
|
||||
now,
|
||||
)
|
||||
to_id = from_id
|
||||
new_to_id = from_id
|
||||
room_to_events: Dict[str, Set[str]] = {}
|
||||
for stream_id, room_id, event_id in sticky_events_rows:
|
||||
to_id = max(to_id, stream_id)
|
||||
new_to_id = max(new_to_id, stream_id)
|
||||
events = room_to_events.get(room_id, set())
|
||||
events.add(event_id)
|
||||
room_to_events[room_id] = events
|
||||
return (to_id, room_to_events)
|
||||
return (new_to_id, room_to_events)
|
||||
|
||||
def _get_sticky_events_in_rooms_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_ids: Collection[str],
|
||||
from_id: int,
|
||||
to_id: int,
|
||||
now: int,
|
||||
) -> List[Tuple[int, str, str]]:
|
||||
if len(room_ids) == 0:
|
||||
@@ -160,9 +164,10 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore):
|
||||
)
|
||||
txn.execute(
|
||||
f"""
|
||||
SELECT stream_id, room_id, event_id FROM sticky_events WHERE soft_failed=FALSE AND expires_at > ? AND stream_id > ? AND {clause}
|
||||
SELECT stream_id, room_id, event_id FROM sticky_events
|
||||
WHERE soft_failed=FALSE AND expires_at > ? AND stream_id > ? AND stream_id <= ? AND {clause}
|
||||
""",
|
||||
(now, from_id, room_id_values),
|
||||
(now, from_id, to_id, room_id_values),
|
||||
)
|
||||
return cast(List[Tuple[int, str, str]], txn.fetchall())
|
||||
|
||||
|
||||
@@ -11,9 +11,8 @@
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
CREATE SEQUENCE sticky_events_sequence
|
||||
-- Synapse streams start at 2, because the default position is 1
|
||||
-- so any item inserted at position 1 is ignored.
|
||||
-- This is also what existing streams do, except they use `setval(..., 1)`
|
||||
-- which is semantically the same except less obvious.
|
||||
START WITH 2;
|
||||
CREATE SEQUENCE sticky_events_sequence;
|
||||
-- Synapse streams start at 2, because the default position is 1
|
||||
-- so any item inserted at position 1 is ignored.
|
||||
-- We have to use nextval not START WITH 2, see https://github.com/element-hq/synapse/issues/18712
|
||||
SELECT nextval('thread_subscriptions_sequence');
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Final,
|
||||
Generic,
|
||||
@@ -396,12 +397,26 @@ class SlidingSyncResult:
|
||||
or bool(self.prev_batch)
|
||||
)
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class StickyEventsExtension:
|
||||
"""The Sticky Events extension (MSC4354)
|
||||
|
||||
Attributes:
|
||||
room_id_to_sticky_events: map (room_id -> [unexpired_sticky_events])
|
||||
"""
|
||||
|
||||
room_id_to_sticky_events: Mapping[str, Collection[EventBase]]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.room_id_to_sticky_events)
|
||||
|
||||
to_device: Optional[ToDeviceExtension] = None
|
||||
e2ee: Optional[E2eeExtension] = None
|
||||
account_data: Optional[AccountDataExtension] = None
|
||||
receipts: Optional[ReceiptsExtension] = None
|
||||
typing: Optional[TypingExtension] = None
|
||||
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None
|
||||
sticky_events: Optional[StickyEventsExtension] = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(
|
||||
@@ -411,6 +426,7 @@ class SlidingSyncResult:
|
||||
or self.receipts
|
||||
or self.typing
|
||||
or self.thread_subscriptions
|
||||
or self.sticky_events
|
||||
)
|
||||
|
||||
next_pos: SlidingSyncStreamToken
|
||||
|
||||
@@ -376,6 +376,15 @@ class SlidingSyncBody(RequestBodyModel):
|
||||
enabled: Optional[StrictBool] = False
|
||||
limit: StrictInt = 100
|
||||
|
||||
class StickyEventsExtension(RequestBodyModel):
|
||||
"""The Sticky Events extension (MSC4354)
|
||||
|
||||
Attributes:
|
||||
enabled
|
||||
"""
|
||||
|
||||
enabled: Optional[StrictBool] = False
|
||||
|
||||
to_device: Optional[ToDeviceExtension] = None
|
||||
e2ee: Optional[E2eeExtension] = None
|
||||
account_data: Optional[AccountDataExtension] = None
|
||||
@@ -384,6 +393,9 @@ class SlidingSyncBody(RequestBodyModel):
|
||||
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field(
|
||||
alias="io.element.msc4308.thread_subscriptions"
|
||||
)
|
||||
sticky_events: Optional[StickyEventsExtension] = Field(
|
||||
alias="org.matrix.msc4354.sticky_events"
|
||||
)
|
||||
|
||||
conn_id: Optional[StrictStr]
|
||||
|
||||
|
||||
@@ -348,6 +348,7 @@ T3 = TypeVar("T3")
|
||||
T4 = TypeVar("T4")
|
||||
T5 = TypeVar("T5")
|
||||
T6 = TypeVar("T6")
|
||||
T7 = TypeVar("T7")
|
||||
|
||||
|
||||
@overload
|
||||
@@ -479,6 +480,30 @@ async def gather_optional_coroutines(
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[
|
||||
Tuple[
|
||||
Optional[Coroutine[Any, Any, T1]],
|
||||
Optional[Coroutine[Any, Any, T2]],
|
||||
Optional[Coroutine[Any, Any, T3]],
|
||||
Optional[Coroutine[Any, Any, T4]],
|
||||
Optional[Coroutine[Any, Any, T5]],
|
||||
Optional[Coroutine[Any, Any, T6]],
|
||||
Optional[Coroutine[Any, Any, T7]],
|
||||
]
|
||||
],
|
||||
) -> Tuple[
|
||||
Optional[T1],
|
||||
Optional[T2],
|
||||
Optional[T3],
|
||||
Optional[T4],
|
||||
Optional[T5],
|
||||
Optional[T6],
|
||||
Optional[T7],
|
||||
]: ...
|
||||
|
||||
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
|
||||
) -> Tuple[Optional[T1], ...]:
|
||||
|
||||
Reference in New Issue
Block a user