1
0

Compare commits

...

26 Commits

Author SHA1 Message Date
Devon Hudson
6fa43cb0b4 Comment cleanup 2025-11-09 12:43:41 -07:00
Devon Hudson
f778ac32c1 Update docstring 2025-11-09 12:37:04 -07:00
Devon Hudson
003fc725db Merge branch 'develop' into devon/ssext_threads 2025-11-09 12:33:55 -07:00
Devon Hudson
934f99a694 Add wait_for_new_data tests 2025-11-09 12:09:56 -07:00
Devon Hudson
78e8ec6161 Add test for room list filtering 2025-11-09 09:44:52 -07:00
Devon Hudson
f59419377d Refactor for clarity 2025-11-09 09:35:11 -07:00
Devon Hudson
a3b34dfafd Run linter 2025-11-09 09:30:44 -07:00
Devon Hudson
cb82a4a687 Handle user leave/ban rooms to prevent leaking data 2025-11-09 08:45:52 -07:00
Devon Hudson
0c0ece9612 Fix next_token logic 2025-11-09 08:29:49 -07:00
Devon Hudson
46e3f6756c Cleanup logic 2025-11-08 10:07:46 -07:00
Devon Hudson
dedd6e35e6 Rejig thread updates to use room lists 2025-11-08 09:12:37 -07:00
Devon Hudson
a3c7b3ecb9 Don't fetch bundled aggregations if we don't have to 2025-10-16 18:06:26 -06:00
Devon Hudson
bf594a28a8 Move constants to designated file 2025-10-16 17:37:01 -06:00
Devon Hudson
c757969597 Add indexes to improve threads query performance 2025-10-10 15:39:27 -06:00
Devon Hudson
4cb0eeabdf Allow SlidingSyncStreamToken in /relations 2025-10-09 11:28:33 -06:00
Devon Hudson
4d7826b006 Filter events from extension if in timeline 2025-10-08 17:01:40 -06:00
Devon Hudson
ab7e5a2b17 Properly return prev_batch tokens for threads extension 2025-10-08 16:12:46 -06:00
Devon Hudson
4c51247cb3 Only return rooms where user is currently joined 2025-10-07 12:49:32 -06:00
Devon Hudson
4dd82e581a Add newsfile 2025-10-03 16:16:04 -06:00
Devon Hudson
6e69338abc Fix linter error 2025-10-03 16:15:11 -06:00
Devon Hudson
79ea4bed33 Add thread_root events to threads extension response 2025-10-03 15:57:13 -06:00
Devon Hudson
9ef4ca173e Add user room filtering for threads extension 2025-10-03 14:01:16 -06:00
Devon Hudson
24b38733df Don't return empty fields in response 2025-10-02 17:23:30 -06:00
Devon Hudson
4602b56643 Stub in early db queries to get tests going 2025-10-02 17:11:14 -06:00
Devon Hudson
6c460b3eae Stub in threads extension tests 2025-10-01 10:53:11 -06:00
Devon Hudson
cd4f4223de Stub in threads sliding sync extension 2025-10-01 10:04:29 -06:00
14 changed files with 1904 additions and 16 deletions

View File

@@ -0,0 +1 @@
Add experimental support for MSC4360: Sliding Sync Threads Extension.

View File

@@ -272,6 +272,9 @@ class EventContentFields:
M_TOPIC: Final = "m.topic"
M_TEXT: Final = "m.text"
# Event relations
RELATIONS: Final = "m.relates_to"
class EventUnsignedContentFields:
"""Fields found inside the 'unsigned' data on events"""
@@ -360,3 +363,10 @@ class Direction(enum.Enum):
class ProfileFields:
DISPLAYNAME: Final = "displayname"
AVATAR_URL: Final = "avatar_url"
class MRelatesToFields:
"""Fields found inside m.relates_to content blocks."""
EVENT_ID: Final = "event_id"
REL_TYPE: Final = "rel_type"

View File

@@ -593,3 +593,6 @@ class ExperimentalConfig(Config):
# MSC4306: Thread Subscriptions
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)
# MSC4360: Threads Extension to Sliding Sync
self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False)

View File

@@ -105,8 +105,6 @@ class RelationsHandler:
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
TODO Accept a PaginationConfig instead of individual pagination parameters.
Args:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.

