1
0

Compare commits

...

35 Commits

Author SHA1 Message Date
Devon Hudson
d65fc3861b Fix lint 2025-11-09 20:54:21 -07:00
Devon Hudson
eff2503adc Move thread_update request body next to handler 2025-11-09 20:52:27 -07:00
Devon Hudson
ba3c8a5f3e Refactor thread updates to use the same logic between endpoint and extension 2025-11-09 20:45:32 -07:00
Devon Hudson
c826157c52 Fix linter errors 2025-11-09 16:21:56 -07:00
Devon Hudson
57c884ec83 Merge branch 'devon/ssext_threads' into devon/ssext_threads_companion 2025-11-09 16:15:20 -07:00
Devon Hudson
6fa43cb0b4 Comment cleanup 2025-11-09 12:43:41 -07:00
Devon Hudson
f778ac32c1 Update docstring 2025-11-09 12:37:04 -07:00
Devon Hudson
003fc725db Merge branch 'develop' into devon/ssext_threads 2025-11-09 12:33:55 -07:00
Devon Hudson
934f99a694 Add wait_for_new_data tests 2025-11-09 12:09:56 -07:00
Devon Hudson
78e8ec6161 Add test for room list filtering 2025-11-09 09:44:52 -07:00
Devon Hudson
f59419377d Refactor for clarity 2025-11-09 09:35:11 -07:00
Devon Hudson
a3b34dfafd Run linter 2025-11-09 09:30:44 -07:00
Devon Hudson
cb82a4a687 Handle user leave/ban rooms to prevent leaking data 2025-11-09 08:45:52 -07:00
Devon Hudson
0c0ece9612 Fix next_token logic 2025-11-09 08:29:49 -07:00
Devon Hudson
46e3f6756c Cleanup logic 2025-11-08 10:07:46 -07:00
Devon Hudson
dedd6e35e6 Rejig thread updates to use room lists 2025-11-08 09:12:37 -07:00
Devon Hudson
a3c7b3ecb9 Don't fetch bundled aggregations if we don't have to 2025-10-16 18:06:26 -06:00
Devon Hudson
bf594a28a8 Move constants to designated file 2025-10-16 17:37:01 -06:00
Devon Hudson
89f75cc70f Add newsfile 2025-10-10 16:20:03 -06:00
Devon Hudson
2f8568866e Remove unnecessary bits 2025-10-10 16:15:21 -06:00
Devon Hudson
af992dd0e2 Merge branch 'devon/ssext_threads' into devon/ssext_threads_companion 2025-10-10 15:40:06 -06:00
Devon Hudson
c757969597 Add indexes to improve threads query performance 2025-10-10 15:39:27 -06:00
Devon Hudson
87e9fe8b38 Add implementation for /thread_updates MSC4360 companion endpoint 2025-10-10 15:09:01 -06:00
Devon Hudson
4cb0eeabdf Allow SlidingSyncStreamToken in /relations 2025-10-09 11:28:33 -06:00
Devon Hudson
4d7826b006 Filter events from extension if in timeline 2025-10-08 17:01:40 -06:00
Devon Hudson
ab7e5a2b17 Properly return prev_batch tokens for threads extension 2025-10-08 16:12:46 -06:00
Devon Hudson
4c51247cb3 Only return rooms where user is currently joined 2025-10-07 12:49:32 -06:00
Devon Hudson
4dd82e581a Add newsfile 2025-10-03 16:16:04 -06:00
Devon Hudson
6e69338abc Fix linter error 2025-10-03 16:15:11 -06:00
Devon Hudson
79ea4bed33 Add thread_root events to threads extension response 2025-10-03 15:57:13 -06:00
Devon Hudson
9ef4ca173e Add user room filtering for threads extension 2025-10-03 14:01:16 -06:00
Devon Hudson
24b38733df Don't return empty fields in response 2025-10-02 17:23:30 -06:00
Devon Hudson
4602b56643 Stub in early db queries to get tests going 2025-10-02 17:11:14 -06:00
Devon Hudson
6c460b3eae Stub in threads extension tests 2025-10-01 10:53:11 -06:00
Devon Hudson
cd4f4223de Stub in threads sliding sync extension 2025-10-01 10:04:29 -06:00
17 changed files with 3515 additions and 24 deletions

View File

@@ -0,0 +1 @@
Add experimental support for MSC4360: Sliding Sync Threads Extension.

View File

@@ -0,0 +1 @@
Add companion endpoint for MSC4360: Sliding Sync Threads Extension.

View File

@@ -272,6 +272,9 @@ class EventContentFields:
M_TOPIC: Final = "m.topic"
M_TEXT: Final = "m.text"
# Event relations
RELATIONS: Final = "m.relates_to"
class EventUnsignedContentFields:
"""Fields found inside the 'unsigned' data on events"""
@@ -360,3 +363,10 @@ class Direction(enum.Enum):
class ProfileFields:
DISPLAYNAME: Final = "displayname"
AVATAR_URL: Final = "avatar_url"
class MRelatesToFields:
"""Fields found inside m.relates_to content blocks."""
EVENT_ID: Final = "event_id"
REL_TYPE: Final = "rel_type"

View File

@@ -593,3 +593,6 @@ class ExperimentalConfig(Config):
# MSC4306: Thread Subscriptions
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)
# MSC4360: Threads Extension to Sliding Sync
self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False)

View File

