Compare commits
26 Commits
develop
...
devon/ssex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fa43cb0b4 | ||
|
|
f778ac32c1 | ||
|
|
003fc725db | ||
|
|
934f99a694 | ||
|
|
78e8ec6161 | ||
|
|
f59419377d | ||
|
|
a3b34dfafd | ||
|
|
cb82a4a687 | ||
|
|
0c0ece9612 | ||
|
|
46e3f6756c | ||
|
|
dedd6e35e6 | ||
|
|
a3c7b3ecb9 | ||
|
|
bf594a28a8 | ||
|
|
c757969597 | ||
|
|
4cb0eeabdf | ||
|
|
4d7826b006 | ||
|
|
ab7e5a2b17 | ||
|
|
4c51247cb3 | ||
|
|
4dd82e581a | ||
|
|
6e69338abc | ||
|
|
79ea4bed33 | ||
|
|
9ef4ca173e | ||
|
|
24b38733df | ||
|
|
4602b56643 | ||
|
|
6c460b3eae | ||
|
|
cd4f4223de |
1
changelog.d/19005.feature
Normal file
1
changelog.d/19005.feature
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Add experimental support for MSC4360: Sliding Sync Threads Extension.
|
||||||
@@ -272,6 +272,9 @@ class EventContentFields:
|
|||||||
M_TOPIC: Final = "m.topic"
|
M_TOPIC: Final = "m.topic"
|
||||||
M_TEXT: Final = "m.text"
|
M_TEXT: Final = "m.text"
|
||||||
|
|
||||||
|
# Event relations
|
||||||
|
RELATIONS: Final = "m.relates_to"
|
||||||
|
|
||||||
|
|
||||||
class EventUnsignedContentFields:
|
class EventUnsignedContentFields:
|
||||||
"""Fields found inside the 'unsigned' data on events"""
|
"""Fields found inside the 'unsigned' data on events"""
|
||||||
@@ -360,3 +363,10 @@ class Direction(enum.Enum):
|
|||||||
class ProfileFields:
|
class ProfileFields:
|
||||||
DISPLAYNAME: Final = "displayname"
|
DISPLAYNAME: Final = "displayname"
|
||||||
AVATAR_URL: Final = "avatar_url"
|
AVATAR_URL: Final = "avatar_url"
|
||||||
|
|
||||||
|
|
||||||
|
class MRelatesToFields:
|
||||||
|
"""Fields found inside m.relates_to content blocks."""
|
||||||
|
|
||||||
|
EVENT_ID: Final = "event_id"
|
||||||
|
REL_TYPE: Final = "rel_type"
|
||||||
|
|||||||
@@ -593,3 +593,6 @@ class ExperimentalConfig(Config):
|
|||||||
# MSC4306: Thread Subscriptions
|
# MSC4306: Thread Subscriptions
|
||||||
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
|
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
|
||||||
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)
|
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)
|
||||||
|
|
||||||
|
# MSC4360: Threads Extension to Sliding Sync
|
||||||
|
self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False)
|
||||||
|
|||||||
@@ -105,8 +105,6 @@ class RelationsHandler:
|
|||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Get related events of a event, ordered by topological ordering.
|
"""Get related events of a event, ordered by topological ordering.
|
||||||
|
|
||||||
TODO Accept a PaginationConfig instead of individual pagination parameters.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requester: The user requesting the relations.
|
requester: The user requesting the relations.
|
||||||
event_id: Fetch events that relate to this event ID.
|
event_id: Fetch events that relate to this event ID.
|
||||||
|
|||||||
@@ -305,6 +305,7 @@ class SlidingSyncHandler:
|
|||||||
# account data, read receipts, typing indicators, to-device messages, etc).
|
# account data, read receipts, typing indicators, to-device messages, etc).
|
||||||
actual_room_ids=set(relevant_room_map.keys()),
|
actual_room_ids=set(relevant_room_map.keys()),
|
||||||
actual_room_response_map=rooms,
|
actual_room_response_map=rooms,
|
||||||
|
room_membership_for_user_at_to_token_map=room_membership_for_user_map,
|
||||||
from_token=from_token,
|
from_token=from_token,
|
||||||
to_token=to_token,
|
to_token=to_token,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
AbstractSet,
|
AbstractSet,
|
||||||
@@ -26,16 +27,28 @@ from typing import (
|
|||||||
|
|
||||||
from typing_extensions import TypeAlias, assert_never
|
from typing_extensions import TypeAlias, assert_never
|
||||||
|
|
||||||
from synapse.api.constants import AccountDataTypes, EduTypes
|
from synapse.api.constants import (
|
||||||
|
AccountDataTypes,
|
||||||
|
EduTypes,
|
||||||
|
EventContentFields,
|
||||||
|
Membership,
|
||||||
|
MRelatesToFields,
|
||||||
|
RelationTypes,
|
||||||
|
)
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.handlers.receipts import ReceiptEventSource
|
from synapse.handlers.receipts import ReceiptEventSource
|
||||||
|
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
|
||||||
from synapse.logging.opentracing import trace
|
from synapse.logging.opentracing import trace
|
||||||
from synapse.storage.databases.main.receipts import ReceiptInRoom
|
from synapse.storage.databases.main.receipts import ReceiptInRoom
|
||||||
|
from synapse.storage.databases.main.relations import ThreadUpdateInfo
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
DeviceListUpdates,
|
DeviceListUpdates,
|
||||||
JsonMapping,
|
JsonMapping,
|
||||||
MultiWriterStreamToken,
|
MultiWriterStreamToken,
|
||||||
|
RoomStreamToken,
|
||||||
SlidingSyncStreamToken,
|
SlidingSyncStreamToken,
|
||||||
StrCollection,
|
StrCollection,
|
||||||
|
StreamKeyType,
|
||||||
StreamToken,
|
StreamToken,
|
||||||
ThreadSubscriptionsToken,
|
ThreadSubscriptionsToken,
|
||||||
)
|
)
|
||||||
@@ -51,6 +64,7 @@ from synapse.util.async_helpers import (
|
|||||||
concurrently_execute,
|
concurrently_execute,
|
||||||
gather_optional_coroutines,
|
gather_optional_coroutines,
|
||||||
)
|
)
|
||||||
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
_ThreadSubscription: TypeAlias = (
|
_ThreadSubscription: TypeAlias = (
|
||||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
|
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
|
||||||
@@ -58,6 +72,7 @@ _ThreadSubscription: TypeAlias = (
|
|||||||
_ThreadUnsubscription: TypeAlias = (
|
_ThreadUnsubscription: TypeAlias = (
|
||||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
|
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
|
||||||
)
|
)
|
||||||
|
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@@ -73,7 +88,10 @@ class SlidingSyncExtensionHandler:
|
|||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.push_rules_handler = hs.get_push_rules_handler()
|
self.push_rules_handler = hs.get_push_rules_handler()
|
||||||
|
self.relations_handler = hs.get_relations_handler()
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
|
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
|
||||||
|
self._enable_threads_ext = hs.config.experimental.msc4360_enabled
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def get_extensions_response(
|
async def get_extensions_response(
|
||||||
@@ -84,6 +102,7 @@ class SlidingSyncExtensionHandler:
|
|||||||
actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
|
actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
|
||||||
actual_room_ids: set[str],
|
actual_room_ids: set[str],
|
||||||
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
|
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
|
||||||
|
room_membership_for_user_at_to_token_map: Mapping[str, RoomsForUserType],
|
||||||
to_token: StreamToken,
|
to_token: StreamToken,
|
||||||
from_token: SlidingSyncStreamToken | None,
|
from_token: SlidingSyncStreamToken | None,
|
||||||
) -> SlidingSyncResult.Extensions:
|
) -> SlidingSyncResult.Extensions:
|
||||||
@@ -99,6 +118,8 @@ class SlidingSyncExtensionHandler:
|
|||||||
actual_room_ids: The actual room IDs in the the Sliding Sync response.
|
actual_room_ids: The actual room IDs in the the Sliding Sync response.
|
||||||
actual_room_response_map: A map of room ID to room results in the the
|
actual_room_response_map: A map of room ID to room results in the the
|
||||||
Sliding Sync response.
|
Sliding Sync response.
|
||||||
|
room_membership_for_user_at_to_token_map: A map of room ID to the membership
|
||||||
|
information for the user in the room at the time of `to_token`.
|
||||||
to_token: The latest point in the stream to sync up to.
|
to_token: The latest point in the stream to sync up to.
|
||||||
from_token: The point in the stream to sync from.
|
from_token: The point in the stream to sync from.
|
||||||
"""
|
"""
|
||||||
@@ -174,6 +195,18 @@ class SlidingSyncExtensionHandler:
|
|||||||
from_token=from_token,
|
from_token=from_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
threads_coro = None
|
||||||
|
if sync_config.extensions.threads is not None and self._enable_threads_ext:
|
||||||
|
threads_coro = self.get_threads_extension_response(
|
||||||
|
sync_config=sync_config,
|
||||||
|
threads_request=sync_config.extensions.threads,
|
||||||
|
actual_room_ids=actual_room_ids,
|
||||||
|
actual_room_response_map=actual_room_response_map,
|
||||||
|
room_membership_for_user_at_to_token_map=room_membership_for_user_at_to_token_map,
|
||||||
|
to_token=to_token,
|
||||||
|
from_token=from_token,
|
||||||
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
to_device_response,
|
to_device_response,
|
||||||
e2ee_response,
|
e2ee_response,
|
||||||
@@ -181,6 +214,7 @@ class SlidingSyncExtensionHandler:
|
|||||||
receipts_response,
|
receipts_response,
|
||||||
typing_response,
|
typing_response,
|
||||||
thread_subs_response,
|
thread_subs_response,
|
||||||
|
threads_response,
|
||||||
) = await gather_optional_coroutines(
|
) = await gather_optional_coroutines(
|
||||||
to_device_coro,
|
to_device_coro,
|
||||||
e2ee_coro,
|
e2ee_coro,
|
||||||
@@ -188,6 +222,7 @@ class SlidingSyncExtensionHandler:
|
|||||||
receipts_coro,
|
receipts_coro,
|
||||||
typing_coro,
|
typing_coro,
|
||||||
thread_subs_coro,
|
thread_subs_coro,
|
||||||
|
threads_coro,
|
||||||
)
|
)
|
||||||
|
|
||||||
return SlidingSyncResult.Extensions(
|
return SlidingSyncResult.Extensions(
|
||||||
@@ -197,6 +232,7 @@ class SlidingSyncExtensionHandler:
|
|||||||
receipts=receipts_response,
|
receipts=receipts_response,
|
||||||
typing=typing_response,
|
typing=typing_response,
|
||||||
thread_subscriptions=thread_subs_response,
|
thread_subscriptions=thread_subs_response,
|
||||||
|
threads=threads_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_relevant_room_ids_for_extension(
|
def find_relevant_room_ids_for_extension(
|
||||||
@@ -967,3 +1003,273 @@ class SlidingSyncExtensionHandler:
|
|||||||
unsubscribed=unsubscribed_threads,
|
unsubscribed=unsubscribed_threads,
|
||||||
prev_batch=prev_batch,
|
prev_batch=prev_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _extract_thread_id_from_event(self, event: EventBase) -> str | None:
|
||||||
|
"""Extract thread ID from event if it's a thread reply.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The event to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The thread ID if the event is a thread reply, None otherwise.
|
||||||
|
"""
|
||||||
|
relates_to = event.content.get(EventContentFields.RELATIONS)
|
||||||
|
if isinstance(relates_to, dict):
|
||||||
|
if relates_to.get(MRelatesToFields.REL_TYPE) == RelationTypes.THREAD:
|
||||||
|
return relates_to.get(MRelatesToFields.EVENT_ID)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_threads_in_timeline(
|
||||||
|
self,
|
||||||
|
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
|
||||||
|
) -> set[str]:
|
||||||
|
"""Find all thread IDs that have events in room timelines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual_room_response_map: A map of room ID to room results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of thread IDs (thread root event IDs) that appear in the timeline.
|
||||||
|
"""
|
||||||
|
threads_in_timeline: set[str] = set()
|
||||||
|
for room_result in actual_room_response_map.values():
|
||||||
|
if room_result.timeline_events:
|
||||||
|
for event in room_result.timeline_events:
|
||||||
|
thread_id = self._extract_thread_id_from_event(event)
|
||||||
|
if thread_id:
|
||||||
|
threads_in_timeline.add(thread_id)
|
||||||
|
return threads_in_timeline
|
||||||
|
|
||||||
|
def _merge_prev_batch_token(
|
||||||
|
self,
|
||||||
|
current_token: StreamToken | None,
|
||||||
|
new_token: StreamToken | None,
|
||||||
|
) -> StreamToken | None:
|
||||||
|
"""Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_token: The current prev_batch token (may be None)
|
||||||
|
new_token: The new prev_batch token to merge (may be None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The merged token (maximum of the two, or None if both are None)
|
||||||
|
"""
|
||||||
|
if new_token is None:
|
||||||
|
return current_token
|
||||||
|
if current_token is None:
|
||||||
|
return new_token
|
||||||
|
if new_token.room_key.stream > current_token.room_key.stream:
|
||||||
|
return new_token
|
||||||
|
return current_token
|
||||||
|
|
||||||
|
def _merge_thread_updates(
|
||||||
|
self,
|
||||||
|
target: dict[str, list[ThreadUpdateInfo]],
|
||||||
|
source: dict[str, list[ThreadUpdateInfo]],
|
||||||
|
) -> None:
|
||||||
|
"""Merge thread updates from source into target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The target dict to merge into (modified in place)
|
||||||
|
source: The source dict to merge from
|
||||||
|
"""
|
||||||
|
for thread_id, updates in source.items():
|
||||||
|
target.setdefault(thread_id, []).extend(updates)
|
||||||
|
|
||||||
|
async def get_threads_extension_response(
|
||||||
|
self,
|
||||||
|
sync_config: SlidingSyncConfig,
|
||||||
|
threads_request: SlidingSyncConfig.Extensions.ThreadsExtension,
|
||||||
|
actual_room_ids: set[str],
|
||||||
|
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
|
||||||
|
room_membership_for_user_at_to_token_map: Mapping[str, RoomsForUserType],
|
||||||
|
to_token: StreamToken,
|
||||||
|
from_token: SlidingSyncStreamToken | None,
|
||||||
|
) -> SlidingSyncResult.Extensions.ThreadsExtension | None:
|
||||||
|
"""Handle Threads extension (MSC4360)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sync_config: Sync configuration.
|
||||||
|
threads_request: The threads extension from the request.
|
||||||
|
actual_room_ids: The actual room IDs in the the Sliding Sync response.
|
||||||
|
actual_room_response_map: A map of room ID to room results in the
|
||||||
|
sliding sync response. Used to determine which threads already have
|
||||||
|
events in the room timeline.
|
||||||
|
room_membership_for_user_at_to_token_map: A map of room ID to the membership
|
||||||
|
information for the user in the room at the time of `to_token`.
|
||||||
|
to_token: The point in the stream to sync up to.
|
||||||
|
from_token: The point in the stream to sync from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the response (None if empty or threads extension is disabled)
|
||||||
|
"""
|
||||||
|
if not threads_request.enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Identify which threads already have events in the room timelines.
|
||||||
|
# If include_roots=False, we'll exclude these threads from the DB query
|
||||||
|
# since the client already sees the thread activity in the timeline.
|
||||||
|
# If include_roots=True, we fetch all threads regardless, because the client
|
||||||
|
# wants the thread root events.
|
||||||
|
threads_to_exclude: set[str] | None = None
|
||||||
|
if not threads_request.include_roots:
|
||||||
|
threads_to_exclude = self._find_threads_in_timeline(
|
||||||
|
actual_room_response_map
|
||||||
|
)
|
||||||
|
|
||||||
|
# Separate rooms into groups based on membership status.
|
||||||
|
# For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events
|
||||||
|
# that occurred after the user left/was banned.
|
||||||
|
leave_ban_rooms: set[str] = set()
|
||||||
|
other_rooms: set[str] = set()
|
||||||
|
|
||||||
|
for room_id in actual_room_ids:
|
||||||
|
membership_info = room_membership_for_user_at_to_token_map.get(room_id)
|
||||||
|
if membership_info and membership_info.membership in (
|
||||||
|
Membership.LEAVE,
|
||||||
|
Membership.BAN,
|
||||||
|
):
|
||||||
|
leave_ban_rooms.add(room_id)
|
||||||
|
else:
|
||||||
|
other_rooms.add(room_id)
|
||||||
|
|
||||||
|
# Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks.
|
||||||
|
all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {}
|
||||||
|
prev_batch_token: StreamToken | None = None
|
||||||
|
remaining_limit = threads_request.limit
|
||||||
|
|
||||||
|
# Query for rooms where the user has left or been banned, using their leave/ban
|
||||||
|
# event position as the upper bound to prevent seeing events after they left.
|
||||||
|
if leave_ban_rooms:
|
||||||
|
for room_id in leave_ban_rooms:
|
||||||
|
if remaining_limit <= 0:
|
||||||
|
# We've already fetched enough updates, but we still need to set
|
||||||
|
# prev_batch to indicate there are more results.
|
||||||
|
prev_batch_token = to_token
|
||||||
|
break
|
||||||
|
|
||||||
|
membership_info = room_membership_for_user_at_to_token_map[room_id]
|
||||||
|
bounded_to_token = membership_info.event_pos.to_room_stream_token()
|
||||||
|
|
||||||
|
(
|
||||||
|
room_thread_updates,
|
||||||
|
room_prev_batch,
|
||||||
|
) = await self.store.get_thread_updates_for_rooms(
|
||||||
|
room_ids={room_id},
|
||||||
|
from_token=from_token.stream_token.room_key if from_token else None,
|
||||||
|
to_token=bounded_to_token,
|
||||||
|
limit=remaining_limit,
|
||||||
|
exclude_thread_ids=threads_to_exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count how many updates we fetched and reduce the remaining limit
|
||||||
|
num_updates = sum(
|
||||||
|
len(updates) for updates in room_thread_updates.values()
|
||||||
|
)
|
||||||
|
remaining_limit -= num_updates
|
||||||
|
|
||||||
|
self._merge_thread_updates(all_thread_updates, room_thread_updates)
|
||||||
|
prev_batch_token = self._merge_prev_batch_token(
|
||||||
|
prev_batch_token, room_prev_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query for rooms where the user is joined, invited, or knocking, using the
|
||||||
|
# normal to_token as the upper bound.
|
||||||
|
if other_rooms and remaining_limit > 0:
|
||||||
|
(
|
||||||
|
other_thread_updates,
|
||||||
|
other_prev_batch,
|
||||||
|
) = await self.store.get_thread_updates_for_rooms(
|
||||||
|
room_ids=other_rooms,
|
||||||
|
from_token=from_token.stream_token.room_key if from_token else None,
|
||||||
|
to_token=to_token.room_key,
|
||||||
|
limit=remaining_limit,
|
||||||
|
exclude_thread_ids=threads_to_exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._merge_thread_updates(all_thread_updates, other_thread_updates)
|
||||||
|
prev_batch_token = self._merge_prev_batch_token(
|
||||||
|
prev_batch_token, other_prev_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(all_thread_updates) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build a mapping of event_id -> (thread_id, update) for efficient lookup
|
||||||
|
# during visibility filtering.
|
||||||
|
event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {}
|
||||||
|
for thread_id, updates in all_thread_updates.items():
|
||||||
|
for update in updates:
|
||||||
|
event_to_thread_map[update.event_id] = (thread_id, update)
|
||||||
|
|
||||||
|
# Fetch and filter events for visibility
|
||||||
|
all_events = await self.store.get_events_as_list(event_to_thread_map.keys())
|
||||||
|
filtered_events = await filter_events_for_client(
|
||||||
|
self._storage_controllers, sync_config.user.to_string(), all_events
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rebuild thread updates from filtered events
|
||||||
|
filtered_updates: dict[str, list[ThreadUpdateInfo]] = defaultdict(list)
|
||||||
|
for event in filtered_events:
|
||||||
|
if event.event_id in event_to_thread_map:
|
||||||
|
thread_id, update = event_to_thread_map[event.event_id]
|
||||||
|
filtered_updates[thread_id].append(update)
|
||||||
|
|
||||||
|
if not filtered_updates:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Note: Updates are already sorted by stream_ordering DESC from the database query,
|
||||||
|
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
|
||||||
|
# the latest event for each thread.
|
||||||
|
|
||||||
|
# Optionally fetch thread root events and their bundled aggregations
|
||||||
|
thread_root_event_map = {}
|
||||||
|
aggregations_map = {}
|
||||||
|
if threads_request.include_roots:
|
||||||
|
thread_root_events = await self.store.get_events_as_list(
|
||||||
|
filtered_updates.keys()
|
||||||
|
)
|
||||||
|
thread_root_event_map = {e.event_id: e for e in thread_root_events}
|
||||||
|
|
||||||
|
if thread_root_event_map:
|
||||||
|
aggregations_map = (
|
||||||
|
await self.relations_handler.get_bundled_aggregations(
|
||||||
|
thread_root_event_map.values(),
|
||||||
|
sync_config.user.to_string(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_updates: dict[str, dict[str, _ThreadUpdate]] = {}
|
||||||
|
for thread_root, updates in filtered_updates.items():
|
||||||
|
# We only care about the latest update for the thread.
|
||||||
|
# After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering).
|
||||||
|
latest_update = updates[0]
|
||||||
|
|
||||||
|
# Generate per-thread prev_batch token if this thread has multiple visible updates.
|
||||||
|
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
|
||||||
|
# we only saw 1 update for them. This is to cover the case where we only saw
|
||||||
|
# a single update for a given thread, but the global limit prevents us from
|
||||||
|
# obtaining other updates which would have otherwise been included in the
|
||||||
|
# range.
|
||||||
|
per_thread_prev_batch = None
|
||||||
|
if len(updates) > 1 or prev_batch_token is not None:
|
||||||
|
# Create a token pointing to one position before the latest event's stream position.
|
||||||
|
# This makes it exclusive - /relations with dir=b won't return the latest event again.
|
||||||
|
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
|
||||||
|
per_thread_prev_batch = StreamToken.START.copy_and_replace(
|
||||||
|
StreamKeyType.ROOM,
|
||||||
|
RoomStreamToken(stream=latest_update.stream_ordering - 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_updates.setdefault(latest_update.room_id, {})[thread_root] = (
|
||||||
|
_ThreadUpdate(
|
||||||
|
thread_root=thread_root_event_map.get(thread_root),
|
||||||
|
prev_batch=per_thread_prev_batch,
|
||||||
|
bundled_aggregations=aggregations_map.get(thread_root),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return SlidingSyncResult.Extensions.ThreadsExtension(
|
||||||
|
updates=thread_updates,
|
||||||
|
prev_batch=prev_batch_token,
|
||||||
|
)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from synapse.api.filtering import FilterCollection
|
|||||||
from synapse.api.presence import UserPresenceState
|
from synapse.api.presence import UserPresenceState
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.events.utils import (
|
from synapse.events.utils import (
|
||||||
|
EventClientSerializer,
|
||||||
SerializeEventConfig,
|
SerializeEventConfig,
|
||||||
format_event_for_client_v2_without_room_id,
|
format_event_for_client_v2_without_room_id,
|
||||||
format_event_raw,
|
format_event_raw,
|
||||||
@@ -56,6 +57,7 @@ from synapse.http.servlet import (
|
|||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
|
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
|
||||||
from synapse.rest.admin.experimental_features import ExperimentalFeature
|
from synapse.rest.admin.experimental_features import ExperimentalFeature
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
|
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
|
||||||
from synapse.types.rest.client import SlidingSyncBody
|
from synapse.types.rest.client import SlidingSyncBody
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
@@ -646,6 +648,7 @@ class SlidingSyncRestServlet(RestServlet):
|
|||||||
- receipts (MSC3960)
|
- receipts (MSC3960)
|
||||||
- account data (MSC3959)
|
- account data (MSC3959)
|
||||||
- thread subscriptions (MSC4308)
|
- thread subscriptions (MSC4308)
|
||||||
|
- threads (MSC4360)
|
||||||
|
|
||||||
Request query parameters:
|
Request query parameters:
|
||||||
timeout: How long to wait for new events in milliseconds.
|
timeout: How long to wait for new events in milliseconds.
|
||||||
@@ -849,7 +852,10 @@ class SlidingSyncRestServlet(RestServlet):
|
|||||||
logger.info("Client has disconnected; not serializing response.")
|
logger.info("Client has disconnected; not serializing response.")
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
response_content = await self.encode_response(requester, sliding_sync_results)
|
time_now = self.clock.time_msec()
|
||||||
|
response_content = await self.encode_response(
|
||||||
|
requester, sliding_sync_results, time_now
|
||||||
|
)
|
||||||
|
|
||||||
return 200, response_content
|
return 200, response_content
|
||||||
|
|
||||||
@@ -858,6 +864,7 @@ class SlidingSyncRestServlet(RestServlet):
|
|||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
sliding_sync_result: SlidingSyncResult,
|
sliding_sync_result: SlidingSyncResult,
|
||||||
|
time_now: int,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
response: JsonDict = defaultdict(dict)
|
response: JsonDict = defaultdict(dict)
|
||||||
|
|
||||||
@@ -866,10 +873,10 @@ class SlidingSyncRestServlet(RestServlet):
|
|||||||
if serialized_lists:
|
if serialized_lists:
|
||||||
response["lists"] = serialized_lists
|
response["lists"] = serialized_lists
|
||||||
response["rooms"] = await self.encode_rooms(
|
response["rooms"] = await self.encode_rooms(
|
||||||
requester, sliding_sync_result.rooms
|
requester, sliding_sync_result.rooms, time_now
|
||||||
)
|
)
|
||||||
response["extensions"] = await self.encode_extensions(
|
response["extensions"] = await self.encode_extensions(
|
||||||
requester, sliding_sync_result.extensions
|
requester, sliding_sync_result.extensions, time_now
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@@ -901,9 +908,8 @@ class SlidingSyncRestServlet(RestServlet):
|
|||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
rooms: dict[str, SlidingSyncResult.RoomResult],
|
rooms: dict[str, SlidingSyncResult.RoomResult],
|
||||||
|
time_now: int,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
time_now = self.clock.time_msec()
|
|
||||||
|
|
||||||
serialize_options = SerializeEventConfig(
|
serialize_options = SerializeEventConfig(
|
||||||
event_format=format_event_for_client_v2_without_room_id,
|
event_format=format_event_for_client_v2_without_room_id,
|
||||||
requester=requester,
|
requester=requester,
|
||||||
@@ -1019,7 +1025,10 @@ class SlidingSyncRestServlet(RestServlet):
|
|||||||
|
|
||||||
@trace_with_opname("sliding_sync.encode_extensions")
|
@trace_with_opname("sliding_sync.encode_extensions")
|
||||||
async def encode_extensions(
|
async def encode_extensions(
|
||||||
self, requester: Requester, extensions: SlidingSyncResult.Extensions
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
extensions: SlidingSyncResult.Extensions,
|
||||||
|
time_now: int,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
serialized_extensions: JsonDict = {}
|
serialized_extensions: JsonDict = {}
|
||||||
|
|
||||||
@@ -1089,6 +1098,17 @@ class SlidingSyncRestServlet(RestServlet):
|
|||||||
_serialise_thread_subscriptions(extensions.thread_subscriptions)
|
_serialise_thread_subscriptions(extensions.thread_subscriptions)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# excludes both None and falsy `threads`
|
||||||
|
if extensions.threads:
|
||||||
|
serialized_extensions[
|
||||||
|
"io.element.msc4360.threads"
|
||||||
|
] = await _serialise_threads(
|
||||||
|
self.event_serializer,
|
||||||
|
time_now,
|
||||||
|
extensions.threads,
|
||||||
|
self.store,
|
||||||
|
)
|
||||||
|
|
||||||
return serialized_extensions
|
return serialized_extensions
|
||||||
|
|
||||||
|
|
||||||
@@ -1125,6 +1145,72 @@ def _serialise_thread_subscriptions(
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
async def _serialise_threads(
|
||||||
|
event_serializer: EventClientSerializer,
|
||||||
|
time_now: int,
|
||||||
|
threads: SlidingSyncResult.Extensions.ThreadsExtension,
|
||||||
|
store: "DataStore",
|
||||||
|
) -> JsonDict:
|
||||||
|
"""
|
||||||
|
Serialize the threads extension response for sliding sync.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_serializer: The event serializer to use for serializing thread root events.
|
||||||
|
time_now: The current time in milliseconds, used for event serialization.
|
||||||
|
threads: The threads extension data containing thread updates and pagination tokens.
|
||||||
|
store: The datastore, needed for serializing stream tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON-serializable dict containing:
|
||||||
|
- "updates": A nested dict mapping room_id -> thread_root_id -> thread update.
|
||||||
|
Each thread update may contain:
|
||||||
|
- "thread_root": The serialized thread root event (if include_roots was True),
|
||||||
|
with bundled aggregations including the latest_event in unsigned.m.relations.m.thread.
|
||||||
|
- "prev_batch": A pagination token for fetching older events in the thread.
|
||||||
|
- "prev_batch": A pagination token for fetching older thread updates (if available).
|
||||||
|
"""
|
||||||
|
out: JsonDict = {}
|
||||||
|
|
||||||
|
if threads.updates:
|
||||||
|
updates_dict: JsonDict = {}
|
||||||
|
for room_id, thread_updates in threads.updates.items():
|
||||||
|
room_updates: JsonDict = {}
|
||||||
|
for thread_root_id, update in thread_updates.items():
|
||||||
|
# Serialize the update
|
||||||
|
update_dict: JsonDict = {}
|
||||||
|
|
||||||
|
# Serialize the thread_root event if present
|
||||||
|
if update.thread_root is not None:
|
||||||
|
# Create a mapping of event_id to bundled_aggregations
|
||||||
|
bundle_aggs_map = (
|
||||||
|
{thread_root_id: update.bundled_aggregations}
|
||||||
|
if update.bundled_aggregations
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
serialized_events = await event_serializer.serialize_events(
|
||||||
|
[update.thread_root],
|
||||||
|
time_now,
|
||||||
|
bundle_aggregations=bundle_aggs_map,
|
||||||
|
)
|
||||||
|
if serialized_events:
|
||||||
|
update_dict["thread_root"] = serialized_events[0]
|
||||||
|
|
||||||
|
# Add prev_batch if present
|
||||||
|
if update.prev_batch is not None:
|
||||||
|
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
|
||||||
|
|
||||||
|
room_updates[thread_root_id] = update_dict
|
||||||
|
|
||||||
|
updates_dict[room_id] = room_updates
|
||||||
|
|
||||||
|
out["updates"] = updates_dict
|
||||||
|
|
||||||
|
if threads.prev_batch:
|
||||||
|
out["prev_batch"] = await threads.prev_batch.to_string(store)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
SyncRestServlet(hs).register(http_server)
|
SyncRestServlet(hs).register(http_server)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Collection,
|
Collection,
|
||||||
@@ -40,13 +41,19 @@ from synapse.storage.database import (
|
|||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
make_in_list_sql_clause,
|
make_in_list_sql_clause,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.databases.main.stream import (
|
from synapse.storage.databases.main.stream import (
|
||||||
generate_next_token,
|
generate_next_token,
|
||||||
generate_pagination_bounds,
|
generate_pagination_bounds,
|
||||||
generate_pagination_where_clause,
|
generate_pagination_where_clause,
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.types import JsonDict, StreamKeyType, StreamToken
|
from synapse.types import (
|
||||||
|
JsonDict,
|
||||||
|
RoomStreamToken,
|
||||||
|
StreamKeyType,
|
||||||
|
StreamToken,
|
||||||
|
)
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -88,7 +95,23 @@ class _RelatedEvent:
|
|||||||
sender: str
|
sender: str
|
||||||
|
|
||||||
|
|
||||||
class RelationsWorkerStore(SQLBaseStore):
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class ThreadUpdateInfo:
|
||||||
|
"""
|
||||||
|
Information about a thread update for the sliding sync threads extension.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
event_id: The event ID of the event in the thread.
|
||||||
|
room_id: The room ID where this thread exists.
|
||||||
|
stream_ordering: The stream ordering of this event.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_id: str
|
||||||
|
room_id: str
|
||||||
|
stream_ordering: int
|
||||||
|
|
||||||
|
|
||||||
|
class RelationsWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database: DatabasePool,
|
database: DatabasePool,
|
||||||
@@ -584,14 +607,18 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
"get_applicable_edits", _get_applicable_edits_txn
|
"get_applicable_edits", _get_applicable_edits_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined]
|
edits = await self.get_events(edit_ids.values())
|
||||||
|
|
||||||
# Map to the original event IDs to the edit events.
|
# Map to the original event IDs to the edit events.
|
||||||
#
|
#
|
||||||
# There might not be an edit event due to there being no edits or
|
# There might not be an edit event due to there being no edits or
|
||||||
# due to the event not being known, either case is treated the same.
|
# due to the event not being known, either case is treated the same.
|
||||||
return {
|
return {
|
||||||
original_event_id: edits.get(edit_ids.get(original_event_id))
|
original_event_id: (
|
||||||
|
edits.get(edit_id)
|
||||||
|
if (edit_id := edit_ids.get(original_event_id))
|
||||||
|
else None
|
||||||
|
)
|
||||||
for original_event_id in event_ids
|
for original_event_id in event_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -699,7 +726,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
"get_thread_summaries", _get_thread_summaries_txn
|
"get_thread_summaries", _get_thread_summaries_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
|
latest_events = await self.get_events(latest_event_ids.values())
|
||||||
|
|
||||||
# Map to the event IDs to the thread summary.
|
# Map to the event IDs to the thread summary.
|
||||||
#
|
#
|
||||||
@@ -1111,6 +1138,148 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
"get_related_thread_id", _get_related_thread_id
|
"get_related_thread_id", _get_related_thread_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_thread_updates_for_rooms(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
room_ids: Collection[str],
|
||||||
|
from_token: RoomStreamToken | None = None,
|
||||||
|
to_token: RoomStreamToken | None = None,
|
||||||
|
limit: int = 5,
|
||||||
|
exclude_thread_ids: Collection[str] | None = None,
|
||||||
|
) -> tuple[dict[str, list[ThreadUpdateInfo]], StreamToken | None]:
|
||||||
|
"""Get a list of updated threads, ordered by stream ordering of their
|
||||||
|
latest reply, filtered to only include threads in rooms where the user
|
||||||
|
is currently joined.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_ids: The room IDs to fetch thread updates for.
|
||||||
|
from_token: The lower bound (exclusive) for thread updates. If None,
|
||||||
|
fetch from the start of the room timeline.
|
||||||
|
to_token: The upper bound (inclusive) for thread updates. If None,
|
||||||
|
fetch up to the current position in the room timeline.
|
||||||
|
limit: Maximum number of thread updates to return.
|
||||||
|
exclude_thread_ids: Optional collection of thread root event IDs to exclude
|
||||||
|
from the results. Useful for filtering out threads already visible
|
||||||
|
in the room timeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of:
|
||||||
|
A dict mapping thread_id to list of ThreadUpdateInfo objects,
|
||||||
|
ordered by stream_ordering descending (most recent first).
|
||||||
|
A prev_batch StreamToken (exclusive) if there are more results available,
|
||||||
|
None otherwise.
|
||||||
|
"""
|
||||||
|
# Ensure bad limits aren't being passed in.
|
||||||
|
assert limit > 0
|
||||||
|
|
||||||
|
if len(room_ids) == 0:
|
||||||
|
return ({}), None
|
||||||
|
|
||||||
|
def _get_thread_updates_for_user_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> tuple[list[tuple[str, str, str, int]], int | None]:
|
||||||
|
room_clause, room_id_values = make_in_list_sql_clause(
|
||||||
|
txn.database_engine, "e.room_id", room_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate the pagination clause, if necessary.
|
||||||
|
pagination_clause = ""
|
||||||
|
pagination_args: list[str] = []
|
||||||
|
if from_token:
|
||||||
|
from_bound = from_token.stream
|
||||||
|
pagination_clause += " AND stream_ordering > ?"
|
||||||
|
pagination_args.append(str(from_bound))
|
||||||
|
|
||||||
|
if to_token:
|
||||||
|
to_bound = to_token.stream
|
||||||
|
pagination_clause += " AND stream_ordering <= ?"
|
||||||
|
pagination_args.append(str(to_bound))
|
||||||
|
|
||||||
|
# Generate the exclusion clause for thread IDs, if necessary.
|
||||||
|
exclusion_clause = ""
|
||||||
|
exclusion_args: list[str] = []
|
||||||
|
if exclude_thread_ids:
|
||||||
|
exclusion_clause, exclusion_args = make_in_list_sql_clause(
|
||||||
|
txn.database_engine,
|
||||||
|
"er.relates_to_id",
|
||||||
|
exclude_thread_ids,
|
||||||
|
negative=True,
|
||||||
|
)
|
||||||
|
exclusion_clause = f" AND {exclusion_clause}"
|
||||||
|
|
||||||
|
# TODO: improve the fact that multiple hits for the same thread means we
|
||||||
|
# won't get as many overall updates for the sss response
|
||||||
|
|
||||||
|
# Find any thread events between the stream ordering bounds.
|
||||||
|
sql = f"""
|
||||||
|
SELECT e.event_id, er.relates_to_id, e.room_id, e.stream_ordering
|
||||||
|
FROM event_relations AS er
|
||||||
|
INNER JOIN events AS e ON er.event_id = e.event_id
|
||||||
|
WHERE er.relation_type = '{RelationTypes.THREAD}'
|
||||||
|
AND {room_clause}
|
||||||
|
{exclusion_clause}
|
||||||
|
{pagination_clause}
|
||||||
|
ORDER BY stream_ordering DESC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Fetch `limit + 1` rows as a way to detect if there are more results beyond
|
||||||
|
# what we're returning. If we get exactly `limit + 1` rows back, we know there
|
||||||
|
# are more results available and we can set `next_token`. We only return the
|
||||||
|
# first `limit` rows to the caller. This avoids needing a separate COUNT query.
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(
|
||||||
|
*room_id_values,
|
||||||
|
*exclusion_args,
|
||||||
|
*pagination_args,
|
||||||
|
limit + 1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# SQL returns: event_id, thread_id, room_id, stream_ordering
|
||||||
|
rows = cast(list[tuple[str, str, str, int]], txn.fetchall())
|
||||||
|
|
||||||
|
# If there are more events, generate the next pagination key from the
|
||||||
|
# last thread which will be returned.
|
||||||
|
next_token = None
|
||||||
|
if len(rows) > limit:
|
||||||
|
# Set the next_token to be the second last row in the result set since
|
||||||
|
# that will be the last row we return from this function.
|
||||||
|
# This works as an exclusive bound that can be backpaginated from.
|
||||||
|
# Use the stream_ordering field (index 2 in original rows)
|
||||||
|
next_token = rows[-2][3]
|
||||||
|
|
||||||
|
return rows[:limit], next_token
|
||||||
|
|
||||||
|
thread_infos, next_token_int = await self.db_pool.runInteraction(
|
||||||
|
"get_thread_updates_for_user", _get_thread_updates_for_user_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert the next_token int (stream ordering) to a StreamToken.
|
||||||
|
# Use StreamToken.START as base (all other streams at 0) since only room
|
||||||
|
# position matters.
|
||||||
|
# Subtract 1 to make it exclusive - the client can paginate from this point without
|
||||||
|
# receiving the last thread update that was already returned.
|
||||||
|
next_token = None
|
||||||
|
if next_token_int is not None:
|
||||||
|
next_token = StreamToken.START.copy_and_replace(
|
||||||
|
StreamKeyType.ROOM, RoomStreamToken(stream=next_token_int - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build ThreadUpdateInfo objects.
|
||||||
|
thread_update_infos: dict[str, list[ThreadUpdateInfo]] = defaultdict(list)
|
||||||
|
for event_id, thread_id, room_id, stream_ordering in thread_infos:
|
||||||
|
thread_update_infos[thread_id].append(
|
||||||
|
ThreadUpdateInfo(
|
||||||
|
event_id=event_id,
|
||||||
|
room_id=room_id,
|
||||||
|
stream_ordering=stream_ordering,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return (thread_update_infos, next_token)
|
||||||
|
|
||||||
|
|
||||||
class RelationsStore(RelationsWorkerStore):
|
class RelationsStore(RelationsWorkerStore):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2025 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
-- Add indexes to improve performance of the thread_updates endpoint and
|
||||||
|
-- sliding sync threads extension (MSC4360).
|
||||||
|
|
||||||
|
-- Index for efficiently finding all events that relate to a specific event
|
||||||
|
-- (e.g., all replies to a thread root). This is used by the correlated subquery
|
||||||
|
-- in get_thread_updates_for_user that counts thread updates.
|
||||||
|
-- Also useful for other relation queries (edits, reactions, etc.).
|
||||||
|
CREATE INDEX IF NOT EXISTS event_relations_relates_to_id_type
|
||||||
|
ON event_relations(relates_to_id, relation_type);
|
||||||
|
|
||||||
|
-- Index for the /thread_updates endpoint's cross-room query.
|
||||||
|
-- Allows efficient descending ordering and range filtering of threads
|
||||||
|
-- by stream_ordering across all rooms.
|
||||||
|
CREATE INDEX IF NOT EXISTS threads_stream_ordering_desc
|
||||||
|
ON threads(stream_ordering DESC);
|
||||||
|
|
||||||
|
-- Index for the EXISTS clause that filters threads to only joined rooms.
|
||||||
|
-- Allows efficient lookup of a user's current room memberships.
|
||||||
|
CREATE INDEX IF NOT EXISTS local_current_membership_user_room
|
||||||
|
ON local_current_membership(user_id, membership, room_id);
|
||||||
@@ -57,19 +57,41 @@ class PaginationConfig:
|
|||||||
from_tok_str = parse_string(request, "from")
|
from_tok_str = parse_string(request, "from")
|
||||||
to_tok_str = parse_string(request, "to")
|
to_tok_str = parse_string(request, "to")
|
||||||
|
|
||||||
|
# Helper function to extract StreamToken from either StreamToken or SlidingSyncStreamToken format
|
||||||
|
def extract_stream_token(token_str: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract the StreamToken portion from a token string.
|
||||||
|
|
||||||
|
Handles both:
|
||||||
|
- StreamToken format: "s123_456_..."
|
||||||
|
- SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /)
|
||||||
|
|
||||||
|
This allows clients using sliding sync to use their pos tokens
|
||||||
|
with endpoints like /relations and /messages.
|
||||||
|
"""
|
||||||
|
if "/" in token_str:
|
||||||
|
# SlidingSyncStreamToken format: "connection_position/stream_token"
|
||||||
|
# Split and return just the stream_token part
|
||||||
|
parts = token_str.split("/", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
return parts[1]
|
||||||
|
return token_str
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from_tok = None
|
from_tok = None
|
||||||
if from_tok_str == "END":
|
if from_tok_str == "END":
|
||||||
from_tok = None # For backwards compat.
|
from_tok = None # For backwards compat.
|
||||||
elif from_tok_str:
|
elif from_tok_str:
|
||||||
from_tok = await StreamToken.from_string(store, from_tok_str)
|
stream_token_str = extract_stream_token(from_tok_str)
|
||||||
|
from_tok = await StreamToken.from_string(store, stream_token_str)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise SynapseError(400, "'from' parameter is invalid")
|
raise SynapseError(400, "'from' parameter is invalid")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
to_tok = None
|
to_tok = None
|
||||||
if to_tok_str:
|
if to_tok_str:
|
||||||
to_tok = await StreamToken.from_string(store, to_tok_str)
|
stream_token_str = extract_stream_token(to_tok_str)
|
||||||
|
to_tok = await StreamToken.from_string(store, stream_token_str)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise SynapseError(400, "'to' parameter is invalid")
|
raise SynapseError(400, "'to' parameter is invalid")
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,9 @@ from pydantic import ConfigDict
|
|||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.handlers.relations import BundledAggregations
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
DeviceListUpdates,
|
DeviceListUpdates,
|
||||||
JsonDict,
|
JsonDict,
|
||||||
@@ -388,12 +391,60 @@ class SlidingSyncResult:
|
|||||||
or bool(self.prev_batch)
|
or bool(self.prev_batch)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class ThreadsExtension:
|
||||||
|
"""The Threads extension (MSC4360)
|
||||||
|
|
||||||
|
Provides thread updates for threads that have new activity across all of the
|
||||||
|
user's joined rooms within the sync window.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
updates: A nested mapping of room_id -> thread_root_id -> ThreadUpdate.
|
||||||
|
Each ThreadUpdate contains information about a thread that has new activity,
|
||||||
|
including the thread root event (if requested) and a pagination token
|
||||||
|
for fetching older events in that specific thread.
|
||||||
|
prev_batch: A pagination token for fetching more thread updates across all rooms.
|
||||||
|
If present, indicates there are more thread updates available beyond what
|
||||||
|
was returned in this response. This token can be used with a future request
|
||||||
|
to paginate through older thread updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class ThreadUpdate:
|
||||||
|
"""Information about a single thread that has new activity.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
thread_root: The thread root event, if requested via include_roots in the
|
||||||
|
request. This is the event that started the thread.
|
||||||
|
prev_batch: A pagination token (exclusive) for fetching older events in this
|
||||||
|
specific thread. Only present if the thread has multiple updates in the
|
||||||
|
sync window. This token can be used with the /relations endpoint with
|
||||||
|
dir=b to paginate backwards through the thread's history.
|
||||||
|
bundled_aggregations: Bundled aggregations for the thread root event,
|
||||||
|
including the latest_event in the thread (found in
|
||||||
|
unsigned.m.relations.m.thread). Only present if thread_root is included.
|
||||||
|
"""
|
||||||
|
|
||||||
|
thread_root: EventBase | None
|
||||||
|
prev_batch: StreamToken | None
|
||||||
|
bundled_aggregations: "BundledAggregations | None" = None
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.thread_root) or bool(self.prev_batch)
|
||||||
|
|
||||||
|
updates: Mapping[str, Mapping[str, ThreadUpdate]] | None
|
||||||
|
prev_batch: StreamToken | None
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.updates) or bool(self.prev_batch)
|
||||||
|
|
||||||
to_device: ToDeviceExtension | None = None
|
to_device: ToDeviceExtension | None = None
|
||||||
e2ee: E2eeExtension | None = None
|
e2ee: E2eeExtension | None = None
|
||||||
account_data: AccountDataExtension | None = None
|
account_data: AccountDataExtension | None = None
|
||||||
receipts: ReceiptsExtension | None = None
|
receipts: ReceiptsExtension | None = None
|
||||||
typing: TypingExtension | None = None
|
typing: TypingExtension | None = None
|
||||||
thread_subscriptions: ThreadSubscriptionsExtension | None = None
|
thread_subscriptions: ThreadSubscriptionsExtension | None = None
|
||||||
|
threads: ThreadsExtension | None = None
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
return bool(
|
return bool(
|
||||||
@@ -403,6 +454,7 @@ class SlidingSyncResult:
|
|||||||
or self.receipts
|
or self.receipts
|
||||||
or self.typing
|
or self.typing
|
||||||
or self.thread_subscriptions
|
or self.thread_subscriptions
|
||||||
|
or self.threads
|
||||||
)
|
)
|
||||||
|
|
||||||
next_pos: SlidingSyncStreamToken
|
next_pos: SlidingSyncStreamToken
|
||||||
@@ -852,6 +904,7 @@ class PerConnectionState:
|
|||||||
Attributes:
|
Attributes:
|
||||||
rooms: The status of each room for the events stream.
|
rooms: The status of each room for the events stream.
|
||||||
receipts: The status of each room for the receipts stream.
|
receipts: The status of each room for the receipts stream.
|
||||||
|
account_data: The status of each room for the account data stream.
|
||||||
room_configs: Map from room_id to the `RoomSyncConfig` of all
|
room_configs: Map from room_id to the `RoomSyncConfig` of all
|
||||||
rooms that we have previously sent down.
|
rooms that we have previously sent down.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -383,6 +383,19 @@ class SlidingSyncBody(RequestBodyModel):
|
|||||||
enabled: StrictBool | None = False
|
enabled: StrictBool | None = False
|
||||||
limit: StrictInt = 100
|
limit: StrictInt = 100
|
||||||
|
|
||||||
|
class ThreadsExtension(RequestBodyModel):
|
||||||
|
"""The Threads extension (MSC4360)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
enabled: Whether the threads extension is enabled.
|
||||||
|
include_roots: whether to include thread root events in the extension response.
|
||||||
|
limit: maximum number of thread updates to return across all joined rooms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: StrictBool | None = False
|
||||||
|
include_roots: StrictBool = False
|
||||||
|
limit: StrictInt = 100
|
||||||
|
|
||||||
to_device: ToDeviceExtension | None = None
|
to_device: ToDeviceExtension | None = None
|
||||||
e2ee: E2eeExtension | None = None
|
e2ee: E2eeExtension | None = None
|
||||||
account_data: AccountDataExtension | None = None
|
account_data: AccountDataExtension | None = None
|
||||||
@@ -391,6 +404,9 @@ class SlidingSyncBody(RequestBodyModel):
|
|||||||
thread_subscriptions: ThreadSubscriptionsExtension | None = Field(
|
thread_subscriptions: ThreadSubscriptionsExtension | None = Field(
|
||||||
None, alias="io.element.msc4308.thread_subscriptions"
|
None, alias="io.element.msc4308.thread_subscriptions"
|
||||||
)
|
)
|
||||||
|
threads: ThreadsExtension | None = Field(
|
||||||
|
None, alias="io.element.msc4360.threads"
|
||||||
|
)
|
||||||
|
|
||||||
conn_id: StrictStr | None = None
|
conn_id: StrictStr | None = None
|
||||||
lists: (
|
lists: (
|
||||||
|
|||||||
@@ -340,6 +340,7 @@ T3 = TypeVar("T3")
|
|||||||
T4 = TypeVar("T4")
|
T4 = TypeVar("T4")
|
||||||
T5 = TypeVar("T5")
|
T5 = TypeVar("T5")
|
||||||
T6 = TypeVar("T6")
|
T6 = TypeVar("T6")
|
||||||
|
T7 = TypeVar("T7")
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -469,6 +470,30 @@ async def gather_optional_coroutines(
|
|||||||
) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None, T6 | None]: ...
|
) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None, T6 | None]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def gather_optional_coroutines(
|
||||||
|
*coroutines: Unpack[
|
||||||
|
tuple[
|
||||||
|
Coroutine[Any, Any, T1] | None,
|
||||||
|
Coroutine[Any, Any, T2] | None,
|
||||||
|
Coroutine[Any, Any, T3] | None,
|
||||||
|
Coroutine[Any, Any, T4] | None,
|
||||||
|
Coroutine[Any, Any, T5] | None,
|
||||||
|
Coroutine[Any, Any, T6] | None,
|
||||||
|
Coroutine[Any, Any, T7] | None,
|
||||||
|
]
|
||||||
|
],
|
||||||
|
) -> tuple[
|
||||||
|
T1 | None,
|
||||||
|
T2 | None,
|
||||||
|
T3 | None,
|
||||||
|
T4 | None,
|
||||||
|
T5 | None,
|
||||||
|
T6 | None,
|
||||||
|
T7 | None,
|
||||||
|
]: ...
|
||||||
|
|
||||||
|
|
||||||
async def gather_optional_coroutines(
|
async def gather_optional_coroutines(
|
||||||
*coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None, ...]],
|
*coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None, ...]],
|
||||||
) -> tuple[T1 | None, ...]:
|
) -> tuple[T1 | None, ...]:
|
||||||
|
|||||||
1165
tests/rest/client/sliding_sync/test_extension_threads.py
Normal file
1165
tests/rest/client/sliding_sync/test_extension_threads.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user