From ba3c8a5f3ed32af670cf2d93eea5d48985f9384b Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Sun, 9 Nov 2025 20:45:32 -0700 Subject: [PATCH] Refactor thread updates to use the same logic between endpoint and extension --- synapse/handlers/relations.py | 361 +++++++++++++++++-- synapse/handlers/sliding_sync/extensions.py | 190 +--------- synapse/rest/client/relations.py | 193 +++++----- synapse/rest/client/sync.py | 58 ++- synapse/types/handlers/sliding_sync.py | 28 +- synapse/types/rest/client/__init__.py | 30 +- tests/rest/client/test_thread_updates.py | 376 +++++++++++++++++++- 7 files changed, 845 insertions(+), 391 deletions(-) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 22d9f418fc..3cd964886b 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -31,7 +31,7 @@ from typing import ( import attr -from synapse.api.constants import Direction, EventTypes, RelationTypes +from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.events.utils import SerializeEventConfig @@ -43,12 +43,22 @@ from synapse.storage.databases.main.relations import ( _RelatedEvent, ) from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, UserID +from synapse.types import ( + JsonDict, + Requester, + RoomStreamToken, + StreamKeyType, + StreamToken, + UserID, +) from synapse.util.async_helpers import gather_results from synapse.visibility import filter_events_for_client if TYPE_CHECKING: + from synapse.events.utils import EventClientSerializer + from synapse.handlers.sliding_sync.room_lists import RoomsForUserType from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -59,6 +69,22 @@ ThreadRootsMap = dict[str, EventBase] AggregationsMap = dict[str, "BundledAggregations"] +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadUpdate: + """ + Data for a single thread update. + + Attributes: + thread_root: The thread root event, or None if not requested/not visible + prev_batch: Per-thread pagination token for fetching older events in this thread + bundled_aggregations: Bundled aggregations for the thread root event + """ + + thread_root: EventBase | None + prev_batch: StreamToken | None + bundled_aggregations: "BundledAggregations | None" = None + + class ThreadsListInclude(str, enum.Enum): """Valid values for the 'include' flag of /threads.""" @@ -554,7 +580,7 @@ class RelationsHandler: return results - async def process_thread_updates_for_visibility( + async def _filter_thread_updates_for_user( self, all_thread_updates: ThreadUpdatesMap, user_id: str, @@ -596,37 +622,324 @@ class RelationsHandler: return filtered_updates - async def fetch_thread_roots_and_aggregations( + def _build_thread_updates_response( self, - thread_ids: Collection[str], - user_id: str, - ) -> tuple[ThreadRootsMap, AggregationsMap]: - """Fetch thread root events and their bundled aggregations. + filtered_updates: ThreadUpdatesMap, + thread_root_event_map: ThreadRootsMap, + aggregations_map: AggregationsMap, + global_prev_batch_token: StreamToken | None, + ) -> dict[str, dict[str, ThreadUpdate]]: + """Build thread update response structure with per-thread prev_batch tokens. Args: - thread_ids: The thread root event IDs to fetch - user_id: The user ID requesting the aggregations + filtered_updates: Map of thread_root_id to list of ThreadUpdateInfo + thread_root_event_map: Map of thread_root_id to EventBase + aggregations_map: Map of thread_root_id to BundledAggregations + global_prev_batch_token: Global pagination token, or None if no more results + + Returns: + Map of room_id to thread_root_id to ThreadUpdate + """ + thread_updates: dict[str, dict[str, ThreadUpdate]] = {} + + for thread_root_id, updates in filtered_updates.items(): + # We only care about the latest update for the thread + # Updates are already sorted by stream_ordering DESC from the database query, + # and filter_events_for_client preserves order, so updates[0] is guaranteed to be + # the latest event for each thread. + latest_update = updates[0] + room_id = latest_update.room_id + + # Generate per-thread prev_batch token if this thread has multiple visible updates + # or if we hit the global limit. + # When we hit the global limit, we generate prev_batch tokens for all threads, even if + # we only saw 1 update for them. This is to cover the case where we only saw + # a single update for a given thread, but the global limit prevents us from + # obtaining other updates which would have otherwise been included in the range. + per_thread_prev_batch = None + if len(updates) > 1 or global_prev_batch_token is not None: + # Create a token pointing to one position before the latest event's stream position. + # This makes it exclusive - /relations with dir=b won't return the latest event again. + # Use StreamToken.START as base (all other streams at 0) since only room position matters. + per_thread_prev_batch = StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, + RoomStreamToken(stream=latest_update.stream_ordering - 1), + ) + + if room_id not in thread_updates: + thread_updates[room_id] = {} + + thread_updates[room_id][thread_root_id] = ThreadUpdate( + thread_root=thread_root_event_map.get(thread_root_id), + prev_batch=per_thread_prev_batch, + bundled_aggregations=aggregations_map.get(thread_root_id), + ) + + return thread_updates + + async def _fetch_thread_updates( + self, + room_ids: frozenset[str], + room_membership_map: Mapping[str, "RoomsForUserType"], + from_token: StreamToken | None, + to_token: StreamToken, + limit: int, + exclude_thread_ids: set[str] | None = None, + ) -> tuple[ThreadUpdatesMap, StreamToken | None]: + """Fetch thread updates across multiple rooms, handling membership states properly. + + This method separates rooms based on membership status (LEAVE/BAN vs others) + and queries them appropriately to prevent data leaks. For rooms where the user + has left or been banned, we bound the query to their leave/ban event position. + + Args: + room_ids: The set of room IDs to fetch thread updates for + room_membership_map: Map of room_id to RoomsForUserType containing membership info + from_token: Lower bound (exclusive) for the query, or None for no lower bound + to_token: Upper bound for the query (for joined/invited/knocking rooms) + limit: Maximum number of thread updates to return across all rooms + exclude_thread_ids: Optional set of thread IDs to exclude from results Returns: A tuple of: - - Map of event_id to EventBase for thread root events - - Map of event_id to BundledAggregations for those events + - Map of thread_id to list of ThreadUpdateInfo objects + - Global prev_batch token if there are more results, None otherwise """ - # Fetch thread root events - thread_root_events = await self._main_store.get_events_as_list(thread_ids) - thread_root_event_map: ThreadRootsMap = { - e.event_id: e for e in thread_root_events - } + # Separate rooms based on membership to handle LEAVE/BAN rooms specially + leave_ban_rooms: set[str] = set() + other_rooms: set[str] = set() - # Fetch bundled aggregations for the thread roots - aggregations_map: AggregationsMap = {} - if thread_root_event_map: - aggregations_map = await self.get_bundled_aggregations( - thread_root_event_map.values(), - user_id, + for room_id in room_ids: + membership_info = room_membership_map.get(room_id) + if membership_info and membership_info.membership in ( + Membership.LEAVE, + Membership.BAN, + ): + leave_ban_rooms.add(room_id) + else: + other_rooms.add(room_id) + + # Fetch thread updates from storage, handling LEAVE/BAN rooms separately + all_thread_updates: ThreadUpdatesMap = {} + prev_batch_token: StreamToken | None = None + remaining_limit = limit + + # Query LEAVE/BAN rooms with bounded to_token to prevent data leaks + if leave_ban_rooms: + for room_id in leave_ban_rooms: + if remaining_limit <= 0: + # We've hit the limit, set prev_batch to indicate more results + prev_batch_token = to_token + break + + membership_info = room_membership_map[room_id] + bounded_to_token = membership_info.event_pos.to_room_stream_token() + + ( + room_thread_updates, + room_prev_batch, + ) = await self._main_store.get_thread_updates_for_rooms( + room_ids={room_id}, + from_token=from_token.room_key if from_token else None, + to_token=bounded_to_token, + limit=remaining_limit, + exclude_thread_ids=exclude_thread_ids, + ) + + # Count updates and reduce remaining limit + num_updates = sum( + len(updates) for updates in room_thread_updates.values() + ) + remaining_limit -= num_updates + + # Merge updates + for thread_id, updates in room_thread_updates.items(): + all_thread_updates.setdefault(thread_id, []).extend(updates) + + # Merge prev_batch tokens (take the maximum for backward pagination) + if room_prev_batch is not None: + if prev_batch_token is None: + prev_batch_token = room_prev_batch + elif ( + room_prev_batch.room_key.stream + > prev_batch_token.room_key.stream + ): + prev_batch_token = room_prev_batch + + # Query other rooms (joined/invited/knocking) with normal to_token + if other_rooms and remaining_limit > 0: + ( + other_thread_updates, + other_prev_batch, + ) = await self._main_store.get_thread_updates_for_rooms( + room_ids=other_rooms, + from_token=from_token.room_key if from_token else None, + to_token=to_token.room_key, + limit=remaining_limit, + exclude_thread_ids=exclude_thread_ids, ) - return thread_root_event_map, aggregations_map + # Merge updates + for thread_id, updates in other_thread_updates.items(): + all_thread_updates.setdefault(thread_id, []).extend(updates) + + # Merge prev_batch tokens + if other_prev_batch is not None: + if prev_batch_token is None: + prev_batch_token = other_prev_batch + elif ( + other_prev_batch.room_key.stream > prev_batch_token.room_key.stream + ): + prev_batch_token = other_prev_batch + + return all_thread_updates, prev_batch_token + + async def get_thread_updates_for_rooms( + self, + room_ids: frozenset[str], + room_membership_map: Mapping[str, "RoomsForUserType"], + user_id: str, + from_token: StreamToken | None, + to_token: StreamToken, + limit: int, + include_roots: bool = False, + exclude_thread_ids: set[str] | None = None, + ) -> tuple[dict[str, dict[str, ThreadUpdate]], StreamToken | None]: + """Get thread updates across multiple rooms with full processing pipeline. + + This is the main entry point for fetching thread updates. It handles: + - Fetching updates with membership-based security + - Filtering for visibility + - Optionally fetching thread roots and aggregations + - Building the response structure + + Args: + room_ids: The set of room IDs to fetch updates for + room_membership_map: Map of room_id to RoomsForUserType for membership info + user_id: The user requesting the updates + from_token: Lower bound (exclusive) for the query + to_token: Upper bound for the query + limit: Maximum number of updates to return + include_roots: Whether to fetch and include thread root events (default: False) + exclude_thread_ids: Optional set of thread IDs to exclude + + Returns: + A tuple of: + - Map of room_id to thread_root_id to ThreadUpdate + - Global prev_batch token if there are more results, None otherwise + """ + # Fetch thread updates with membership handling + all_thread_updates, prev_batch_token = await self._fetch_thread_updates( + room_ids=room_ids, + room_membership_map=room_membership_map, + from_token=from_token, + to_token=to_token, + limit=limit, + exclude_thread_ids=exclude_thread_ids, + ) + + if not all_thread_updates: + return {}, prev_batch_token + + # Filter thread updates for visibility + filtered_updates = await self._filter_thread_updates_for_user( + all_thread_updates, user_id + ) + + if not filtered_updates: + return {}, prev_batch_token + + # Optionally fetch thread root events and their bundled aggregations + thread_root_event_map: ThreadRootsMap = {} + aggregations_map: AggregationsMap = {} + if include_roots: + # Fetch thread root events + thread_root_events = await self._main_store.get_events_as_list( + filtered_updates.keys() + ) + thread_root_event_map = {e.event_id: e for e in thread_root_events} + + # Fetch bundled aggregations for the thread roots + if thread_root_event_map: + aggregations_map = await self.get_bundled_aggregations( + thread_root_event_map.values(), + user_id, + ) + + # Build response structure with per-thread prev_batch tokens + thread_updates = self._build_thread_updates_response( + filtered_updates=filtered_updates, + thread_root_event_map=thread_root_event_map, + aggregations_map=aggregations_map, + global_prev_batch_token=prev_batch_token, + ) + + return thread_updates, prev_batch_token + + @staticmethod + async def serialize_thread_updates( + thread_updates: Mapping[str, Mapping[str, ThreadUpdate]], + prev_batch_token: StreamToken | None, + event_serializer: "EventClientSerializer", + time_now: int, + store: "DataStore", + serialize_options: SerializeEventConfig, + ) -> JsonDict: + """ + Serialize thread updates to JSON format. + + This helper handles serialization of ThreadUpdate objects for both the + companion endpoint and the sliding sync extension. + + Args: + thread_updates: Map of room_id to thread_root_id to ThreadUpdate + prev_batch_token: Global pagination token for fetching more updates + event_serializer: The event serializer to use + time_now: Current time in milliseconds for event serialization + store: Datastore for serializing stream tokens + serialize_options: Serialization config + + Returns: + JSON-serializable dict with "updates" and optionally "prev_batch" + """ + updates_dict: JsonDict = {} + + for room_id, room_threads in thread_updates.items(): + room_updates: JsonDict = {} + for thread_root_id, update in room_threads.items(): + update_dict: JsonDict = {} + + # Serialize thread_root event if present + if update.thread_root is not None: + bundle_aggs_map = ( + {thread_root_id: update.bundled_aggregations} + if update.bundled_aggregations is not None + else None + ) + serialized_events = await event_serializer.serialize_events( + [update.thread_root], + time_now, + config=serialize_options, + bundle_aggregations=bundle_aggs_map, + ) + if serialized_events: + update_dict["thread_root"] = serialized_events[0] + + # Add per-thread prev_batch if present + if update.prev_batch is not None: + update_dict["prev_batch"] = await update.prev_batch.to_string(store) + + room_updates[thread_root_id] = update_dict + + updates_dict[room_id] = room_updates + + result: JsonDict = {"updates": updates_dict} + + # Add global prev_batch token if present + if prev_batch_token is not None: + result["prev_batch"] = await prev_batch_token.to_string(store) + + return result async def get_threads( self, diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index 04d35e05bb..914ef1f8cc 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -30,28 +30,20 @@ from synapse.api.constants import ( AccountDataTypes, EduTypes, EventContentFields, - Membership, MRelatesToFields, RelationTypes, ) from synapse.events import EventBase from synapse.handlers.receipts import ReceiptEventSource -from synapse.handlers.relations import ( - AggregationsMap, - ThreadRootsMap, -) from synapse.handlers.sliding_sync.room_lists import RoomsForUserType from synapse.logging.opentracing import trace from synapse.storage.databases.main.receipts import ReceiptInRoom -from synapse.storage.databases.main.relations import ThreadUpdateInfo from synapse.types import ( DeviceListUpdates, JsonMapping, MultiWriterStreamToken, - RoomStreamToken, SlidingSyncStreamToken, StrCollection, - StreamKeyType, StreamToken, ThreadSubscriptionsToken, ) @@ -74,7 +66,6 @@ _ThreadSubscription: TypeAlias = ( _ThreadUnsubscription: TypeAlias = ( SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription ) -_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate if TYPE_CHECKING: from synapse.server import HomeServer @@ -1042,42 +1033,6 @@ class SlidingSyncExtensionHandler: threads_in_timeline.add(thread_id) return threads_in_timeline - def _merge_prev_batch_token( - self, - current_token: StreamToken | None, - new_token: StreamToken | None, - ) -> StreamToken | None: - """Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination. - - Args: - current_token: The current prev_batch token (may be None) - new_token: The new prev_batch token to merge (may be None) - - Returns: - The merged token (maximum of the two, or None if both are None) - """ - if new_token is None: - return current_token - if current_token is None: - return new_token - if new_token.room_key.stream > current_token.room_key.stream: - return new_token - return current_token - - def _merge_thread_updates( - self, - target: dict[str, list[ThreadUpdateInfo]], - source: dict[str, list[ThreadUpdateInfo]], - ) -> None: - """Merge thread updates from source into target. - - Args: - target: The target dict to merge into (modified in place) - source: The source dict to merge from - """ - for thread_id, updates in source.items(): - target.setdefault(thread_id, []).extend(updates) - async def get_threads_extension_response( self, sync_config: SlidingSyncConfig, @@ -1119,141 +1074,26 @@ class SlidingSyncExtensionHandler: actual_room_response_map ) - # Separate rooms into groups based on membership status. - # For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events - # that occurred after the user left/was banned. - leave_ban_rooms: set[str] = set() - other_rooms: set[str] = set() - - for room_id in actual_room_ids: - membership_info = room_membership_for_user_at_to_token_map.get(room_id) - if membership_info and membership_info.membership in ( - Membership.LEAVE, - Membership.BAN, - ): - leave_ban_rooms.add(room_id) - else: - other_rooms.add(room_id) - - # Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks. - all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {} - prev_batch_token: StreamToken | None = None - remaining_limit = threads_request.limit - - # Query for rooms where the user has left or been banned, using their leave/ban - # event position as the upper bound to prevent seeing events after they left. - if leave_ban_rooms: - for room_id in leave_ban_rooms: - if remaining_limit <= 0: - # We've already fetched enough updates, but we still need to set - # prev_batch to indicate there are more results. - prev_batch_token = to_token - break - - membership_info = room_membership_for_user_at_to_token_map[room_id] - bounded_to_token = membership_info.event_pos.to_room_stream_token() - - ( - room_thread_updates, - room_prev_batch, - ) = await self.store.get_thread_updates_for_rooms( - room_ids={room_id}, - from_token=from_token.stream_token.room_key if from_token else None, - to_token=bounded_to_token, - limit=remaining_limit, - exclude_thread_ids=threads_to_exclude, - ) - - # Count how many updates we fetched and reduce the remaining limit - num_updates = sum( - len(updates) for updates in room_thread_updates.values() - ) - remaining_limit -= num_updates - - self._merge_thread_updates(all_thread_updates, room_thread_updates) - prev_batch_token = self._merge_prev_batch_token( - prev_batch_token, room_prev_batch - ) - - # Query for rooms where the user is joined, invited, or knocking, using the - # normal to_token as the upper bound. - if other_rooms and remaining_limit > 0: - ( - other_thread_updates, - other_prev_batch, - ) = await self.store.get_thread_updates_for_rooms( - room_ids=other_rooms, - from_token=from_token.stream_token.room_key if from_token else None, - to_token=to_token.room_key, - limit=remaining_limit, - exclude_thread_ids=threads_to_exclude, - ) - - self._merge_thread_updates(all_thread_updates, other_thread_updates) - prev_batch_token = self._merge_prev_batch_token( - prev_batch_token, other_prev_batch - ) - - if len(all_thread_updates) == 0: - return None - - # Filter thread updates for visibility + # Get thread updates using unified helper user_id = sync_config.user.to_string() - filtered_updates = ( - await self.relations_handler.process_thread_updates_for_visibility( - all_thread_updates, user_id - ) + ( + thread_updates_response, + prev_batch_token, + ) = await self.relations_handler.get_thread_updates_for_rooms( + room_ids=frozenset(actual_room_ids), + room_membership_map=room_membership_for_user_at_to_token_map, + user_id=user_id, + from_token=from_token.stream_token if from_token else None, + to_token=to_token, + limit=threads_request.limit, + include_roots=threads_request.include_roots, + exclude_thread_ids=threads_to_exclude, ) - if not filtered_updates: + if not thread_updates_response: return None - # Note: Updates are already sorted by stream_ordering DESC from the database query, - # and filter_events_for_client preserves order, so updates[0] is guaranteed to be - # the latest event for each thread. - - # Optionally fetch thread root events and their bundled aggregations - thread_root_event_map: ThreadRootsMap = {} - aggregations_map: AggregationsMap = {} - if threads_request.include_roots: - ( - thread_root_event_map, - aggregations_map, - ) = await self.relations_handler.fetch_thread_roots_and_aggregations( - filtered_updates.keys(), user_id - ) - - thread_updates: dict[str, dict[str, _ThreadUpdate]] = {} - for thread_root, updates in filtered_updates.items(): - # We only care about the latest update for the thread. - # After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering). - latest_update = updates[0] - - # Generate per-thread prev_batch token if this thread has multiple visible updates. - # When we hit the global limit, we generate prev_batch tokens for all threads, even if - # we only saw 1 update for them. This is to cover the case where we only saw - # a single update for a given thread, but the global limit prevents us from - # obtaining other updates which would have otherwise been included in the - # range. - per_thread_prev_batch = None - if len(updates) > 1 or prev_batch_token is not None: - # Create a token pointing to one position before the latest event's stream position. - # This makes it exclusive - /relations with dir=b won't return the latest event again. - # Use StreamToken.START as base (all other streams at 0) since only room position matters. - per_thread_prev_batch = StreamToken.START.copy_and_replace( - StreamKeyType.ROOM, - RoomStreamToken(stream=latest_update.stream_ordering - 1), - ) - - thread_updates.setdefault(latest_update.room_id, {})[thread_root] = ( - _ThreadUpdate( - thread_root=thread_root_event_map.get(thread_root), - prev_batch=per_thread_prev_batch, - bundled_aggregations=aggregations_map.get(thread_root), - ) - ) - return SlidingSyncResult.Extensions.ThreadsExtension( - updates=thread_updates, + updates=thread_updates_response, prev_batch=prev_batch_token, ) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 09a61425cc..6ce64906e5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -22,7 +22,7 @@ import logging import re from typing import TYPE_CHECKING -from synapse.api.constants import Direction +from synapse.api.constants import Direction, Membership from synapse.api.errors import SynapseError from synapse.events.utils import SerializeEventConfig from synapse.handlers.relations import ThreadsListInclude @@ -41,7 +41,8 @@ from synapse.streams.config import ( PaginationConfig, extract_stream_token_from_pagination_token, ) -from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken +from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID +from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig from synapse.types.rest.client import ThreadUpdatesBody if TYPE_CHECKING: @@ -166,8 +167,7 @@ class ThreadUpdatesServlet(RestServlet): self.relations_handler = hs.get_relations_handler() self.event_serializer = hs.get_event_client_serializer() self._storage_controllers = hs.get_storage_controllers() - # TODO: Get sliding sync handler for filter_rooms logic - # self.sliding_sync_handler = hs.get_sliding_sync_handler() + self.sliding_sync_handler = hs.get_sliding_sync_handler() async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -191,135 +191,116 @@ class ThreadUpdatesServlet(RestServlet): to_token_str = parse_string(request, "to") # Parse pagination tokens - from_token: RoomStreamToken | None = None - to_token: RoomStreamToken | None = None + from_token: StreamToken | None = None + to_token: StreamToken | None = None if from_token_str: try: stream_token_str = extract_stream_token_from_pagination_token( from_token_str ) - stream_token = await StreamToken.from_string( - self.store, stream_token_str - ) - from_token = stream_token.room_key - except Exception: - raise SynapseError(400, "'from' parameter is invalid") + from_token = await StreamToken.from_string(self.store, stream_token_str) + except Exception as e: + logger.exception("Error parsing 'from' token: %s", from_token_str) + raise SynapseError(400, "'from' parameter is invalid") from e if to_token_str: try: stream_token_str = extract_stream_token_from_pagination_token( to_token_str ) - stream_token = await StreamToken.from_string( - self.store, stream_token_str - ) - to_token = stream_token.room_key + to_token = await StreamToken.from_string(self.store, stream_token_str) except Exception: raise SynapseError(400, "'to' parameter is invalid") # Get the list of rooms to fetch thread updates for - # Start with all joined rooms, then apply filters if provided user_id = requester.user.to_string() - room_ids = await self.store.get_rooms_for_user(user_id) + user = UserID.from_string(user_id) - if body.filters is not None: - # TODO: Apply filters using sliding sync room filter logic - # For now, if filters are provided, we need to call the sliding sync - # filter_rooms method to get the applicable room IDs - raise SynapseError(501, "Room filters not yet implemented") - - # Fetch thread updates from storage - # For backward pagination: - # - 'from' (upper bound, exclusive) maps to 'to_token' (inclusive with <=) - # Since next_batch is (last_returned - 1), <= excludes the last returned item - # - 'to' (lower bound, exclusive) maps to 'from_token' (exclusive with >) - ( - all_thread_updates, - prev_batch_token, - ) = await self.store.get_thread_updates_for_rooms( - room_ids=room_ids, - from_token=to_token, - to_token=from_token, - limit=limit, - ) - - if len(all_thread_updates) == 0: - return 200, {"chunk": {}} - - # Filter thread updates for visibility - filtered_updates = ( - await self.relations_handler.process_thread_updates_for_visibility( - all_thread_updates, user_id + # Get the current stream token for membership lookup + if from_token is None: + max_stream_ordering = self.store.get_room_max_stream_ordering() + current_token = StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, RoomStreamToken(stream=max_stream_ordering) ) + else: + current_token = from_token + + # Get room membership information to properly handle LEAVE/BAN rooms + ( + room_membership_for_user_at_to_token_map, + _, + _, + ) = await self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token( + user=user, + to_token=current_token, + from_token=None, ) - if not filtered_updates: + # Determine which rooms to fetch updates for based on lists/room_subscriptions + if body.lists is not None or body.room_subscriptions is not None: + # Use sliding sync room selection logic + sync_config = SlidingSyncConfig( + user=user, + requester=requester, + lists=body.lists, + room_subscriptions=body.room_subscriptions, + ) + + # Use the sliding sync room list handler to get the same set of rooms + interested_rooms = ( + await self.sliding_sync_handler.room_lists.compute_interested_rooms( + sync_config=sync_config, + previous_connection_state=PerConnectionState(), + to_token=current_token, + from_token=None, + ) + ) + + room_ids = frozenset(interested_rooms.relevant_room_map.keys()) + else: + # No lists or room_subscriptions, use only joined rooms + room_ids = frozenset( + room_id + for room_id, membership_info in room_membership_for_user_at_to_token_map.items() + if membership_info.membership == Membership.JOIN + ) + + # Get thread updates using unified helper + ( + thread_updates, + prev_batch_token, + ) = await self.relations_handler.get_thread_updates_for_rooms( + room_ids=room_ids, + room_membership_map=room_membership_for_user_at_to_token_map, + user_id=user_id, + from_token=to_token, + to_token=from_token if from_token else current_token, + limit=limit, + include_roots=body.include_roots, + ) + + if not thread_updates: return 200, {"chunk": {}} - # Fetch thread root events and their bundled aggregations - ( - thread_root_event_map, - aggregations_map, - ) = await self.relations_handler.fetch_thread_roots_and_aggregations( - filtered_updates.keys(), user_id - ) - - # Build response with per-thread data - # Updates are already sorted by stream_ordering DESC from the database query, - # and filter_events_for_client preserves order, so updates[0] is guaranteed to be - # the latest event for each thread. + # Serialize thread updates using shared helper time_now = self.clock.time_msec() serialize_options = SerializeEventConfig(requester=requester) - chunk: dict[str, dict[str, JsonDict]] = {} - for thread_root_id, updates in filtered_updates.items(): - # We only care about the latest update for the thread - latest_update = updates[0] - room_id = latest_update.room_id + serialized = await self.relations_handler.serialize_thread_updates( + thread_updates=thread_updates, + prev_batch_token=prev_batch_token, + event_serializer=self.event_serializer, + time_now=time_now, + store=self.store, + serialize_options=serialize_options, + ) - if room_id not in chunk: - chunk[room_id] = {} - - update_dict: JsonDict = {} - - # Add thread root if present - thread_root_event = thread_root_event_map.get(thread_root_id) - if thread_root_event is not None: - bundle_aggs_map = ( - {thread_root_id: aggregations_map[thread_root_id]} - if thread_root_id in aggregations_map - else None - ) - serialized_events = await self.event_serializer.serialize_events( - [thread_root_event], - time_now, - config=serialize_options, - bundle_aggregations=bundle_aggs_map, - ) - if serialized_events: - update_dict["thread_root"] = serialized_events[0] - - # Add per-thread prev_batch if this thread has multiple visible updates - if len(updates) > 1: - # Create a token pointing to one position before the latest event's stream position. - # This makes it exclusive - /relations with dir=b won't return the latest event again. - per_thread_prev_batch = StreamToken.START.copy_and_replace( - StreamKeyType.ROOM, - RoomStreamToken(stream=latest_update.stream_ordering - 1), - ) - update_dict["prev_batch"] = await per_thread_prev_batch.to_string( - self.store - ) - - chunk[room_id][thread_root_id] = update_dict - - # Build response - response: JsonDict = {"chunk": chunk} - - # Add next_batch token for pagination - if prev_batch_token is not None: - response["next_batch"] = await prev_batch_token.to_string(self.store) + # Build response with "chunk" wrapper and "next_batch" key + # (companion endpoint uses different key names than sliding sync) + response: JsonDict = {"chunk": serialized["updates"]} + if "prev_batch" in serialized: + response["next_batch"] = serialized["prev_batch"] return 200, response diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index b02ac8e4e1..e8f7417f50 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -37,6 +37,7 @@ from synapse.events.utils import ( format_event_raw, ) from synapse.handlers.presence import format_user_presence_state +from synapse.handlers.relations import RelationsHandler from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult from synapse.handlers.sync import ( ArchivedSyncResult, @@ -1107,6 +1108,7 @@ class SlidingSyncRestServlet(RestServlet): time_now, extensions.threads, self.store, + requester, ) return serialized_extensions @@ -1150,6 +1152,7 @@ async def _serialise_threads( time_now: int, threads: SlidingSyncResult.Extensions.ThreadsExtension, store: "DataStore", + requester: Requester, ) -> JsonDict: """ Serialize the threads extension response for sliding sync. @@ -1159,6 +1162,7 @@ async def _serialise_threads( time_now: The current time in milliseconds, used for event serialization. threads: The threads extension data containing thread updates and pagination tokens. store: The datastore, needed for serializing stream tokens. + requester: The user making the request, used for transaction_id inclusion. Returns: A JSON-serializable dict containing: @@ -1169,46 +1173,24 @@ async def _serialise_threads( - "prev_batch": A pagination token for fetching older events in the thread. - "prev_batch": A pagination token for fetching older thread updates (if available). """ - out: JsonDict = {} + if not threads.updates: + out: JsonDict = {} + if threads.prev_batch: + out["prev_batch"] = await threads.prev_batch.to_string(store) + return out - if threads.updates: - updates_dict: JsonDict = {} - for room_id, thread_updates in threads.updates.items(): - room_updates: JsonDict = {} - for thread_root_id, update in thread_updates.items(): - # Serialize the update - update_dict: JsonDict = {} + # Create serialization config to include transaction_id for requester's events + serialize_options = SerializeEventConfig(requester=requester) - # Serialize the thread_root event if present - if update.thread_root is not None: - # Create a mapping of event_id to bundled_aggregations - bundle_aggs_map = ( - {thread_root_id: update.bundled_aggregations} - if update.bundled_aggregations - else None - ) - serialized_events = await event_serializer.serialize_events( - [update.thread_root], - time_now, - bundle_aggregations=bundle_aggs_map, - ) - if serialized_events: - update_dict["thread_root"] = serialized_events[0] - - # Add prev_batch if present - if update.prev_batch is not None: - update_dict["prev_batch"] = await update.prev_batch.to_string(store) - - room_updates[thread_root_id] = update_dict - - updates_dict[room_id] = room_updates - - out["updates"] = updates_dict - - if threads.prev_batch: - out["prev_batch"] = await threads.prev_batch.to_string(store) - - return out + # Use shared serialization helper (static method) + return await RelationsHandler.serialize_thread_updates( + thread_updates=threads.updates, + prev_batch_token=threads.prev_batch, + event_serializer=event_serializer, + time_now=time_now, + store=store, + serialize_options=serialize_options, + ) def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index a5d90252b7..def7a709fa 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -37,7 +37,8 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase if TYPE_CHECKING: - from synapse.handlers.relations import BundledAggregations + from synapse.handlers.relations import BundledAggregations, ThreadUpdate + from synapse.types import ( DeviceListUpdates, JsonDict, @@ -409,30 +410,7 @@ class SlidingSyncResult: to paginate through older thread updates. """ - @attr.s(slots=True, frozen=True, auto_attribs=True) - class ThreadUpdate: - """Information about a single thread that has new activity. - - Attributes: - thread_root: The thread root event, if requested via include_roots in the - request. This is the event that started the thread. - prev_batch: A pagination token (exclusive) for fetching older events in this - specific thread. Only present if the thread has multiple updates in the - sync window. This token can be used with the /relations endpoint with - dir=b to paginate backwards through the thread's history. - bundled_aggregations: Bundled aggregations for the thread root event, - including the latest_event in the thread (found in - unsigned.m.relations.m.thread). Only present if thread_root is included. - """ - - thread_root: EventBase | None - prev_batch: StreamToken | None - bundled_aggregations: "BundledAggregations | None" = None - - def __bool__(self) -> bool: - return bool(self.thread_root) or bool(self.prev_batch) - - updates: Mapping[str, Mapping[str, ThreadUpdate]] | None + updates: Mapping[str, Mapping[str, "ThreadUpdate"]] | None prev_batch: StreamToken | None def __bool__(self) -> bool: diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index e10cca7d9b..28d5a2b566 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -433,14 +433,30 @@ class ThreadUpdatesBody(RequestBodyModel): """ Thread updates companion endpoint request body (MSC4360). - Allows filtering thread updates using the same filter criteria as sliding sync lists. - This enables clients to paginate thread updates using the same room filters that - were applied when generating the prev_batch token. + Allows paginating thread updates using the same room selection as a sliding sync + request. This enables clients to fetch thread updates for the same set of rooms + that were included in their sliding sync response. Attributes: - filters: Optional room filters to apply, using the same structure as - SlidingSyncList.Filters. If not provided, thread updates from all - joined rooms are returned. + lists: Sliding window API lists, using the same structure as SlidingSyncBody.lists. + If provided along with room_subscriptions, the union of rooms from both will + be used. + room_subscriptions: Room subscription API rooms, using the same structure as + SlidingSyncBody.room_subscriptions. If provided along with lists, the union + of rooms from both will be used. + include_roots: Whether to include the thread root events in the response. + Defaults to False. + + If neither lists nor room_subscriptions are provided, thread updates from all + joined rooms are returned. """ - filters: SlidingSyncBody.SlidingSyncList.Filters | None = None + lists: ( + dict[ + Annotated[str, StringConstraints(max_length=64, strict=True)], + SlidingSyncBody.SlidingSyncList, + ] + | None + ) = None + room_subscriptions: dict[StrictStr, SlidingSyncBody.RoomSubscription] | None = None + include_roots: StrictBool = False diff --git a/tests/rest/client/test_thread_updates.py b/tests/rest/client/test_thread_updates.py index fb62b877e7..1418a35dd5 100644 --- a/tests/rest/client/test_thread_updates.py +++ b/tests/rest/client/test_thread_updates.py @@ -58,7 +58,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", - content={}, + content={"include_roots": True}, access_token=user1_tok, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -99,7 +99,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -165,7 +165,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -227,7 +227,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -277,7 +277,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -293,7 +293,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -315,7 +315,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -347,7 +347,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 400) @@ -363,7 +363,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 400) @@ -372,7 +372,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 400) @@ -388,7 +388,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 400) @@ -397,7 +397,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 400) @@ -454,7 +454,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 200) self.assertIn("next_batch", channel.json_body) @@ -471,7 +471,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 200) @@ -523,7 +523,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 200) @@ -602,7 +602,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", access_token=user1_tok, - content={}, + content={"include_roots": True}, ) self.assertEqual(channel.code, 200) @@ -611,3 +611,347 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase): self.assertIn(room1_id, chunk) self.assertNotIn(room2_id, chunk) self.assertIn(thread1_root_id, chunk[room1_id]) + + def test_room_filtering_with_lists(self) -> None: + """ + Test that room filtering works correctly using the lists parameter. + This verifies that thread updates are only returned for rooms matching + the provided filters. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create an encrypted room and an unencrypted room + encrypted_room_id = self.helper.create_room_as( + user1_id, + tok=user1_tok, + extra_content={ + "initial_state": [ + { + "type": "m.room.encryption", + "state_key": "", + "content": {"algorithm": "m.megolm.v1.aes-sha2"}, + } + ] + }, + ) + unencrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create threads in both rooms + enc_thread_root_id = self.helper.send( + encrypted_room_id, body="Encrypted thread", tok=user1_tok + )["event_id"] + unenc_thread_root_id = self.helper.send( + unencrypted_room_id, body="Unencrypted thread", tok=user1_tok + )["event_id"] + + # Add replies to both threads + self.helper.send_event( + encrypted_room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply in encrypted room", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": enc_thread_root_id, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + unencrypted_room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply in unencrypted room", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": unenc_thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates with filter for encrypted rooms only + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={ + "lists": { + "encrypted_list": { + "ranges": [[0, 99]], + "required_state": [["m.room.encryption", ""]], + "timeline_limit": 10, + "filters": {"is_encrypted": True}, + } + } + }, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + # Should only include the encrypted room + self.assertIn(encrypted_room_id, chunk) + self.assertNotIn(unencrypted_room_id, chunk) + self.assertIn(enc_thread_root_id, chunk[encrypted_room_id]) + + def test_room_filtering_with_room_subscriptions(self) -> None: + """ + Test that room filtering works correctly using the room_subscriptions parameter. + This verifies that thread updates are only returned for explicitly subscribed rooms. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create three rooms + room1_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room2_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room3_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create threads in all three rooms + thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[ + "event_id" + ] + thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user1_tok)[ + "event_id" + ] + thread3_root_id = self.helper.send(room3_id, body="Thread 3", tok=user1_tok)[ + "event_id" + ] + + # Add replies to all threads + for room_id, thread_root_id in [ + (room1_id, thread1_root_id), + (room2_id, thread2_root_id), + (room3_id, thread3_root_id), + ]: + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates with subscription to only room1 and room2 + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={ + "room_subscriptions": { + room1_id: { + "required_state": [["m.room.name", ""]], + "timeline_limit": 10, + }, + room2_id: { + "required_state": [["m.room.name", ""]], + "timeline_limit": 10, + }, + } + }, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + # Should only include room1 and room2, not room3 + self.assertIn(room1_id, chunk) + self.assertIn(room2_id, chunk) + self.assertNotIn(room3_id, chunk) + self.assertIn(thread1_root_id, chunk[room1_id]) + self.assertIn(thread2_root_id, chunk[room2_id]) + + def test_room_filtering_with_lists_and_room_subscriptions(self) -> None: + """ + Test that room filtering works correctly when both lists and room_subscriptions + are provided. The union of rooms from both should be included. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create an encrypted room and two unencrypted rooms + encrypted_room_id = self.helper.create_room_as( + user1_id, + tok=user1_tok, + extra_content={ + "initial_state": [ + { + "type": "m.room.encryption", + "state_key": "", + "content": {"algorithm": "m.megolm.v1.aes-sha2"}, + } + ] + }, + ) + unencrypted_room1_id = self.helper.create_room_as(user1_id, tok=user1_tok) + unencrypted_room2_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create threads in all three rooms + enc_thread_root_id = self.helper.send( + encrypted_room_id, body="Encrypted thread", tok=user1_tok + )["event_id"] + unenc1_thread_root_id = self.helper.send( + unencrypted_room1_id, body="Unencrypted thread 1", tok=user1_tok + )["event_id"] + unenc2_thread_root_id = self.helper.send( + unencrypted_room2_id, body="Unencrypted thread 2", tok=user1_tok + )["event_id"] + + # Add replies to all threads + for room_id, thread_root_id in [ + (encrypted_room_id, enc_thread_root_id), + (unencrypted_room1_id, unenc1_thread_root_id), + (unencrypted_room2_id, unenc2_thread_root_id), + ]: + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates with: + # - lists: filter for encrypted rooms + # - room_subscriptions: explicitly subscribe to unencrypted_room1_id + # Expected: should get both encrypted_room_id (from list) and unencrypted_room1_id + # (from subscription), but NOT unencrypted_room2_id + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={ + "lists": { + "encrypted_list": { + "ranges": [[0, 99]], + "required_state": [["m.room.encryption", ""]], + "timeline_limit": 10, + "filters": {"is_encrypted": True}, + } + }, + "room_subscriptions": { + unencrypted_room1_id: { + "required_state": [["m.room.name", ""]], + "timeline_limit": 10, + } + }, + }, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + # Should include encrypted_room_id (from list filter) and unencrypted_room1_id + # (from subscription), but not unencrypted_room2_id + self.assertIn(encrypted_room_id, chunk) + self.assertIn(unencrypted_room1_id, chunk) + self.assertNotIn(unencrypted_room2_id, chunk) + self.assertIn(enc_thread_root_id, chunk[encrypted_room_id]) + self.assertIn(unenc1_thread_root_id, chunk[unencrypted_room1_id]) + + def test_threads_not_returned_after_leaving_room(self) -> None: + """ + Test that thread updates are properly bounded when a user leaves a room. + + Users should see thread updates that occurred up to the point they left, + but NOT updates that occurred after they left. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create room and both users join + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + self.helper.join(room_id, user2_id, tok=user2_tok) + + # Create thread + res = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root = res["event_id"] + + # Reply in thread while user2 is joined + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1 while user2 joined", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root, + }, + }, + tok=user1_tok, + ) + + # User2 leaves the room + self.helper.leave(room_id, user2_id, tok=user2_tok) + + # Another reply after user2 left + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 2 after user2 left", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root, + }, + }, + tok=user1_tok, + ) + + # User2 gets thread updates with an explicit room subscription + # (We need to explicitly subscribe to the room to include it after leaving; + # otherwise only joined rooms are returned) + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=100", + { + "room_subscriptions": { + room_id: { + "required_state": [], + "timeline_limit": 0, + } + } + }, + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Assert: User2 SHOULD see Reply 1 (happened while joined) but NOT Reply 2 (after leaving) + chunk = channel.json_body["chunk"] + self.assertIn( + room_id, + chunk, + "Thread updates should include the room user2 left", + ) + self.assertIn( + thread_root, + chunk[room_id], + "Thread root should be in the updates", + ) + + # Verify that only a single update was seen (Reply 1) by checking that there's + # no prev_batch token. If Reply 2 was also included, there would be multiple + # updates and a prev_batch token would be present. + thread_update = chunk[room_id][thread_root] + self.assertNotIn( + "prev_batch", + thread_update, + "No prev_batch should be present since only one update (Reply 1) is visible", + )