@@ -20,6 +20,7 @@
#
import enum
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
@@ -30,24 +31,59 @@ from typing import (
import attr
from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.events.utils import SerializeEventConfig
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
from synapse.storage.databases.main.relations import (
ThreadsNextBatch,
ThreadUpdateInfo,
_RelatedEvent,
)
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, UserID
from synapse.types import (
JsonDict,
Requester,
RoomStreamToken,
StreamKeyType,
StreamToken,
UserID,
)
from synapse.util.async_helpers import gather_results
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.events.utils import EventClientSerializer
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
# Type aliases for thread update processing
ThreadUpdatesMap = dict[str, list[ThreadUpdateInfo]]
ThreadRootsMap = dict[str, EventBase]
AggregationsMap = dict[str, "BundledAggregations"]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdate:
"""
Data for a single thread update.
Attributes:
thread_root: The thread root event, or None if not requested/not visible
prev_batch: Per-thread pagination token for fetching older events in this thread
bundled_aggregations: Bundled aggregations for the thread root event
"""
thread_root: EventBase | None
prev_batch: StreamToken | None
bundled_aggregations: "BundledAggregations | None" = None
class ThreadsListInclude(str, enum.Enum):
"""Valid values for the 'include' flag of /threads."""
@@ -105,8 +141,6 @@ class RelationsHandler:
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
TODO Accept a PaginationConfig instead of individual pagination parameters.
Args:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
@@ -546,6 +580,367 @@ class RelationsHandler:
return results
async def _filter_thread_updates_for_user(
self,
all_thread_updates: ThreadUpdatesMap,
user_id: str,
) -> ThreadUpdatesMap:
"""Process thread updates by filtering for visibility.
Takes raw thread updates from storage and filters them based on whether the
user can see the events. Preserves the ordering of updates within each thread.
Args:
all_thread_updates: Map of thread_id to list of ThreadUpdateInfo objects
user_id: The user ID to filter events for
Returns:
Filtered map of thread_id to list of ThreadUpdateInfo objects, containing
only updates for events the user can see.
"""
# Build a mapping of event_id -> (thread_id, update) for efficient lookup
# during visibility filtering.
event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {}
for thread_id, updates in all_thread_updates.items():
for update in updates:
event_to_thread_map[update.event_id] = (thread_id, update)
# Fetch and filter events for visibility
all_events = await self._main_store.get_events_as_list(
event_to_thread_map.keys()
)
filtered_events = await filter_events_for_client(
self._storage_controllers, user_id, all_events
)
# Rebuild thread updates from filtered events
filtered_updates: ThreadUpdatesMap = defaultdict(list)
for event in filtered_events:
if event.event_id in event_to_thread_map:
thread_id, update = event_to_thread_map[event.event_id]
filtered_updates[thread_id].append(update)
return filtered_updates
def _build_thread_updates_response(
self,
filtered_updates: ThreadUpdatesMap,
thread_root_event_map: ThreadRootsMap,
aggregations_map: AggregationsMap,
global_prev_batch_token: StreamToken | None,
) -> dict[str, dict[str, ThreadUpdate]]:
"""Build thread update response structure with per-thread prev_batch tokens.
Args:
filtered_updates: Map of thread_root_id to list of ThreadUpdateInfo
thread_root_event_map: Map of thread_root_id to EventBase
aggregations_map: Map of thread_root_id to BundledAggregations
global_prev_batch_token: Global pagination token, or None if no more results
Returns:
Map of room_id to thread_root_id to ThreadUpdate
"""
thread_updates: dict[str, dict[str, ThreadUpdate]] = {}
for thread_root_id, updates in filtered_updates.items():
# We only care about the latest update for the thread
# Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.
latest_update = updates[0]
room_id = latest_update.room_id
# Generate per-thread prev_batch token if this thread has multiple visible updates
# or if we hit the global limit.
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
# we only saw 1 update for them. This is to cover the case where we only saw
# a single update for a given thread, but the global limit prevents us from
# obtaining other updates which would have otherwise been included in the range.
per_thread_prev_batch = None
if len(updates) > 1 or global_prev_batch_token is not None:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
if room_id not in thread_updates:
thread_updates[room_id] = {}
thread_updates[room_id][thread_root_id] = ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root_id),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root_id),
)
return thread_updates
async def _fetch_thread_updates(
self,
room_ids: frozenset[str],
room_membership_map: Mapping[str, "RoomsForUserType"],
from_token: StreamToken | None,
to_token: StreamToken,
limit: int,
exclude_thread_ids: set[str] | None = None,
) -> tuple[ThreadUpdatesMap, StreamToken | None]:
"""Fetch thread updates across multiple rooms, handling membership states properly.
This method separates rooms based on membership status (LEAVE/BAN vs others)
and queries them appropriately to prevent data leaks. For rooms where the user
has left or been banned, we bound the query to their leave/ban event position.
Args:
room_ids: The set of room IDs to fetch thread updates for
room_membership_map: Map of room_id to RoomsForUserType containing membership info
from_token: Lower bound (exclusive) for the query, or None for no lower bound
to_token: Upper bound for the query (for joined/invited/knocking rooms)
limit: Maximum number of thread updates to return across all rooms
exclude_thread_ids: Optional set of thread IDs to exclude from results
Returns:
A tuple of:
- Map of thread_id to list of ThreadUpdateInfo objects
- Global prev_batch token if there are more results, None otherwise
"""
# Separate rooms based on membership to handle LEAVE/BAN rooms specially
leave_ban_rooms: set[str] = set()
other_rooms: set[str] = set()
for room_id in room_ids:
membership_info = room_membership_map.get(room_id)
if membership_info and membership_info.membership in (
Membership.LEAVE,
Membership.BAN,
):
leave_ban_rooms.add(room_id)
else:
other_rooms.add(room_id)
# Fetch thread updates from storage, handling LEAVE/BAN rooms separately
all_thread_updates: ThreadUpdatesMap = {}
prev_batch_token: StreamToken | None = None
remaining_limit = limit
# Query LEAVE/BAN rooms with bounded to_token to prevent data leaks
if leave_ban_rooms:
for room_id in leave_ban_rooms:
if remaining_limit <= 0:
# We've hit the limit, set prev_batch to indicate more results
prev_batch_token = to_token
break
membership_info = room_membership_map[room_id]
bounded_to_token = membership_info.event_pos.to_room_stream_token()
(
room_thread_updates,
room_prev_batch,
) = await self._main_store.get_thread_updates_for_rooms(
room_ids={room_id},
from_token=from_token.room_key if from_token else None,
to_token=bounded_to_token,
limit=remaining_limit,
exclude_thread_ids=exclude_thread_ids,
)
# Count updates and reduce remaining limit
num_updates = sum(
len(updates) for updates in room_thread_updates.values()
)
remaining_limit -= num_updates
# Merge updates
for thread_id, updates in room_thread_updates.items():
all_thread_updates.setdefault(thread_id, []).extend(updates)
# Merge prev_batch tokens (take the maximum for backward pagination)
if room_prev_batch is not None:
if prev_batch_token is None:
prev_batch_token = room_prev_batch
elif (
room_prev_batch.room_key.stream
> prev_batch_token.room_key.stream
):
prev_batch_token = room_prev_batch
# Query other rooms (joined/invited/knocking) with normal to_token
if other_rooms and remaining_limit > 0:
(
other_thread_updates,
other_prev_batch,
) = await self._main_store.get_thread_updates_for_rooms(
room_ids=other_rooms,
from_token=from_token.room_key if from_token else None,
to_token=to_token.room_key,
limit=remaining_limit,
exclude_thread_ids=exclude_thread_ids,
)
# Merge updates
for thread_id, updates in other_thread_updates.items():
all_thread_updates.setdefault(thread_id, []).extend(updates)
# Merge prev_batch tokens
if other_prev_batch is not None:
if prev_batch_token is None:
prev_batch_token = other_prev_batch
elif (
other_prev_batch.room_key.stream > prev_batch_token.room_key.stream
):
prev_batch_token = other_prev_batch
return all_thread_updates, prev_batch_token
async def get_thread_updates_for_rooms(
self,
room_ids: frozenset[str],
room_membership_map: Mapping[str, "RoomsForUserType"],
user_id: str,
from_token: StreamToken | None,
to_token: StreamToken,
limit: int,
include_roots: bool = False,
exclude_thread_ids: set[str] | None = None,
) -> tuple[dict[str, dict[str, ThreadUpdate]], StreamToken | None]:
"""Get thread updates across multiple rooms with full processing pipeline.
This is the main entry point for fetching thread updates. It handles:
- Fetching updates with membership-based security
- Filtering for visibility
- Optionally fetching thread roots and aggregations
- Building the response structure
Args:
room_ids: The set of room IDs to fetch updates for
room_membership_map: Map of room_id to RoomsForUserType for membership info
user_id: The user requesting the updates
from_token: Lower bound (exclusive) for the query
to_token: Upper bound for the query
limit: Maximum number of updates to return
include_roots: Whether to fetch and include thread root events (default: False)
exclude_thread_ids: Optional set of thread IDs to exclude
Returns:
A tuple of:
- Map of room_id to thread_root_id to ThreadUpdate
- Global prev_batch token if there are more results, None otherwise
"""
# Fetch thread updates with membership handling
all_thread_updates, prev_batch_token = await self._fetch_thread_updates(
room_ids=room_ids,
room_membership_map=room_membership_map,
from_token=from_token,
to_token=to_token,
limit=limit,
exclude_thread_ids=exclude_thread_ids,
)
if not all_thread_updates:
return {}, prev_batch_token
# Filter thread updates for visibility
filtered_updates = await self._filter_thread_updates_for_user(
all_thread_updates, user_id
)
if not filtered_updates:
return {}, prev_batch_token
# Optionally fetch thread root events and their bundled aggregations
thread_root_event_map: ThreadRootsMap = {}
aggregations_map: AggregationsMap = {}
if include_roots:
# Fetch thread root events
thread_root_events = await self._main_store.get_events_as_list(
filtered_updates.keys()
)
thread_root_event_map = {e.event_id: e for e in thread_root_events}
# Fetch bundled aggregations for the thread roots
if thread_root_event_map:
aggregations_map = await self.get_bundled_aggregations(
thread_root_event_map.values(),
user_id,
)
# Build response structure with per-thread prev_batch tokens
thread_updates = self._build_thread_updates_response(
filtered_updates=filtered_updates,
thread_root_event_map=thread_root_event_map,
aggregations_map=aggregations_map,
global_prev_batch_token=prev_batch_token,
)
return thread_updates, prev_batch_token
@staticmethod
async def serialize_thread_updates(
thread_updates: Mapping[str, Mapping[str, ThreadUpdate]],
prev_batch_token: StreamToken | None,
event_serializer: "EventClientSerializer",
time_now: int,
store: "DataStore",
serialize_options: SerializeEventConfig,
) -> JsonDict:
"""
Serialize thread updates to JSON format.
This helper handles serialization of ThreadUpdate objects for both the
companion endpoint and the sliding sync extension.
Args:
thread_updates: Map of room_id to thread_root_id to ThreadUpdate
prev_batch_token: Global pagination token for fetching more updates
event_serializer: The event serializer to use
time_now: Current time in milliseconds for event serialization
store: Datastore for serializing stream tokens
serialize_options: Serialization config
Returns:
JSON-serializable dict with "updates" and optionally "prev_batch"
"""
updates_dict: JsonDict = {}
for room_id, room_threads in thread_updates.items():
room_updates: JsonDict = {}
for thread_root_id, update in room_threads.items():
update_dict: JsonDict = {}
# Serialize thread_root event if present
if update.thread_root is not None:
bundle_aggs_map = (
{thread_root_id: update.bundled_aggregations}
if update.bundled_aggregations is not None
else None
)
serialized_events = await event_serializer.serialize_events(
[update.thread_root],
time_now,
config=serialize_options,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]
# Add per-thread prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
room_updates[thread_root_id] = update_dict
updates_dict[room_id] = room_updates
result: JsonDict = {"updates": updates_dict}
# Add global prev_batch token if present
if prev_batch_token is not None:
result["prev_batch"] = await prev_batch_token.to_string(store)
return result
async def get_threads(
self,
requester: Requester,

View File

@@ -305,6 +305,7 @@ class SlidingSyncHandler:
# account data, read receipts, typing indicators, to-device messages, etc).
actual_room_ids=set(relevant_room_map.keys()),
actual_room_response_map=rooms,
room_membership_for_user_at_to_token_map=room_membership_for_user_map,
from_token=from_token,
to_token=to_token,
)