View File

@@ -305,6 +305,7 @@ class SlidingSyncHandler:
# account data, read receipts, typing indicators, to-device messages, etc).
actual_room_ids=set(relevant_room_map.keys()),
actual_room_response_map=rooms,
room_membership_for_user_at_to_token_map=room_membership_for_user_map,
from_token=from_token,
to_token=to_token,
)

View File

@@ -14,6 +14,7 @@
import itertools
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
AbstractSet,
@@ -26,16 +27,28 @@ from typing import (
from typing_extensions import TypeAlias, assert_never
from synapse.api.constants import AccountDataTypes, EduTypes
from synapse.api.constants import (
AccountDataTypes,
EduTypes,
EventContentFields,
Membership,
MRelatesToFields,
RelationTypes,
)
from synapse.events import EventBase
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.receipts import ReceiptInRoom
from synapse.storage.databases.main.relations import ThreadUpdateInfo
from synapse.types import (
DeviceListUpdates,
JsonMapping,
MultiWriterStreamToken,
RoomStreamToken,
SlidingSyncStreamToken,
StrCollection,
StreamKeyType,
StreamToken,
ThreadSubscriptionsToken,
)
@@ -51,6 +64,7 @@ from synapse.util.async_helpers import (
concurrently_execute,
gather_optional_coroutines,
)
from synapse.visibility import filter_events_for_client
_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
@@ -58,6 +72,7 @@ _ThreadSubscription: TypeAlias = (
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -73,7 +88,10 @@ class SlidingSyncExtensionHandler:
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self.relations_handler = hs.get_relations_handler()
self._storage_controllers = hs.get_storage_controllers()
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
self._enable_threads_ext = hs.config.experimental.msc4360_enabled
@trace
async def get_extensions_response(
@@ -84,6 +102,7 @@ class SlidingSyncExtensionHandler:
actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
actual_room_ids: set[str],
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
room_membership_for_user_at_to_token_map: Mapping[str, RoomsForUserType],
to_token: StreamToken,
from_token: SlidingSyncStreamToken | None,
) -> SlidingSyncResult.Extensions:
@@ -99,6 +118,8 @@ class SlidingSyncExtensionHandler:
actual_room_ids: The actual room IDs in the the Sliding Sync response.
actual_room_response_map: A map of room ID to room results in the the
Sliding Sync response.
room_membership_for_user_at_to_token_map: A map of room ID to the membership
information for the user in the room at the time of `to_token`.
to_token: The latest point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""
@@ -174,6 +195,18 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
threads_coro = None
if sync_config.extensions.threads is not None and self._enable_threads_ext:
threads_coro = self.get_threads_extension_response(
sync_config=sync_config,
threads_request=sync_config.extensions.threads,
actual_room_ids=actual_room_ids,
actual_room_response_map=actual_room_response_map,
room_membership_for_user_at_to_token_map=room_membership_for_user_at_to_token_map,
to_token=to_token,
from_token=from_token,
)
(
to_device_response,
e2ee_response,
@@ -181,6 +214,7 @@ class SlidingSyncExtensionHandler:
receipts_response,
typing_response,
thread_subs_response,
threads_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
@@ -188,6 +222,7 @@ class SlidingSyncExtensionHandler:
receipts_coro,
typing_coro,
thread_subs_coro,
threads_coro,
)
return SlidingSyncResult.Extensions(
@@ -197,6 +232,7 @@ class SlidingSyncExtensionHandler:
receipts=receipts_response,
typing=typing_response,
thread_subscriptions=thread_subs_response,
threads=threads_response,
)
def find_relevant_room_ids_for_extension(
@@ -967,3 +1003,273 @@ class SlidingSyncExtensionHandler:
unsubscribed=unsubscribed_threads,
prev_batch=prev_batch,
)
def _extract_thread_id_from_event(self, event: EventBase) -> str | None:
"""Extract thread ID from event if it's a thread reply.
Args:
event: The event to check.
Returns:
The thread ID if the event is a thread reply, None otherwise.
"""
relates_to = event.content.get(EventContentFields.RELATIONS)
if isinstance(relates_to, dict):
if relates_to.get(MRelatesToFields.REL_TYPE) == RelationTypes.THREAD:
return relates_to.get(MRelatesToFields.EVENT_ID)
return None
def _find_threads_in_timeline(
self,
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
) -> set[str]:
"""Find all thread IDs that have events in room timelines.
Args:
actual_room_response_map: A map of room ID to room results.
Returns:
A set of thread IDs (thread root event IDs) that appear in the timeline.
"""
threads_in_timeline: set[str] = set()
for room_result in actual_room_response_map.values():
if room_result.timeline_events:
for event in room_result.timeline_events:
thread_id = self._extract_thread_id_from_event(event)
if thread_id:
threads_in_timeline.add(thread_id)
return threads_in_timeline
def _merge_prev_batch_token(
self,
current_token: StreamToken | None,
new_token: StreamToken | None,
) -> StreamToken | None:
"""Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination.
Args:
current_token: The current prev_batch token (may be None)
new_token: The new prev_batch token to merge (may be None)
Returns:
The merged token (maximum of the two, or None if both are None)
"""
if new_token is None:
return current_token
if current_token is None:
return new_token
if new_token.room_key.stream > current_token.room_key.stream:
return new_token
return current_token
def _merge_thread_updates(
self,
target: dict[str, list[ThreadUpdateInfo]],
source: dict[str, list[ThreadUpdateInfo]],
) -> None:
"""Merge thread updates from source into target.
Args:
target: The target dict to merge into (modified in place)
source: The source dict to merge from
"""
for thread_id, updates in source.items():
target.setdefault(thread_id, []).extend(updates)
async def get_threads_extension_response(
self,
sync_config: SlidingSyncConfig,
threads_request: SlidingSyncConfig.Extensions.ThreadsExtension,
actual_room_ids: set[str],
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
room_membership_for_user_at_to_token_map: Mapping[str, RoomsForUserType],
to_token: StreamToken,
from_token: SlidingSyncStreamToken | None,
) -> SlidingSyncResult.Extensions.ThreadsExtension | None:
"""Handle Threads extension (MSC4360)
Args:
sync_config: Sync configuration.
threads_request: The threads extension from the request.
actual_room_ids: The actual room IDs in the the Sliding Sync response.
actual_room_response_map: A map of room ID to room results in the
sliding sync response. Used to determine which threads already have
events in the room timeline.
room_membership_for_user_at_to_token_map: A map of room ID to the membership
information for the user in the room at the time of `to_token`.
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
Returns:
the response (None if empty or threads extension is disabled)
"""
if not threads_request.enabled:
return None
# Identify which threads already have events in the room timelines.
# If include_roots=False, we'll exclude these threads from the DB query
# since the client already sees the thread activity in the timeline.
# If include_roots=True, we fetch all threads regardless, because the client
# wants the thread root events.
threads_to_exclude: set[str] | None = None
if not threads_request.include_roots:
threads_to_exclude = self._find_threads_in_timeline(
actual_room_response_map
)
# Separate rooms into groups based on membership status.
# For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events
# that occurred after the user left/was banned.
leave_ban_rooms: set[str] = set()
other_rooms: set[str] = set()
for room_id in actual_room_ids:
membership_info = room_membership_for_user_at_to_token_map.get(room_id)
if membership_info and membership_info.membership in (
Membership.LEAVE,
Membership.BAN,
):
leave_ban_rooms.add(room_id)
else:
other_rooms.add(room_id)
# Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks.
all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {}
prev_batch_token: StreamToken | None = None
remaining_limit = threads_request.limit
# Query for rooms where the user has left or been banned, using their leave/ban
# event position as the upper bound to prevent seeing events after they left.
if leave_ban_rooms:
for room_id in leave_ban_rooms:
if remaining_limit <= 0:
# We've already fetched enough updates, but we still need to set
# prev_batch to indicate there are more results.
prev_batch_token = to_token
break
membership_info = room_membership_for_user_at_to_token_map[room_id]
bounded_to_token = membership_info.event_pos.to_room_stream_token()
(
room_thread_updates,
room_prev_batch,
) = await self.store.get_thread_updates_for_rooms(
room_ids={room_id},
from_token=from_token.stream_token.room_key if from_token else None,
to_token=bounded_to_token,
limit=remaining_limit,
exclude_thread_ids=threads_to_exclude,
)
# Count how many updates we fetched and reduce the remaining limit
num_updates = sum(
len(updates) for updates in room_thread_updates.values()
)
remaining_limit -= num_updates
self._merge_thread_updates(all_thread_updates, room_thread_updates)
prev_batch_token = self._merge_prev_batch_token(
prev_batch_token, room_prev_batch
)
# Query for rooms where the user is joined, invited, or knocking, using the
# normal to_token as the upper bound.
if other_rooms and remaining_limit > 0:
(
other_thread_updates,
other_prev_batch,
) = await self.store.get_thread_updates_for_rooms(
room_ids=other_rooms,
from_token=from_token.stream_token.room_key if from_token else None,
to_token=to_token.room_key,
limit=remaining_limit,
exclude_thread_ids=threads_to_exclude,
)
self._merge_thread_updates(all_thread_updates, other_thread_updates)
prev_batch_token = self._merge_prev_batch_token(
prev_batch_token, other_prev_batch
)
if len(all_thread_updates) == 0:
return None
# Build a mapping of event_id -> (thread_id, update) for efficient lookup
# during visibility filtering.
event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {}
for thread_id, updates in all_thread_updates.items():
for update in updates:
event_to_thread_map[update.event_id] = (thread_id, update)
# Fetch and filter events for visibility
all_events = await self.store.get_events_as_list(event_to_thread_map.keys())
filtered_events = await filter_events_for_client(
self._storage_controllers, sync_config.user.to_string(), all_events
)
# Rebuild thread updates from filtered events
filtered_updates: dict[str, list[ThreadUpdateInfo]] = defaultdict(list)
for event in filtered_events:
if event.event_id in event_to_thread_map:
thread_id, update = event_to_thread_map[event.event_id]
filtered_updates[thread_id].append(update)
if not filtered_updates:
return None
# Note: Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.
# Optionally fetch thread root events and their bundled aggregations
thread_root_event_map = {}
aggregations_map = {}
if threads_request.include_roots:
thread_root_events = await self.store.get_events_as_list(
filtered_updates.keys()
)
thread_root_event_map = {e.event_id: e for e in thread_root_events}
if thread_root_event_map:
aggregations_map = (
await self.relations_handler.get_bundled_aggregations(
thread_root_event_map.values(),
sync_config.user.to_string(),
)
)
thread_updates: dict[str, dict[str, _ThreadUpdate]] = {}
for thread_root, updates in filtered_updates.items():
# We only care about the latest update for the thread.
# After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering).
latest_update = updates[0]
# Generate per-thread prev_batch token if this thread has multiple visible updates.
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
# we only saw 1 update for them. This is to cover the case where we only saw
# a single update for a given thread, but the global limit prevents us from
# obtaining other updates which would have otherwise been included in the
# range.
per_thread_prev_batch = None
if len(updates) > 1 or prev_batch_token is not None:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
thread_updates.setdefault(latest_update.room_id, {})[thread_root] = (
_ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root),
)
)
return SlidingSyncResult.Extensions.ThreadsExtension(
updates=thread_updates,
prev_batch=prev_batch_token,
)