View File

@@ -26,8 +26,16 @@ from typing import (
from typing_extensions import TypeAlias, assert_never
from synapse.api.constants import AccountDataTypes, EduTypes
from synapse.api.constants import (
AccountDataTypes,
EduTypes,
EventContentFields,
MRelatesToFields,
RelationTypes,
)
from synapse.events import EventBase
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.receipts import ReceiptInRoom
from synapse.types import (
@@ -73,7 +81,10 @@ class SlidingSyncExtensionHandler:
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self.relations_handler = hs.get_relations_handler()
self._storage_controllers = hs.get_storage_controllers()
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
self._enable_threads_ext = hs.config.experimental.msc4360_enabled
@trace
async def get_extensions_response(
@@ -84,6 +95,7 @@ class SlidingSyncExtensionHandler:
actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
actual_room_ids: set[str],
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
room_membership_for_user_at_to_token_map: Mapping[str, RoomsForUserType],
to_token: StreamToken,
from_token: SlidingSyncStreamToken | None,
) -> SlidingSyncResult.Extensions:
@@ -99,6 +111,8 @@ class SlidingSyncExtensionHandler:
actual_room_ids: The actual room IDs in the the Sliding Sync response.
actual_room_response_map: A map of room ID to room results in the the
Sliding Sync response.
room_membership_for_user_at_to_token_map: A map of room ID to the membership
information for the user in the room at the time of `to_token`.
to_token: The latest point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""
@@ -174,6 +188,18 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
threads_coro = None
if sync_config.extensions.threads is not None and self._enable_threads_ext:
threads_coro = self.get_threads_extension_response(
sync_config=sync_config,
threads_request=sync_config.extensions.threads,
actual_room_ids=actual_room_ids,
actual_room_response_map=actual_room_response_map,
room_membership_for_user_at_to_token_map=room_membership_for_user_at_to_token_map,
to_token=to_token,
from_token=from_token,
)
(
to_device_response,
e2ee_response,
@@ -181,6 +207,7 @@ class SlidingSyncExtensionHandler:
receipts_response,
typing_response,
thread_subs_response,
threads_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
@@ -188,6 +215,7 @@ class SlidingSyncExtensionHandler:
receipts_coro,
typing_coro,
thread_subs_coro,
threads_coro,
)
return SlidingSyncResult.Extensions(
@@ -197,6 +225,7 @@ class SlidingSyncExtensionHandler:
receipts=receipts_response,
typing=typing_response,
thread_subscriptions=thread_subs_response,
threads=threads_response,
)
def find_relevant_room_ids_for_extension(
@@ -967,3 +996,104 @@ class SlidingSyncExtensionHandler:
unsubscribed=unsubscribed_threads,
prev_batch=prev_batch,
)
def _extract_thread_id_from_event(self, event: EventBase) -> str | None:
"""Extract thread ID from event if it's a thread reply.
Args:
event: The event to check.
Returns:
The thread ID if the event is a thread reply, None otherwise.
"""
relates_to = event.content.get(EventContentFields.RELATIONS)
if isinstance(relates_to, dict):
if relates_to.get(MRelatesToFields.REL_TYPE) == RelationTypes.THREAD:
return relates_to.get(MRelatesToFields.EVENT_ID)
return None
def _find_threads_in_timeline(
self,
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
) -> set[str]:
"""Find all thread IDs that have events in room timelines.
Args:
actual_room_response_map: A map of room ID to room results.
Returns:
A set of thread IDs (thread root event IDs) that appear in the timeline.
"""
threads_in_timeline: set[str] = set()
for room_result in actual_room_response_map.values():
if room_result.timeline_events:
for event in room_result.timeline_events:
thread_id = self._extract_thread_id_from_event(event)
if thread_id:
threads_in_timeline.add(thread_id)
return threads_in_timeline
async def get_threads_extension_response(
self,
sync_config: SlidingSyncConfig,
threads_request: SlidingSyncConfig.Extensions.ThreadsExtension,
actual_room_ids: set[str],
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
room_membership_for_user_at_to_token_map: Mapping[str, RoomsForUserType],
to_token: StreamToken,
from_token: SlidingSyncStreamToken | None,
) -> SlidingSyncResult.Extensions.ThreadsExtension | None:
"""Handle Threads extension (MSC4360)
Args:
sync_config: Sync configuration.
threads_request: The threads extension from the request.
actual_room_ids: The actual room IDs in the the Sliding Sync response.
actual_room_response_map: A map of room ID to room results in the
sliding sync response. Used to determine which threads already have
events in the room timeline.
room_membership_for_user_at_to_token_map: A map of room ID to the membership
information for the user in the room at the time of `to_token`.
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
Returns:
the response (None if empty or threads extension is disabled)
"""
if not threads_request.enabled:
return None
# Identify which threads already have events in the room timelines.
# If include_roots=False, we'll exclude these threads from the DB query
# since the client already sees the thread activity in the timeline.
# If include_roots=True, we fetch all threads regardless, because the client
# wants the thread root events.
threads_to_exclude: set[str] | None = None
if not threads_request.include_roots:
threads_to_exclude = self._find_threads_in_timeline(
actual_room_response_map
)
# Get thread updates using unified helper
user_id = sync_config.user.to_string()
(
thread_updates_response,
prev_batch_token,
) = await self.relations_handler.get_thread_updates_for_rooms(
room_ids=frozenset(actual_room_ids),
room_membership_map=room_membership_for_user_at_to_token_map,
user_id=user_id,
from_token=from_token.stream_token if from_token else None,
to_token=to_token,
limit=threads_request.limit,
include_roots=threads_request.include_roots,
exclude_thread_ids=threads_to_exclude,
)
if not thread_updates_response:
return None
return SlidingSyncResult.Extensions.ThreadsExtension(
updates=thread_updates_response,
prev_batch=prev_batch_token,
)

View File

@@ -20,17 +20,33 @@
import logging
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Annotated
from synapse.api.constants import Direction
from pydantic import StrictBool, StrictStr
from pydantic.types import StringConstraints
from synapse.api.constants import Direction, Membership
from synapse.api.errors import SynapseError
from synapse.events.utils import SerializeEventConfig
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.servlet import (
RestServlet,
parse_and_validate_json_object_from_request,
parse_boolean,
parse_integer,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.relations import ThreadsNextBatch
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict
from synapse.streams.config import (
PaginationConfig,
extract_stream_token_from_pagination_token,
)
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig
from synapse.types.rest.client import RequestBodyModel, SlidingSyncBody
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -38,6 +54,39 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ThreadUpdatesBody(RequestBodyModel):
"""
Thread updates companion endpoint request body (MSC4360).
Allows paginating thread updates using the same room selection as a sliding sync
request. This enables clients to fetch thread updates for the same set of rooms
that were included in their sliding sync response.
Attributes:
lists: Sliding window API lists, using the same structure as SlidingSyncBody.lists.
If provided along with room_subscriptions, the union of rooms from both will
be used.
room_subscriptions: Room subscription API rooms, using the same structure as
SlidingSyncBody.room_subscriptions. If provided along with lists, the union
of rooms from both will be used.
include_roots: Whether to include the thread root events in the response.
Defaults to False.
If neither lists nor room_subscriptions are provided, thread updates from all
joined rooms are returned.
"""
lists: (
dict[
Annotated[str, StringConstraints(max_length=64, strict=True)],
SlidingSyncBody.SlidingSyncList,
]
| None
) = None
room_subscriptions: dict[StrictStr, SlidingSyncBody.RoomSubscription] | None = None
include_roots: StrictBool = False
class RelationPaginationServlet(RestServlet):
"""API to paginate relations on an event by topological ordering, optionally
filtered by relation type and event type.
@@ -133,6 +182,167 @@ class ThreadsServlet(RestServlet):
return 200, result
class ThreadUpdatesServlet(RestServlet):
"""
Companion endpoint to the Sliding Sync threads extension (MSC4360).
Allows clients to bulk fetch thread updates across all joined rooms.
"""
PATTERNS = client_patterns(
"/io.element.msc4360/thread_updates$",
unstable=True,
releases=(),
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.relations_handler = hs.get_relations_handler()
self.event_serializer = hs.get_event_client_serializer()
self._storage_controllers = hs.get_storage_controllers()
self.sliding_sync_handler = hs.get_sliding_sync_handler()
async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
# Parse request body
body = parse_and_validate_json_object_from_request(request, ThreadUpdatesBody)
# Parse query parameters
dir_str = parse_string(request, "dir", default="b")
if dir_str != "b":
raise SynapseError(
400,
"The 'dir' parameter must be 'b' (backward). Forward pagination is not supported.",
)
limit = parse_integer(request, "limit", default=100)
if limit <= 0:
raise SynapseError(400, "The 'limit' parameter must be positive.")
from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")
# Parse pagination tokens
from_token: StreamToken | None = None
to_token: StreamToken | None = None
if from_token_str:
try:
stream_token_str = extract_stream_token_from_pagination_token(
from_token_str
)
from_token = await StreamToken.from_string(self.store, stream_token_str)
except Exception as e:
logger.exception("Error parsing 'from' token: %s", from_token_str)
raise SynapseError(400, "'from' parameter is invalid") from e
if to_token_str:
try:
stream_token_str = extract_stream_token_from_pagination_token(
to_token_str
)
to_token = await StreamToken.from_string(self.store, stream_token_str)
except Exception:
raise SynapseError(400, "'to' parameter is invalid")
# Get the list of rooms to fetch thread updates for
user_id = requester.user.to_string()
user = UserID.from_string(user_id)
# Get the current stream token for membership lookup
if from_token is None:
max_stream_ordering = self.store.get_room_max_stream_ordering()
current_token = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM, RoomStreamToken(stream=max_stream_ordering)
)
else:
current_token = from_token
# Get room membership information to properly handle LEAVE/BAN rooms
(
room_membership_for_user_at_to_token_map,
_,
_,
) = await self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
user=user,
to_token=current_token,
from_token=None,
)
# Determine which rooms to fetch updates for based on lists/room_subscriptions
if body.lists is not None or body.room_subscriptions is not None:
# Use sliding sync room selection logic
sync_config = SlidingSyncConfig(
user=user,
requester=requester,
lists=body.lists,
room_subscriptions=body.room_subscriptions,
)
# Use the sliding sync room list handler to get the same set of rooms
interested_rooms = (
await self.sliding_sync_handler.room_lists.compute_interested_rooms(
sync_config=sync_config,
previous_connection_state=PerConnectionState(),
to_token=current_token,
from_token=None,
)
)
room_ids = frozenset(interested_rooms.relevant_room_map.keys())
else:
# No lists or room_subscriptions, use only joined rooms
room_ids = frozenset(
room_id
for room_id, membership_info in room_membership_for_user_at_to_token_map.items()
if membership_info.membership == Membership.JOIN
)
# Get thread updates using unified helper
(
thread_updates,
prev_batch_token,
) = await self.relations_handler.get_thread_updates_for_rooms(
room_ids=room_ids,
room_membership_map=room_membership_for_user_at_to_token_map,
user_id=user_id,
from_token=to_token,
to_token=from_token if from_token else current_token,
limit=limit,
include_roots=body.include_roots,
)
if not thread_updates:
return 200, {"chunk": {}}
# Serialize thread updates using shared helper
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
serialized = await self.relations_handler.serialize_thread_updates(
thread_updates=thread_updates,
prev_batch_token=prev_batch_token,
event_serializer=self.event_serializer,
time_now=time_now,
store=self.store,
serialize_options=serialize_options,
)
# Build response with "chunk" wrapper and "next_batch" key
# (companion endpoint uses different key names than sliding sync)
response: JsonDict = {"chunk": serialized["updates"]}
if "prev_batch" in serialized:
response["next_batch"] = serialized["prev_batch"]
return 200, response
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
ThreadsServlet(hs).register(http_server)
if hs.config.experimental.msc4360_enabled:
ThreadUpdatesServlet(hs).register(http_server)

View File

@@ -31,11 +31,13 @@ from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.ratelimiting import Ratelimiter
from synapse.events.utils import (
EventClientSerializer,
SerializeEventConfig,
format_event_for_client_v2_without_room_id,
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.relations import RelationsHandler
from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult
from synapse.handlers.sync import (
ArchivedSyncResult,
@@ -56,6 +58,7 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
from synapse.types.rest.client import SlidingSyncBody
from synapse.util.caches.lrucache import LruCache
@@ -646,6 +649,7 @@ class SlidingSyncRestServlet(RestServlet):
- receipts (MSC3960)
- account data (MSC3959)
- thread subscriptions (MSC4308)
- threads (MSC4360)
Request query parameters:
timeout: How long to wait for new events in milliseconds.
@@ -849,7 +853,10 @@ class SlidingSyncRestServlet(RestServlet):
logger.info("Client has disconnected; not serializing response.")
return 200, {}
response_content = await self.encode_response(requester, sliding_sync_results)
time_now = self.clock.time_msec()
response_content = await self.encode_response(
requester, sliding_sync_results, time_now
)
return 200, response_content
@@ -858,6 +865,7 @@ class SlidingSyncRestServlet(RestServlet):
self,
requester: Requester,
sliding_sync_result: SlidingSyncResult,
time_now: int,
) -> JsonDict:
response: JsonDict = defaultdict(dict)
@@ -866,10 +874,10 @@ class SlidingSyncRestServlet(RestServlet):
if serialized_lists:
response["lists"] = serialized_lists
response["rooms"] = await self.encode_rooms(
requester, sliding_sync_result.rooms
requester, sliding_sync_result.rooms, time_now
)
response["extensions"] = await self.encode_extensions(
requester, sliding_sync_result.extensions
requester, sliding_sync_result.extensions, time_now
)
return response
@@ -901,9 +909,8 @@ class SlidingSyncRestServlet(RestServlet):
self,
requester: Requester,
rooms: dict[str, SlidingSyncResult.RoomResult],
time_now: int,
) -> JsonDict:
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(
event_format=format_event_for_client_v2_without_room_id,
requester=requester,
@@ -1019,7 +1026,10 @@ class SlidingSyncRestServlet(RestServlet):
@trace_with_opname("sliding_sync.encode_extensions")
async def encode_extensions(
self, requester: Requester, extensions: SlidingSyncResult.Extensions
self,
requester: Requester,
extensions: SlidingSyncResult.Extensions,
time_now: int,
) -> JsonDict:
serialized_extensions: JsonDict = {}
@@ -1089,6 +1099,18 @@ class SlidingSyncRestServlet(RestServlet):
_serialise_thread_subscriptions(extensions.thread_subscriptions)
)
# excludes both None and falsy `threads`
if extensions.threads:
serialized_extensions[
"io.element.msc4360.threads"
] = await _serialise_threads(
self.event_serializer,
time_now,
extensions.threads,
self.store,
requester,
)
return serialized_extensions
@@ -1125,6 +1147,52 @@ def _serialise_thread_subscriptions(
return out
async def _serialise_threads(
event_serializer: EventClientSerializer,
time_now: int,
threads: SlidingSyncResult.Extensions.ThreadsExtension,
store: "DataStore",
requester: Requester,
) -> JsonDict:
"""
Serialize the threads extension response for sliding sync.
Args:
event_serializer: The event serializer to use for serializing thread root events.
time_now: The current time in milliseconds, used for event serialization.
threads: The threads extension data containing thread updates and pagination tokens.
store: The datastore, needed for serializing stream tokens.
requester: The user making the request, used for transaction_id inclusion.
Returns:
A JSON-serializable dict containing:
- "updates": A nested dict mapping room_id -> thread_root_id -> thread update.
Each thread update may contain:
- "thread_root": The serialized thread root event (if include_roots was True),
with bundled aggregations including the latest_event in unsigned.m.relations.m.thread.
- "prev_batch": A pagination token for fetching older events in the thread.
- "prev_batch": A pagination token for fetching older thread updates (if available).
"""
if not threads.updates:
out: JsonDict = {}
if threads.prev_batch:
out["prev_batch"] = await threads.prev_batch.to_string(store)
return out
# Create serialization config to include transaction_id for requester's events
serialize_options = SerializeEventConfig(requester=requester)
# Use shared serialization helper (static method)
return await RelationsHandler.serialize_thread_updates(
thread_updates=threads.updates,
prev_batch_token=threads.prev_batch,
event_serializer=event_serializer,
time_now=time_now,
store=store,
serialize_options=serialize_options,
)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

View File

@@ -19,6 +19,7 @@
#
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
@@ -40,13 +41,19 @@ from synapse.storage.database import (
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.stream import (
generate_next_token,
generate_pagination_bounds,
generate_pagination_where_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, StreamKeyType, StreamToken
from synapse.types import (
JsonDict,
RoomStreamToken,
StreamKeyType,
StreamToken,
)
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -88,7 +95,23 @@ class _RelatedEvent:
sender: str
class RelationsWorkerStore(SQLBaseStore):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdateInfo:
"""
Information about a thread update for the sliding sync threads extension.
Attributes:
event_id: The event ID of the event in the thread.
room_id: The room ID where this thread exists.
stream_ordering: The stream ordering of this event.
"""
event_id: str
room_id: str
stream_ordering: int
class RelationsWorkerStore(EventsWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
@@ -584,14 +607,18 @@ class RelationsWorkerStore(SQLBaseStore):
"get_applicable_edits", _get_applicable_edits_txn
)
edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined]
edits = await self.get_events(edit_ids.values())
# Map to the original event IDs to the edit events.
#
# There might not be an edit event due to there being no edits or
# due to the event not being known, either case is treated the same.
return {
original_event_id: edits.get(edit_ids.get(original_event_id))
original_event_id: (
edits.get(edit_id)
if (edit_id := edit_ids.get(original_event_id))
else None
)
for original_event_id in event_ids
}
@@ -699,7 +726,7 @@ class RelationsWorkerStore(SQLBaseStore):
"get_thread_summaries", _get_thread_summaries_txn
)
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
latest_events = await self.get_events(latest_event_ids.values())
# Map to the event IDs to the thread summary.
#
@@ -1111,6 +1138,148 @@ class RelationsWorkerStore(SQLBaseStore):
"get_related_thread_id", _get_related_thread_id
)
async def get_thread_updates_for_rooms(
self,
*,
room_ids: Collection[str],
from_token: RoomStreamToken | None = None,
to_token: RoomStreamToken | None = None,
limit: int = 5,
exclude_thread_ids: Collection[str] | None = None,
) -> tuple[dict[str, list[ThreadUpdateInfo]], StreamToken | None]:
"""Get a list of updated threads, ordered by stream ordering of their
latest reply, filtered to only include threads in rooms where the user
is currently joined.
Args:
room_ids: The room IDs to fetch thread updates for.
from_token: The lower bound (exclusive) for thread updates. If None,
fetch from the start of the room timeline.
to_token: The upper bound (inclusive) for thread updates. If None,
fetch up to the current position in the room timeline.
limit: Maximum number of thread updates to return.
exclude_thread_ids: Optional collection of thread root event IDs to exclude
from the results. Useful for filtering out threads already visible
in the room timeline.
Returns:
A tuple of:
A dict mapping thread_id to list of ThreadUpdateInfo objects,
ordered by stream_ordering descending (most recent first).
A prev_batch StreamToken (exclusive) if there are more results available,
None otherwise.
"""
# Ensure bad limits aren't being passed in.
assert limit > 0
if len(room_ids) == 0:
return ({}), None
def _get_thread_updates_for_user_txn(
txn: LoggingTransaction,
) -> tuple[list[tuple[str, str, str, int]], int | None]:
room_clause, room_id_values = make_in_list_sql_clause(
txn.database_engine, "e.room_id", room_ids
)
# Generate the pagination clause, if necessary.
pagination_clause = ""
pagination_args: list[str] = []
if from_token:
from_bound = from_token.stream
pagination_clause += " AND stream_ordering > ?"
pagination_args.append(str(from_bound))
if to_token:
to_bound = to_token.stream
pagination_clause += " AND stream_ordering <= ?"
pagination_args.append(str(to_bound))
# Generate the exclusion clause for thread IDs, if necessary.
exclusion_clause = ""
exclusion_args: list[str] = []
if exclude_thread_ids:
exclusion_clause, exclusion_args = make_in_list_sql_clause(
txn.database_engine,
"er.relates_to_id",
exclude_thread_ids,
negative=True,
)
exclusion_clause = f" AND {exclusion_clause}"
# TODO: improve the fact that multiple hits for the same thread means we
# won't get as many overall updates for the sss response
# Find any thread events between the stream ordering bounds.
sql = f"""
SELECT e.event_id, er.relates_to_id, e.room_id, e.stream_ordering
FROM event_relations AS er
INNER JOIN events AS e ON er.event_id = e.event_id
WHERE er.relation_type = '{RelationTypes.THREAD}'
AND {room_clause}
{exclusion_clause}
{pagination_clause}
ORDER BY stream_ordering DESC
LIMIT ?
"""
# Fetch `limit + 1` rows as a way to detect if there are more results beyond
# what we're returning. If we get exactly `limit + 1` rows back, we know there
# are more results available and we can set `next_token`. We only return the
# first `limit` rows to the caller. This avoids needing a separate COUNT query.
txn.execute(
sql,
(
*room_id_values,
*exclusion_args,
*pagination_args,
limit + 1,
),
)
# SQL returns: event_id, thread_id, room_id, stream_ordering
rows = cast(list[tuple[str, str, str, int]], txn.fetchall())
# If there are more events, generate the next pagination key from the
# last thread which will be returned.
next_token = None
if len(rows) > limit:
# Set the next_token to be the second last row in the result set since
# that will be the last row we return from this function.
# This works as an exclusive bound that can be backpaginated from.
# Use the stream_ordering field (index 2 in original rows)
next_token = rows[-2][3]
return rows[:limit], next_token
thread_infos, next_token_int = await self.db_pool.runInteraction(
"get_thread_updates_for_user", _get_thread_updates_for_user_txn
)
# Convert the next_token int (stream ordering) to a StreamToken.
# Use StreamToken.START as base (all other streams at 0) since only room
# position matters.
# Subtract 1 to make it exclusive - the client can paginate from this point without
# receiving the last thread update that was already returned.
next_token = None
if next_token_int is not None:
next_token = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM, RoomStreamToken(stream=next_token_int - 1)
)
# Build ThreadUpdateInfo objects.
thread_update_infos: dict[str, list[ThreadUpdateInfo]] = defaultdict(list)
for event_id, thread_id, room_id, stream_ordering in thread_infos:
thread_update_infos[thread_id].append(
ThreadUpdateInfo(
event_id=event_id,
room_id=room_id,
stream_ordering=stream_ordering,
)
)
return (thread_update_infos, next_token)
class RelationsStore(RelationsWorkerStore):
pass

View File

@@ -0,0 +1,33 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Add indexes to improve performance of the thread_updates endpoint and
-- sliding sync threads extension (MSC4360).
-- Index for efficiently finding all events that relate to a specific event
-- (e.g., all replies to a thread root). This is used by the correlated subquery
-- in get_thread_updates_for_user that counts thread updates.
-- Also useful for other relation queries (edits, reactions, etc.).
CREATE INDEX IF NOT EXISTS event_relations_relates_to_id_type
ON event_relations(relates_to_id, relation_type);
-- Index for the /thread_updates endpoint's cross-room query.
-- Allows efficient descending ordering and range filtering of threads
-- by stream_ordering across all rooms.
CREATE INDEX IF NOT EXISTS threads_stream_ordering_desc
ON threads(stream_ordering DESC);
-- Index for the EXISTS clause that filters threads to only joined rooms.
-- Allows efficient lookup of a user's current room memberships.
CREATE INDEX IF NOT EXISTS local_current_membership_user_room
ON local_current_membership(user_id, membership, room_id);

View File

@@ -35,6 +35,32 @@ logger = logging.getLogger(__name__)
MAX_LIMIT = 1000
def extract_stream_token_from_pagination_token(token_str: str) -> str:
"""
Extract the StreamToken portion from a pagination token string.
Handles both:
- StreamToken format: "s123_456_..."
- SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /)
This allows clients using sliding sync to use their pos tokens
with endpoints like /relations and /messages.
Args:
token_str: The token string to parse
Returns:
The StreamToken portion of the token
"""
if "/" in token_str:
# SlidingSyncStreamToken format: "connection_position/stream_token"
# Split and return just the stream_token part
parts = token_str.split("/", 1)
if len(parts) == 2:
return parts[1]
return token_str
@attr.s(slots=True, auto_attribs=True)
class PaginationConfig:
"""A configuration object which stores pagination parameters."""
@@ -62,14 +88,20 @@ class PaginationConfig:
if from_tok_str == "END":
from_tok = None # For backwards compat.
elif from_tok_str:
from_tok = await StreamToken.from_string(store, from_tok_str)
stream_token_str = extract_stream_token_from_pagination_token(
from_tok_str
)
from_tok = await StreamToken.from_string(store, stream_token_str)
except Exception:
raise SynapseError(400, "'from' parameter is invalid")
try:
to_tok = None
if to_tok_str:
to_tok = await StreamToken.from_string(store, to_tok_str)
stream_token_str = extract_stream_token_from_pagination_token(
to_tok_str
)
to_tok = await StreamToken.from_string(store, stream_token_str)
except Exception:
raise SynapseError(400, "'to' parameter is invalid")

View File

@@ -35,6 +35,10 @@ from pydantic import ConfigDict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations, ThreadUpdate
from synapse.types import (
DeviceListUpdates,
JsonDict,
@@ -388,12 +392,37 @@ class SlidingSyncResult:
or bool(self.prev_batch)
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadsExtension:
"""The Threads extension (MSC4360)
Provides thread updates for threads that have new activity across all of the
user's joined rooms within the sync window.
Attributes:
updates: A nested mapping of room_id -> thread_root_id -> ThreadUpdate.
Each ThreadUpdate contains information about a thread that has new activity,
including the thread root event (if requested) and a pagination token
for fetching older events in that specific thread.
prev_batch: A pagination token for fetching more thread updates across all rooms.
If present, indicates there are more thread updates available beyond what
was returned in this response. This token can be used with a future request
to paginate through older thread updates.
"""
updates: Mapping[str, Mapping[str, "ThreadUpdate"]] | None
prev_batch: StreamToken | None
def __bool__(self) -> bool:
return bool(self.updates) or bool(self.prev_batch)
to_device: ToDeviceExtension | None = None
e2ee: E2eeExtension | None = None
account_data: AccountDataExtension | None = None
receipts: ReceiptsExtension | None = None
typing: TypingExtension | None = None
thread_subscriptions: ThreadSubscriptionsExtension | None = None
threads: ThreadsExtension | None = None
def __bool__(self) -> bool:
return bool(
@@ -403,6 +432,7 @@ class SlidingSyncResult:
or self.receipts
or self.typing
or self.thread_subscriptions
or self.threads
)
next_pos: SlidingSyncStreamToken
@@ -852,6 +882,7 @@ class PerConnectionState:
Attributes:
rooms: The status of each room for the events stream.
receipts: The status of each room for the receipts stream.
account_data: The status of each room for the account data stream.
room_configs: Map from room_id to the `RoomSyncConfig` of all
rooms that we have previously sent down.
"""

View File

@@ -383,6 +383,19 @@ class SlidingSyncBody(RequestBodyModel):
enabled: StrictBool | None = False
limit: StrictInt = 100
class ThreadsExtension(RequestBodyModel):
"""The Threads extension (MSC4360)
Attributes:
enabled: Whether the threads extension is enabled.
include_roots: whether to include thread root events in the extension response.
limit: maximum number of thread updates to return across all joined rooms.
"""
enabled: StrictBool | None = False
include_roots: StrictBool = False
limit: StrictInt = 100
to_device: ToDeviceExtension | None = None
e2ee: E2eeExtension | None = None
account_data: AccountDataExtension | None = None
@@ -391,6 +404,9 @@ class SlidingSyncBody(RequestBodyModel):
thread_subscriptions: ThreadSubscriptionsExtension | None = Field(
None, alias="io.element.msc4308.thread_subscriptions"
)
threads: ThreadsExtension | None = Field(
None, alias="io.element.msc4360.threads"
)
conn_id: StrictStr | None = None
lists: (

View File

@@ -340,6 +340,7 @@ T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
T7 = TypeVar("T7")
@overload
@@ -469,6 +470,30 @@ async def gather_optional_coroutines(
) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None, T6 | None]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
tuple[
Coroutine[Any, Any, T1] | None,
Coroutine[Any, Any, T2] | None,
Coroutine[Any, Any, T3] | None,
Coroutine[Any, Any, T4] | None,
Coroutine[Any, Any, T5] | None,
Coroutine[Any, Any, T6] | None,
Coroutine[Any, Any, T7] | None,
]
],
) -> tuple[
T1 | None,
T2 | None,
T3 | None,
T4 | None,
T5 | None,
T6 | None,
T7 | None,
]: ...
async def gather_optional_coroutines(
*coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None, ...]],
) -> tuple[T1 | None, ...]:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,957 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import logging
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import RelationTypes
from synapse.rest.client import login, relations, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util.clock import Clock
from tests import unittest
logger = logging.getLogger(__name__)
class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"""
Test the /thread_updates companion endpoint (MSC4360).
"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
relations.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc4360_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
def test_no_updates_for_new_user(self) -> None:
"""
Test that a user with no thread updates gets an empty response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Request thread updates
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
content={"include_roots": True},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert empty chunk and no next_batch
self.assertEqual(channel.json_body["chunk"], {})
self.assertNotIn("next_batch", channel.json_body)
def test_single_thread_update(self) -> None:
"""
Test that a single thread with one reply appears in the response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create thread root
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# Add reply to thread
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply 1",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert thread is present
chunk = channel.json_body["chunk"]
self.assertIn(room_id, chunk)
self.assertIn(thread_root_id, chunk[room_id])
# Assert thread root is included
thread_update = chunk[room_id][thread_root_id]
self.assertIn("thread_root", thread_update)
self.assertEqual(thread_update["thread_root"]["event_id"], thread_root_id)
# Assert prev_batch is NOT present (only 1 update - the reply)
self.assertNotIn("prev_batch", thread_update)
def test_multiple_threads_single_room(self) -> None:
"""
Test that multiple threads in the same room are grouped correctly.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create two threads
thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[
"event_id"
]
thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[
"event_id"
]
# Add replies to both threads
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 1",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread1_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 2",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread2_root_id,
},
},
tok=user1_tok,
)
# Request thread updates
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert both threads are in the same room
chunk = channel.json_body["chunk"]
self.assertIn(room_id, chunk)
self.assertEqual(len(chunk), 1, "Should only have one room")
self.assertEqual(len(chunk[room_id]), 2, "Should have two threads")
self.assertIn(thread1_root_id, chunk[room_id])
self.assertIn(thread2_root_id, chunk[room_id])
def test_threads_across_multiple_rooms(self) -> None:
"""
Test that threads from different rooms are grouped by room_id.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_a_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room_b_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in both rooms
thread_a_root_id = self.helper.send(room_a_id, body="Thread A", tok=user1_tok)[
"event_id"
]
thread_b_root_id = self.helper.send(room_b_id, body="Thread B", tok=user1_tok)[
"event_id"
]
# Add replies
self.helper.send_event(
room_a_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to A",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_a_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
room_b_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to B",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_b_root_id,
},
},
tok=user1_tok,
)
# Request thread updates
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert both rooms are present with their threads
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 2, "Should have two rooms")
self.assertIn(room_a_id, chunk)
self.assertIn(room_b_id, chunk)
self.assertIn(thread_a_root_id, chunk[room_a_id])
self.assertIn(thread_b_root_id, chunk[room_b_id])
def test_pagination_with_from_token(self) -> None:
"""
Test that pagination works using the next_batch token.
This verifies that multiple calls to /thread_updates return all thread
updates with no duplicates and no gaps.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create many threads (more than default limit)
thread_ids = []
for i in range(5):
thread_root_id = self.helper.send(
room_id, body=f"Thread {i}", tok=user1_tok
)["event_id"]
thread_ids.append(thread_root_id)
# Add reply
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": f"Reply to thread {i}",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request first page with small limit
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
# Should have 2 threads and a next_batch token
first_page_threads = set(channel.json_body["chunk"][room_id].keys())
self.assertEqual(len(first_page_threads), 2)
self.assertIn("next_batch", channel.json_body)
next_batch = channel.json_body["next_batch"]
# Request second page
channel = self.make_request(
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
second_page_threads = set(channel.json_body["chunk"][room_id].keys())
self.assertEqual(len(second_page_threads), 2)
# Verify no overlap
self.assertEqual(
len(first_page_threads & second_page_threads),
0,
"Pages should not have overlapping threads",
)
# Request third page to get the remaining thread
self.assertIn("next_batch", channel.json_body)
next_batch_2 = channel.json_body["next_batch"]
channel = self.make_request(
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
third_page_threads = set(channel.json_body["chunk"][room_id].keys())
self.assertEqual(len(third_page_threads), 1)
# Verify no overlap between any pages
self.assertEqual(len(first_page_threads & third_page_threads), 0)
self.assertEqual(len(second_page_threads & third_page_threads), 0)
# Verify no gaps - all threads should be accounted for across all pages
all_threads = set(thread_ids)
combined_threads = first_page_threads | second_page_threads | third_page_threads
self.assertEqual(
combined_threads,
all_threads,
"Combined pages should include all thread updates with no gaps",
)
def test_invalid_dir_parameter(self) -> None:
"""
Test that forward pagination (dir=f) is rejected with an error.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Request with forward direction should fail
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
def test_invalid_limit_parameter(self) -> None:
"""
Test that invalid limit values are rejected.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Zero limit should fail
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
# Negative limit should fail
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
def test_invalid_pagination_tokens(self) -> None:
"""
Test that invalid from/to tokens are rejected with appropriate errors.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Invalid from token
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
# Invalid to token
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
def test_to_token_filtering(self) -> None:
"""
Test that the to_token parameter correctly limits pagination to updates
newer than the to_token (since we paginate backwards from newest to oldest).
This also verifies the to_token boundary is exclusive - updates at exactly
the to_token position should not be included (as they were already returned
in a previous response that synced up to that position).
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create two thread roots
thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[
"event_id"
]
thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[
"event_id"
]
# Send replies to both threads
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 1",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread1_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 2",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread2_root_id,
},
},
tok=user1_tok,
)
# Request with limit=1 to get only the latest thread update
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
self.assertIn("next_batch", channel.json_body)
# next_batch points to before the update we just received
next_batch = channel.json_body["next_batch"]
first_response_threads = set(channel.json_body["chunk"][room_id].keys())
# Request again with to=next_batch (lower bound for backward pagination) and no
# limit.
# This should get only the same thread updates as before, not the additional
# update.
channel = self.make_request(
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
chunk = channel.json_body["chunk"]
self.assertIn(room_id, chunk)
# Should have exactly one thread update
self.assertEqual(len(chunk[room_id]), 1)
second_response_threads = set(chunk[room_id].keys())
# Verify no overlap - the from parameter boundary should be exclusive
self.assertEqual(
first_response_threads,
second_response_threads,
"to parameter boundary should be exclusive - both responses should be identical",
)
def test_bundled_aggregations_on_thread_roots(self) -> None:
"""
Test that thread root events include bundled aggregations with latest thread event.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create thread root
thread_root_id = self.helper.send(room_id, body="Thread root", tok=user1_tok)[
"event_id"
]
# Send replies to create bundled aggregation data
for i in range(2):
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": f"Reply {i + 1}",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
# Check that thread root has bundled aggregations with latest event
chunk = channel.json_body["chunk"]
thread_update = chunk[room_id][thread_root_id]
thread_root_event = thread_update["thread_root"]
# Should have unsigned data with latest thread event content
self.assertIn("unsigned", thread_root_event)
self.assertIn("m.relations", thread_root_event["unsigned"])
relations = thread_root_event["unsigned"]["m.relations"]
self.assertIn(RelationTypes.THREAD, relations)
# Check latest event is present in bundled aggregations
thread_summary = relations[RelationTypes.THREAD]
self.assertIn("latest_event", thread_summary)
latest_event = thread_summary["latest_event"]
self.assertEqual(latest_event["content"]["body"], "Reply 2")
def test_only_joined_rooms(self) -> None:
"""
Test that thread updates only include rooms where the user is currently joined.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# Create two rooms, user1 joins both
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room2_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room2_id, user1_id, tok=user1_tok)
# Create threads in both rooms
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
"event_id"
]
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user2_tok)[
"event_id"
]
# Add replies to both threads
self.helper.send_event(
room1_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 1",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread1_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
room2_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 2",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread2_root_id,
},
},
tok=user2_tok,
)
# User1 leaves room2
self.helper.leave(room2_id, user1_id, tok=user1_tok)
# Request thread updates for user1 - should only get room1
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
chunk = channel.json_body["chunk"]
# Should only have room1, not room2
self.assertIn(room1_id, chunk)
self.assertNotIn(room2_id, chunk)
self.assertIn(thread1_root_id, chunk[room1_id])
def test_room_filtering_with_lists(self) -> None:
"""
Test that room filtering works correctly using the lists parameter.
This verifies that thread updates are only returned for rooms matching
the provided filters.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create an encrypted room and an unencrypted room
encrypted_room_id = self.helper.create_room_as(
user1_id,
tok=user1_tok,
extra_content={
"initial_state": [
{
"type": "m.room.encryption",
"state_key": "",
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
}
]
},
)
unencrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in both rooms
enc_thread_root_id = self.helper.send(
encrypted_room_id, body="Encrypted thread", tok=user1_tok
)["event_id"]
unenc_thread_root_id = self.helper.send(
unencrypted_room_id, body="Unencrypted thread", tok=user1_tok
)["event_id"]
# Add replies to both threads
self.helper.send_event(
encrypted_room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply in encrypted room",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": enc_thread_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
unencrypted_room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply in unencrypted room",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": unenc_thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with filter for encrypted rooms only
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"lists": {
"encrypted_list": {
"ranges": [[0, 99]],
"required_state": [["m.room.encryption", ""]],
"timeline_limit": 10,
"filters": {"is_encrypted": True},
}
}
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should only include the encrypted room
self.assertIn(encrypted_room_id, chunk)
self.assertNotIn(unencrypted_room_id, chunk)
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
def test_room_filtering_with_room_subscriptions(self) -> None:
"""
Test that room filtering works correctly using the room_subscriptions parameter.
This verifies that thread updates are only returned for explicitly subscribed rooms.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create three rooms
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room3_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in all three rooms
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
"event_id"
]
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user1_tok)[
"event_id"
]
thread3_root_id = self.helper.send(room3_id, body="Thread 3", tok=user1_tok)[
"event_id"
]
# Add replies to all threads
for room_id, thread_root_id in [
(room1_id, thread1_root_id),
(room2_id, thread2_root_id),
(room3_id, thread3_root_id),
]:
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with subscription to only room1 and room2
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"room_subscriptions": {
room1_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
},
room2_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
},
}
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should only include room1 and room2, not room3
self.assertIn(room1_id, chunk)
self.assertIn(room2_id, chunk)
self.assertNotIn(room3_id, chunk)
self.assertIn(thread1_root_id, chunk[room1_id])
self.assertIn(thread2_root_id, chunk[room2_id])
def test_room_filtering_with_lists_and_room_subscriptions(self) -> None:
"""
Test that room filtering works correctly when both lists and room_subscriptions
are provided. The union of rooms from both should be included.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create an encrypted room and two unencrypted rooms
encrypted_room_id = self.helper.create_room_as(
user1_id,
tok=user1_tok,
extra_content={
"initial_state": [
{
"type": "m.room.encryption",
"state_key": "",
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
}
]
},
)
unencrypted_room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
unencrypted_room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in all three rooms
enc_thread_root_id = self.helper.send(
encrypted_room_id, body="Encrypted thread", tok=user1_tok
)["event_id"]
unenc1_thread_root_id = self.helper.send(
unencrypted_room1_id, body="Unencrypted thread 1", tok=user1_tok
)["event_id"]
unenc2_thread_root_id = self.helper.send(
unencrypted_room2_id, body="Unencrypted thread 2", tok=user1_tok
)["event_id"]
# Add replies to all threads
for room_id, thread_root_id in [
(encrypted_room_id, enc_thread_root_id),
(unencrypted_room1_id, unenc1_thread_root_id),
(unencrypted_room2_id, unenc2_thread_root_id),
]:
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with:
# - lists: filter for encrypted rooms
# - room_subscriptions: explicitly subscribe to unencrypted_room1_id
# Expected: should get both encrypted_room_id (from list) and unencrypted_room1_id
# (from subscription), but NOT unencrypted_room2_id
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"lists": {
"encrypted_list": {
"ranges": [[0, 99]],
"required_state": [["m.room.encryption", ""]],
"timeline_limit": 10,
"filters": {"is_encrypted": True},
}
},
"room_subscriptions": {
unencrypted_room1_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
}
},
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should include encrypted_room_id (from list filter) and unencrypted_room1_id
# (from subscription), but not unencrypted_room2_id
self.assertIn(encrypted_room_id, chunk)
self.assertIn(unencrypted_room1_id, chunk)
self.assertNotIn(unencrypted_room2_id, chunk)
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
self.assertIn(unenc1_thread_root_id, chunk[unencrypted_room1_id])
def test_threads_not_returned_after_leaving_room(self) -> None:
"""
Test that thread updates are properly bounded when a user leaves a room.
Users should see thread updates that occurred up to the point they left,
but NOT updates that occurred after they left.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# Create room and both users join
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
self.helper.join(room_id, user2_id, tok=user2_tok)
# Create thread
res = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root = res["event_id"]
# Reply in thread while user2 is joined
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply 1 while user2 joined",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root,
},
},
tok=user1_tok,
)
# User2 leaves the room
self.helper.leave(room_id, user2_id, tok=user2_tok)
# Another reply after user2 left
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply 2 after user2 left",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root,
},
},
tok=user1_tok,
)
# User2 gets thread updates with an explicit room subscription
# (We need to explicitly subscribe to the room to include it after leaving;
# otherwise only joined rooms are returned)
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=100",
{
"room_subscriptions": {
room_id: {
"required_state": [],
"timeline_limit": 0,
}
}
},
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert: User2 SHOULD see Reply 1 (happened while joined) but NOT Reply 2 (after leaving)
chunk = channel.json_body["chunk"]
self.assertIn(
room_id,
chunk,
"Thread updates should include the room user2 left",
)
self.assertIn(
thread_root,
chunk[room_id],
"Thread root should be in the updates",
)
# Verify that only a single update was seen (Reply 1) by checking that there's
# no prev_batch token. If Reply 2 was also included, there would be multiple
# updates and a prev_batch token would be present.
thread_update = chunk[room_id][thread_root]
self.assertNotIn(
"prev_batch",
thread_update,
"No prev_batch should be present since only one update (Reply 1) is visible",
)