View File

@@ -31,6 +31,7 @@ from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.ratelimiting import Ratelimiter
from synapse.events.utils import (
EventClientSerializer,
SerializeEventConfig,
format_event_for_client_v2_without_room_id,
format_event_raw,
@@ -56,6 +57,7 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
from synapse.types.rest.client import SlidingSyncBody
from synapse.util.caches.lrucache import LruCache
@@ -646,6 +648,7 @@ class SlidingSyncRestServlet(RestServlet):
- receipts (MSC3960)
- account data (MSC3959)
- thread subscriptions (MSC4308)
- threads (MSC4360)
Request query parameters:
timeout: How long to wait for new events in milliseconds.
@@ -849,7 +852,10 @@ class SlidingSyncRestServlet(RestServlet):
logger.info("Client has disconnected; not serializing response.")
return 200, {}
response_content = await self.encode_response(requester, sliding_sync_results)
time_now = self.clock.time_msec()
response_content = await self.encode_response(
requester, sliding_sync_results, time_now
)
return 200, response_content
@@ -858,6 +864,7 @@ class SlidingSyncRestServlet(RestServlet):
self,
requester: Requester,
sliding_sync_result: SlidingSyncResult,
time_now: int,
) -> JsonDict:
response: JsonDict = defaultdict(dict)
@@ -866,10 +873,10 @@ class SlidingSyncRestServlet(RestServlet):
if serialized_lists:
response["lists"] = serialized_lists
response["rooms"] = await self.encode_rooms(
requester, sliding_sync_result.rooms
requester, sliding_sync_result.rooms, time_now
)
response["extensions"] = await self.encode_extensions(
requester, sliding_sync_result.extensions
requester, sliding_sync_result.extensions, time_now
)
return response
@@ -901,9 +908,8 @@ class SlidingSyncRestServlet(RestServlet):
self,
requester: Requester,
rooms: dict[str, SlidingSyncResult.RoomResult],
time_now: int,
) -> JsonDict:
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(
event_format=format_event_for_client_v2_without_room_id,
requester=requester,
@@ -1019,7 +1025,10 @@ class SlidingSyncRestServlet(RestServlet):
@trace_with_opname("sliding_sync.encode_extensions")
async def encode_extensions(
self, requester: Requester, extensions: SlidingSyncResult.Extensions
self,
requester: Requester,
extensions: SlidingSyncResult.Extensions,
time_now: int,
) -> JsonDict:
serialized_extensions: JsonDict = {}
@@ -1089,6 +1098,17 @@ class SlidingSyncRestServlet(RestServlet):
_serialise_thread_subscriptions(extensions.thread_subscriptions)
)
# excludes both None and falsy `threads`
if extensions.threads:
serialized_extensions[
"io.element.msc4360.threads"
] = await _serialise_threads(
self.event_serializer,
time_now,
extensions.threads,
self.store,
)
return serialized_extensions
@@ -1125,6 +1145,72 @@ def _serialise_thread_subscriptions(
return out
async def _serialise_threads(
event_serializer: EventClientSerializer,
time_now: int,
threads: SlidingSyncResult.Extensions.ThreadsExtension,
store: "DataStore",
) -> JsonDict:
"""
Serialize the threads extension response for sliding sync.
Args:
event_serializer: The event serializer to use for serializing thread root events.
time_now: The current time in milliseconds, used for event serialization.
threads: The threads extension data containing thread updates and pagination tokens.
store: The datastore, needed for serializing stream tokens.
Returns:
A JSON-serializable dict containing:
- "updates": A nested dict mapping room_id -> thread_root_id -> thread update.
Each thread update may contain:
- "thread_root": The serialized thread root event (if include_roots was True),
with bundled aggregations including the latest_event in unsigned.m.relations.m.thread.
- "prev_batch": A pagination token for fetching older events in the thread.
- "prev_batch": A pagination token for fetching older thread updates (if available).
"""
out: JsonDict = {}
if threads.updates:
updates_dict: JsonDict = {}
for room_id, thread_updates in threads.updates.items():
room_updates: JsonDict = {}
for thread_root_id, update in thread_updates.items():
# Serialize the update
update_dict: JsonDict = {}
# Serialize the thread_root event if present
if update.thread_root is not None:
# Create a mapping of event_id to bundled_aggregations
bundle_aggs_map = (
{thread_root_id: update.bundled_aggregations}
if update.bundled_aggregations
else None
)
serialized_events = await event_serializer.serialize_events(
[update.thread_root],
time_now,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]
# Add prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
room_updates[thread_root_id] = update_dict
updates_dict[room_id] = room_updates
out["updates"] = updates_dict
if threads.prev_batch:
out["prev_batch"] = await threads.prev_batch.to_string(store)
return out
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

View File

@@ -19,6 +19,7 @@
#
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
@@ -40,13 +41,19 @@ from synapse.storage.database import (
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.stream import (
generate_next_token,
generate_pagination_bounds,
generate_pagination_where_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, StreamKeyType, StreamToken
from synapse.types import (
JsonDict,
RoomStreamToken,
StreamKeyType,
StreamToken,
)
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -88,7 +95,23 @@ class _RelatedEvent:
sender: str
class RelationsWorkerStore(SQLBaseStore):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdateInfo:
"""
Information about a thread update for the sliding sync threads extension.
Attributes:
event_id: The event ID of the event in the thread.
room_id: The room ID where this thread exists.
stream_ordering: The stream ordering of this event.
"""
event_id: str
room_id: str
stream_ordering: int
class RelationsWorkerStore(EventsWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
@@ -584,14 +607,18 @@ class RelationsWorkerStore(SQLBaseStore):
"get_applicable_edits", _get_applicable_edits_txn
)
edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined]
edits = await self.get_events(edit_ids.values())
# Map to the original event IDs to the edit events.
#
# There might not be an edit event due to there being no edits or
# due to the event not being known, either case is treated the same.
return {
original_event_id: edits.get(edit_ids.get(original_event_id))
original_event_id: (
edits.get(edit_id)
if (edit_id := edit_ids.get(original_event_id))
else None
)
for original_event_id in event_ids
}
@@ -699,7 +726,7 @@ class RelationsWorkerStore(SQLBaseStore):
"get_thread_summaries", _get_thread_summaries_txn
)
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
latest_events = await self.get_events(latest_event_ids.values())
# Map to the event IDs to the thread summary.
#
@@ -1111,6 +1138,148 @@ class RelationsWorkerStore(SQLBaseStore):
"get_related_thread_id", _get_related_thread_id
)
async def get_thread_updates_for_rooms(
self,
*,
room_ids: Collection[str],
from_token: RoomStreamToken | None = None,
to_token: RoomStreamToken | None = None,
limit: int = 5,
exclude_thread_ids: Collection[str] | None = None,
) -> tuple[dict[str, list[ThreadUpdateInfo]], StreamToken | None]:
"""Get a list of updated threads, ordered by stream ordering of their
latest reply, filtered to only include threads in rooms where the user
is currently joined.
Args:
room_ids: The room IDs to fetch thread updates for.
from_token: The lower bound (exclusive) for thread updates. If None,
fetch from the start of the room timeline.
to_token: The upper bound (inclusive) for thread updates. If None,
fetch up to the current position in the room timeline.
limit: Maximum number of thread updates to return.
exclude_thread_ids: Optional collection of thread root event IDs to exclude
from the results. Useful for filtering out threads already visible
in the room timeline.
Returns:
A tuple of:
A dict mapping thread_id to list of ThreadUpdateInfo objects,
ordered by stream_ordering descending (most recent first).
A prev_batch StreamToken (exclusive) if there are more results available,
None otherwise.
"""
# Ensure bad limits aren't being passed in.
assert limit > 0
if len(room_ids) == 0:
return ({}), None
def _get_thread_updates_for_user_txn(
txn: LoggingTransaction,
) -> tuple[list[tuple[str, str, str, int]], int | None]:
room_clause, room_id_values = make_in_list_sql_clause(
txn.database_engine, "e.room_id", room_ids
)
# Generate the pagination clause, if necessary.
pagination_clause = ""
pagination_args: list[str] = []
if from_token:
from_bound = from_token.stream
pagination_clause += " AND stream_ordering > ?"
pagination_args.append(str(from_bound))
if to_token:
to_bound = to_token.stream
pagination_clause += " AND stream_ordering <= ?"
pagination_args.append(str(to_bound))
# Generate the exclusion clause for thread IDs, if necessary.
exclusion_clause = ""
exclusion_args: list[str] = []
if exclude_thread_ids:
exclusion_clause, exclusion_args = make_in_list_sql_clause(
txn.database_engine,
"er.relates_to_id",
exclude_thread_ids,
negative=True,
)
exclusion_clause = f" AND {exclusion_clause}"
# TODO: improve the fact that multiple hits for the same thread means we
# won't get as many overall updates for the sss response
# Find any thread events between the stream ordering bounds.
sql = f"""
SELECT e.event_id, er.relates_to_id, e.room_id, e.stream_ordering
FROM event_relations AS er
INNER JOIN events AS e ON er.event_id = e.event_id
WHERE er.relation_type = '{RelationTypes.THREAD}'
AND {room_clause}
{exclusion_clause}
{pagination_clause}
ORDER BY stream_ordering DESC
LIMIT ?
"""
# Fetch `limit + 1` rows as a way to detect if there are more results beyond
# what we're returning. If we get exactly `limit + 1` rows back, we know there
# are more results available and we can set `next_token`. We only return the
# first `limit` rows to the caller. This avoids needing a separate COUNT query.
txn.execute(
sql,
(
*room_id_values,
*exclusion_args,
*pagination_args,
limit + 1,
),
)
# SQL returns: event_id, thread_id, room_id, stream_ordering
rows = cast(list[tuple[str, str, str, int]], txn.fetchall())
# If there are more events, generate the next pagination key from the
# last thread which will be returned.
next_token = None
if len(rows) > limit:
# Set the next_token to be the second last row in the result set since
# that will be the last row we return from this function.
# This works as an exclusive bound that can be backpaginated from.
# Use the stream_ordering field (index 2 in original rows)
next_token = rows[-2][3]
return rows[:limit], next_token
thread_infos, next_token_int = await self.db_pool.runInteraction(
"get_thread_updates_for_user", _get_thread_updates_for_user_txn
)
# Convert the next_token int (stream ordering) to a StreamToken.
# Use StreamToken.START as base (all other streams at 0) since only room
# position matters.
# Subtract 1 to make it exclusive - the client can paginate from this point without
# receiving the last thread update that was already returned.
next_token = None
if next_token_int is not None:
next_token = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM, RoomStreamToken(stream=next_token_int - 1)
)
# Build ThreadUpdateInfo objects.
thread_update_infos: dict[str, list[ThreadUpdateInfo]] = defaultdict(list)
for event_id, thread_id, room_id, stream_ordering in thread_infos:
thread_update_infos[thread_id].append(
ThreadUpdateInfo(
event_id=event_id,
room_id=room_id,
stream_ordering=stream_ordering,
)
)
return (thread_update_infos, next_token)
class RelationsStore(RelationsWorkerStore):
pass

View File

@@ -0,0 +1,33 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Add indexes to improve performance of the thread_updates endpoint and
-- sliding sync threads extension (MSC4360).
-- Index for efficiently finding all events that relate to a specific event
-- (e.g., all replies to a thread root). This is used by the correlated subquery
-- in get_thread_updates_for_user that counts thread updates.
-- Also useful for other relation queries (edits, reactions, etc.).
CREATE INDEX IF NOT EXISTS event_relations_relates_to_id_type
ON event_relations(relates_to_id, relation_type);
-- Index for the /thread_updates endpoint's cross-room query.
-- Allows efficient descending ordering and range filtering of threads
-- by stream_ordering across all rooms.
CREATE INDEX IF NOT EXISTS threads_stream_ordering_desc
ON threads(stream_ordering DESC);
-- Index for the EXISTS clause that filters threads to only joined rooms.
-- Allows efficient lookup of a user's current room memberships.
CREATE INDEX IF NOT EXISTS local_current_membership_user_room
ON local_current_membership(user_id, membership, room_id);

View File

@@ -57,19 +57,41 @@ class PaginationConfig:
from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")
# Helper function to extract StreamToken from either StreamToken or SlidingSyncStreamToken format
def extract_stream_token(token_str: str) -> str:
"""
Extract the StreamToken portion from a token string.
Handles both:
- StreamToken format: "s123_456_..."
- SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /)
This allows clients using sliding sync to use their pos tokens
with endpoints like /relations and /messages.
"""
if "/" in token_str:
# SlidingSyncStreamToken format: "connection_position/stream_token"
# Split and return just the stream_token part
parts = token_str.split("/", 1)
if len(parts) == 2:
return parts[1]
return token_str
try:
from_tok = None
if from_tok_str == "END":
from_tok = None # For backwards compat.
elif from_tok_str:
from_tok = await StreamToken.from_string(store, from_tok_str)
stream_token_str = extract_stream_token(from_tok_str)
from_tok = await StreamToken.from_string(store, stream_token_str)
except Exception:
raise SynapseError(400, "'from' parameter is invalid")
try:
to_tok = None
if to_tok_str:
to_tok = await StreamToken.from_string(store, to_tok_str)
stream_token_str = extract_stream_token(to_tok_str)
to_tok = await StreamToken.from_string(store, stream_token_str)
except Exception:
raise SynapseError(400, "'to' parameter is invalid")

View File

@@ -35,6 +35,9 @@ from pydantic import ConfigDict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.types import (
DeviceListUpdates,
JsonDict,
@@ -388,12 +391,60 @@ class SlidingSyncResult:
or bool(self.prev_batch)
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadsExtension:
"""The Threads extension (MSC4360)
Provides thread updates for threads that have new activity across all of the
user's joined rooms within the sync window.
Attributes:
updates: A nested mapping of room_id -> thread_root_id -> ThreadUpdate.
Each ThreadUpdate contains information about a thread that has new activity,
including the thread root event (if requested) and a pagination token
for fetching older events in that specific thread.
prev_batch: A pagination token for fetching more thread updates across all rooms.
If present, indicates there are more thread updates available beyond what
was returned in this response. This token can be used with a future request
to paginate through older thread updates.
"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdate:
"""Information about a single thread that has new activity.
Attributes:
thread_root: The thread root event, if requested via include_roots in the
request. This is the event that started the thread.
prev_batch: A pagination token (exclusive) for fetching older events in this
specific thread. Only present if the thread has multiple updates in the
sync window. This token can be used with the /relations endpoint with
dir=b to paginate backwards through the thread's history.
bundled_aggregations: Bundled aggregations for the thread root event,
including the latest_event in the thread (found in
unsigned.m.relations.m.thread). Only present if thread_root is included.
"""
thread_root: EventBase | None
prev_batch: StreamToken | None
bundled_aggregations: "BundledAggregations | None" = None
def __bool__(self) -> bool:
return bool(self.thread_root) or bool(self.prev_batch)
updates: Mapping[str, Mapping[str, ThreadUpdate]] | None
prev_batch: StreamToken | None
def __bool__(self) -> bool:
return bool(self.updates) or bool(self.prev_batch)
to_device: ToDeviceExtension | None = None
e2ee: E2eeExtension | None = None
account_data: AccountDataExtension | None = None
receipts: ReceiptsExtension | None = None
typing: TypingExtension | None = None
thread_subscriptions: ThreadSubscriptionsExtension | None = None
threads: ThreadsExtension | None = None
def __bool__(self) -> bool:
return bool(
@@ -403,6 +454,7 @@ class SlidingSyncResult:
or self.receipts
or self.typing
or self.thread_subscriptions
or self.threads
)
next_pos: SlidingSyncStreamToken
@@ -852,6 +904,7 @@ class PerConnectionState:
Attributes:
rooms: The status of each room for the events stream.
receipts: The status of each room for the receipts stream.
account_data: The status of each room for the account data stream.
room_configs: Map from room_id to the `RoomSyncConfig` of all
rooms that we have previously sent down.
"""

View File

@@ -383,6 +383,19 @@ class SlidingSyncBody(RequestBodyModel):
enabled: StrictBool | None = False
limit: StrictInt = 100
class ThreadsExtension(RequestBodyModel):
"""The Threads extension (MSC4360)
Attributes:
enabled: Whether the threads extension is enabled.
include_roots: whether to include thread root events in the extension response.
limit: maximum number of thread updates to return across all joined rooms.
"""
enabled: StrictBool | None = False
include_roots: StrictBool = False
limit: StrictInt = 100
to_device: ToDeviceExtension | None = None
e2ee: E2eeExtension | None = None
account_data: AccountDataExtension | None = None
@@ -391,6 +404,9 @@ class SlidingSyncBody(RequestBodyModel):
thread_subscriptions: ThreadSubscriptionsExtension | None = Field(
None, alias="io.element.msc4308.thread_subscriptions"
)
threads: ThreadsExtension | None = Field(
None, alias="io.element.msc4360.threads"
)
conn_id: StrictStr | None = None
lists: (

View File

@@ -340,6 +340,7 @@ T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
T7 = TypeVar("T7")
@overload
@@ -469,6 +470,30 @@ async def gather_optional_coroutines(
) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None, T6 | None]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
tuple[
Coroutine[Any, Any, T1] | None,
Coroutine[Any, Any, T2] | None,
Coroutine[Any, Any, T3] | None,
Coroutine[Any, Any, T4] | None,
Coroutine[Any, Any, T5] | None,
Coroutine[Any, Any, T6] | None,
Coroutine[Any, Any, T7] | None,
]
],
) -> tuple[
T1 | None,
T2 | None,
T3 | None,
T4 | None,
T5 | None,
T6 | None,
T7 | None,
]: ...
async def gather_optional_coroutines(
*coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None, ...]],
) -> tuple[T1 | None, ...]:

File diff suppressed because it is too large Load Diff