Refactor thread updates to use the same logic between endpoint and extension

This commit is contained in:
Devon Hudson
2025-11-09 20:45:32 -07:00
parent c826157c52
commit ba3c8a5f3e
7 changed files with 845 additions and 391 deletions

View File

@@ -31,7 +31,7 @@ from typing import (
import attr import attr
from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event from synapse.events import EventBase, relation_from_event
from synapse.events.utils import SerializeEventConfig from synapse.events.utils import SerializeEventConfig
@@ -43,12 +43,22 @@ from synapse.storage.databases.main.relations import (
_RelatedEvent, _RelatedEvent,
) )
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, UserID from synapse.types import (
JsonDict,
Requester,
RoomStreamToken,
StreamKeyType,
StreamToken,
UserID,
)
from synapse.util.async_helpers import gather_results from synapse.util.async_helpers import gather_results
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.events.utils import EventClientSerializer
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -59,6 +69,22 @@ ThreadRootsMap = dict[str, EventBase]
AggregationsMap = dict[str, "BundledAggregations"] AggregationsMap = dict[str, "BundledAggregations"]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdate:
"""
Data for a single thread update.
Attributes:
thread_root: The thread root event, or None if not requested/not visible
prev_batch: Per-thread pagination token for fetching older events in this thread
bundled_aggregations: Bundled aggregations for the thread root event
"""
thread_root: EventBase | None
prev_batch: StreamToken | None
bundled_aggregations: "BundledAggregations | None" = None
class ThreadsListInclude(str, enum.Enum): class ThreadsListInclude(str, enum.Enum):
"""Valid values for the 'include' flag of /threads.""" """Valid values for the 'include' flag of /threads."""
@@ -554,7 +580,7 @@ class RelationsHandler:
return results return results
async def process_thread_updates_for_visibility( async def _filter_thread_updates_for_user(
self, self,
all_thread_updates: ThreadUpdatesMap, all_thread_updates: ThreadUpdatesMap,
user_id: str, user_id: str,
@@ -596,37 +622,324 @@ class RelationsHandler:
return filtered_updates return filtered_updates
async def fetch_thread_roots_and_aggregations( def _build_thread_updates_response(
self, self,
thread_ids: Collection[str], filtered_updates: ThreadUpdatesMap,
user_id: str, thread_root_event_map: ThreadRootsMap,
) -> tuple[ThreadRootsMap, AggregationsMap]: aggregations_map: AggregationsMap,
"""Fetch thread root events and their bundled aggregations. global_prev_batch_token: StreamToken | None,
) -> dict[str, dict[str, ThreadUpdate]]:
"""Build thread update response structure with per-thread prev_batch tokens.
Args: Args:
thread_ids: The thread root event IDs to fetch filtered_updates: Map of thread_root_id to list of ThreadUpdateInfo
user_id: The user ID requesting the aggregations thread_root_event_map: Map of thread_root_id to EventBase
aggregations_map: Map of thread_root_id to BundledAggregations
global_prev_batch_token: Global pagination token, or None if no more results
Returns:
Map of room_id to thread_root_id to ThreadUpdate
"""
thread_updates: dict[str, dict[str, ThreadUpdate]] = {}
for thread_root_id, updates in filtered_updates.items():
# We only care about the latest update for the thread
# Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.
latest_update = updates[0]
room_id = latest_update.room_id
# Generate per-thread prev_batch token if this thread has multiple visible updates
# or if we hit the global limit.
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
# we only saw 1 update for them. This is to cover the case where we only saw
# a single update for a given thread, but the global limit prevents us from
# obtaining other updates which would have otherwise been included in the range.
per_thread_prev_batch = None
if len(updates) > 1 or global_prev_batch_token is not None:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
if room_id not in thread_updates:
thread_updates[room_id] = {}
thread_updates[room_id][thread_root_id] = ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root_id),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root_id),
)
return thread_updates
async def _fetch_thread_updates(
self,
room_ids: frozenset[str],
room_membership_map: Mapping[str, "RoomsForUserType"],
from_token: StreamToken | None,
to_token: StreamToken,
limit: int,
exclude_thread_ids: set[str] | None = None,
) -> tuple[ThreadUpdatesMap, StreamToken | None]:
"""Fetch thread updates across multiple rooms, handling membership states properly.
This method separates rooms based on membership status (LEAVE/BAN vs others)
and queries them appropriately to prevent data leaks. For rooms where the user
has left or been banned, we bound the query to their leave/ban event position.
Args:
room_ids: The set of room IDs to fetch thread updates for
room_membership_map: Map of room_id to RoomsForUserType containing membership info
from_token: Lower bound (exclusive) for the query, or None for no lower bound
to_token: Upper bound for the query (for joined/invited/knocking rooms)
limit: Maximum number of thread updates to return across all rooms
exclude_thread_ids: Optional set of thread IDs to exclude from results
Returns: Returns:
A tuple of: A tuple of:
- Map of event_id to EventBase for thread root events - Map of thread_id to list of ThreadUpdateInfo objects
- Map of event_id to BundledAggregations for those events - Global prev_batch token if there are more results, None otherwise
""" """
# Fetch thread root events # Separate rooms based on membership to handle LEAVE/BAN rooms specially
thread_root_events = await self._main_store.get_events_as_list(thread_ids) leave_ban_rooms: set[str] = set()
thread_root_event_map: ThreadRootsMap = { other_rooms: set[str] = set()
e.event_id: e for e in thread_root_events
}
# Fetch bundled aggregations for the thread roots for room_id in room_ids:
aggregations_map: AggregationsMap = {} membership_info = room_membership_map.get(room_id)
if thread_root_event_map: if membership_info and membership_info.membership in (
aggregations_map = await self.get_bundled_aggregations( Membership.LEAVE,
thread_root_event_map.values(), Membership.BAN,
user_id, ):
leave_ban_rooms.add(room_id)
else:
other_rooms.add(room_id)
# Fetch thread updates from storage, handling LEAVE/BAN rooms separately
all_thread_updates: ThreadUpdatesMap = {}
prev_batch_token: StreamToken | None = None
remaining_limit = limit
# Query LEAVE/BAN rooms with bounded to_token to prevent data leaks
if leave_ban_rooms:
for room_id in leave_ban_rooms:
if remaining_limit <= 0:
# We've hit the limit, set prev_batch to indicate more results
prev_batch_token = to_token
break
membership_info = room_membership_map[room_id]
bounded_to_token = membership_info.event_pos.to_room_stream_token()
(
room_thread_updates,
room_prev_batch,
) = await self._main_store.get_thread_updates_for_rooms(
room_ids={room_id},
from_token=from_token.room_key if from_token else None,
to_token=bounded_to_token,
limit=remaining_limit,
exclude_thread_ids=exclude_thread_ids,
)
# Count updates and reduce remaining limit
num_updates = sum(
len(updates) for updates in room_thread_updates.values()
)
remaining_limit -= num_updates
# Merge updates
for thread_id, updates in room_thread_updates.items():
all_thread_updates.setdefault(thread_id, []).extend(updates)
# Merge prev_batch tokens (take the maximum for backward pagination)
if room_prev_batch is not None:
if prev_batch_token is None:
prev_batch_token = room_prev_batch
elif (
room_prev_batch.room_key.stream
> prev_batch_token.room_key.stream
):
prev_batch_token = room_prev_batch
# Query other rooms (joined/invited/knocking) with normal to_token
if other_rooms and remaining_limit > 0:
(
other_thread_updates,
other_prev_batch,
) = await self._main_store.get_thread_updates_for_rooms(
room_ids=other_rooms,
from_token=from_token.room_key if from_token else None,
to_token=to_token.room_key,
limit=remaining_limit,
exclude_thread_ids=exclude_thread_ids,
) )
return thread_root_event_map, aggregations_map # Merge updates
for thread_id, updates in other_thread_updates.items():
all_thread_updates.setdefault(thread_id, []).extend(updates)
# Merge prev_batch tokens
if other_prev_batch is not None:
if prev_batch_token is None:
prev_batch_token = other_prev_batch
elif (
other_prev_batch.room_key.stream > prev_batch_token.room_key.stream
):
prev_batch_token = other_prev_batch
return all_thread_updates, prev_batch_token
async def get_thread_updates_for_rooms(
self,
room_ids: frozenset[str],
room_membership_map: Mapping[str, "RoomsForUserType"],
user_id: str,
from_token: StreamToken | None,
to_token: StreamToken,
limit: int,
include_roots: bool = False,
exclude_thread_ids: set[str] | None = None,
) -> tuple[dict[str, dict[str, ThreadUpdate]], StreamToken | None]:
"""Get thread updates across multiple rooms with full processing pipeline.
This is the main entry point for fetching thread updates. It handles:
- Fetching updates with membership-based security
- Filtering for visibility
- Optionally fetching thread roots and aggregations
- Building the response structure
Args:
room_ids: The set of room IDs to fetch updates for
room_membership_map: Map of room_id to RoomsForUserType for membership info
user_id: The user requesting the updates
from_token: Lower bound (exclusive) for the query
to_token: Upper bound for the query
limit: Maximum number of updates to return
include_roots: Whether to fetch and include thread root events (default: False)
exclude_thread_ids: Optional set of thread IDs to exclude
Returns:
A tuple of:
- Map of room_id to thread_root_id to ThreadUpdate
- Global prev_batch token if there are more results, None otherwise
"""
# Fetch thread updates with membership handling
all_thread_updates, prev_batch_token = await self._fetch_thread_updates(
room_ids=room_ids,
room_membership_map=room_membership_map,
from_token=from_token,
to_token=to_token,
limit=limit,
exclude_thread_ids=exclude_thread_ids,
)
if not all_thread_updates:
return {}, prev_batch_token
# Filter thread updates for visibility
filtered_updates = await self._filter_thread_updates_for_user(
all_thread_updates, user_id
)
if not filtered_updates:
return {}, prev_batch_token
# Optionally fetch thread root events and their bundled aggregations
thread_root_event_map: ThreadRootsMap = {}
aggregations_map: AggregationsMap = {}
if include_roots:
# Fetch thread root events
thread_root_events = await self._main_store.get_events_as_list(
filtered_updates.keys()
)
thread_root_event_map = {e.event_id: e for e in thread_root_events}
# Fetch bundled aggregations for the thread roots
if thread_root_event_map:
aggregations_map = await self.get_bundled_aggregations(
thread_root_event_map.values(),
user_id,
)
# Build response structure with per-thread prev_batch tokens
thread_updates = self._build_thread_updates_response(
filtered_updates=filtered_updates,
thread_root_event_map=thread_root_event_map,
aggregations_map=aggregations_map,
global_prev_batch_token=prev_batch_token,
)
return thread_updates, prev_batch_token
@staticmethod
async def serialize_thread_updates(
thread_updates: Mapping[str, Mapping[str, ThreadUpdate]],
prev_batch_token: StreamToken | None,
event_serializer: "EventClientSerializer",
time_now: int,
store: "DataStore",
serialize_options: SerializeEventConfig,
) -> JsonDict:
"""
Serialize thread updates to JSON format.
This helper handles serialization of ThreadUpdate objects for both the
companion endpoint and the sliding sync extension.
Args:
thread_updates: Map of room_id to thread_root_id to ThreadUpdate
prev_batch_token: Global pagination token for fetching more updates
event_serializer: The event serializer to use
time_now: Current time in milliseconds for event serialization
store: Datastore for serializing stream tokens
serialize_options: Serialization config
Returns:
JSON-serializable dict with "updates" and optionally "prev_batch"
"""
updates_dict: JsonDict = {}
for room_id, room_threads in thread_updates.items():
room_updates: JsonDict = {}
for thread_root_id, update in room_threads.items():
update_dict: JsonDict = {}
# Serialize thread_root event if present
if update.thread_root is not None:
bundle_aggs_map = (
{thread_root_id: update.bundled_aggregations}
if update.bundled_aggregations is not None
else None
)
serialized_events = await event_serializer.serialize_events(
[update.thread_root],
time_now,
config=serialize_options,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]
# Add per-thread prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
room_updates[thread_root_id] = update_dict
updates_dict[room_id] = room_updates
result: JsonDict = {"updates": updates_dict}
# Add global prev_batch token if present
if prev_batch_token is not None:
result["prev_batch"] = await prev_batch_token.to_string(store)
return result
async def get_threads( async def get_threads(
self, self,

View File

@@ -30,28 +30,20 @@ from synapse.api.constants import (
AccountDataTypes, AccountDataTypes,
EduTypes, EduTypes,
EventContentFields, EventContentFields,
Membership,
MRelatesToFields, MRelatesToFields,
RelationTypes, RelationTypes,
) )
from synapse.events import EventBase from synapse.events import EventBase
from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.relations import (
AggregationsMap,
ThreadRootsMap,
)
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.storage.databases.main.receipts import ReceiptInRoom from synapse.storage.databases.main.receipts import ReceiptInRoom
from synapse.storage.databases.main.relations import ThreadUpdateInfo
from synapse.types import ( from synapse.types import (
DeviceListUpdates, DeviceListUpdates,
JsonMapping, JsonMapping,
MultiWriterStreamToken, MultiWriterStreamToken,
RoomStreamToken,
SlidingSyncStreamToken, SlidingSyncStreamToken,
StrCollection, StrCollection,
StreamKeyType,
StreamToken, StreamToken,
ThreadSubscriptionsToken, ThreadSubscriptionsToken,
) )
@@ -74,7 +66,6 @@ _ThreadSubscription: TypeAlias = (
_ThreadUnsubscription: TypeAlias = ( _ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
) )
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@@ -1042,42 +1033,6 @@ class SlidingSyncExtensionHandler:
threads_in_timeline.add(thread_id) threads_in_timeline.add(thread_id)
return threads_in_timeline return threads_in_timeline
def _merge_prev_batch_token(
self,
current_token: StreamToken | None,
new_token: StreamToken | None,
) -> StreamToken | None:
"""Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination.
Args:
current_token: The current prev_batch token (may be None)
new_token: The new prev_batch token to merge (may be None)
Returns:
The merged token (maximum of the two, or None if both are None)
"""
if new_token is None:
return current_token
if current_token is None:
return new_token
if new_token.room_key.stream > current_token.room_key.stream:
return new_token
return current_token
def _merge_thread_updates(
self,
target: dict[str, list[ThreadUpdateInfo]],
source: dict[str, list[ThreadUpdateInfo]],
) -> None:
"""Merge thread updates from source into target.
Args:
target: The target dict to merge into (modified in place)
source: The source dict to merge from
"""
for thread_id, updates in source.items():
target.setdefault(thread_id, []).extend(updates)
async def get_threads_extension_response( async def get_threads_extension_response(
self, self,
sync_config: SlidingSyncConfig, sync_config: SlidingSyncConfig,
@@ -1119,141 +1074,26 @@ class SlidingSyncExtensionHandler:
actual_room_response_map actual_room_response_map
) )
# Separate rooms into groups based on membership status. # Get thread updates using unified helper
# For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events
# that occurred after the user left/was banned.
leave_ban_rooms: set[str] = set()
other_rooms: set[str] = set()
for room_id in actual_room_ids:
membership_info = room_membership_for_user_at_to_token_map.get(room_id)
if membership_info and membership_info.membership in (
Membership.LEAVE,
Membership.BAN,
):
leave_ban_rooms.add(room_id)
else:
other_rooms.add(room_id)
# Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks.
all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {}
prev_batch_token: StreamToken | None = None
remaining_limit = threads_request.limit
# Query for rooms where the user has left or been banned, using their leave/ban
# event position as the upper bound to prevent seeing events after they left.
if leave_ban_rooms:
for room_id in leave_ban_rooms:
if remaining_limit <= 0:
# We've already fetched enough updates, but we still need to set
# prev_batch to indicate there are more results.
prev_batch_token = to_token
break
membership_info = room_membership_for_user_at_to_token_map[room_id]
bounded_to_token = membership_info.event_pos.to_room_stream_token()
(
room_thread_updates,
room_prev_batch,
) = await self.store.get_thread_updates_for_rooms(
room_ids={room_id},
from_token=from_token.stream_token.room_key if from_token else None,
to_token=bounded_to_token,
limit=remaining_limit,
exclude_thread_ids=threads_to_exclude,
)
# Count how many updates we fetched and reduce the remaining limit
num_updates = sum(
len(updates) for updates in room_thread_updates.values()
)
remaining_limit -= num_updates
self._merge_thread_updates(all_thread_updates, room_thread_updates)
prev_batch_token = self._merge_prev_batch_token(
prev_batch_token, room_prev_batch
)
# Query for rooms where the user is joined, invited, or knocking, using the
# normal to_token as the upper bound.
if other_rooms and remaining_limit > 0:
(
other_thread_updates,
other_prev_batch,
) = await self.store.get_thread_updates_for_rooms(
room_ids=other_rooms,
from_token=from_token.stream_token.room_key if from_token else None,
to_token=to_token.room_key,
limit=remaining_limit,
exclude_thread_ids=threads_to_exclude,
)
self._merge_thread_updates(all_thread_updates, other_thread_updates)
prev_batch_token = self._merge_prev_batch_token(
prev_batch_token, other_prev_batch
)
if len(all_thread_updates) == 0:
return None
# Filter thread updates for visibility
user_id = sync_config.user.to_string() user_id = sync_config.user.to_string()
filtered_updates = ( (
await self.relations_handler.process_thread_updates_for_visibility( thread_updates_response,
all_thread_updates, user_id prev_batch_token,
) ) = await self.relations_handler.get_thread_updates_for_rooms(
room_ids=frozenset(actual_room_ids),
room_membership_map=room_membership_for_user_at_to_token_map,
user_id=user_id,
from_token=from_token.stream_token if from_token else None,
to_token=to_token,
limit=threads_request.limit,
include_roots=threads_request.include_roots,
exclude_thread_ids=threads_to_exclude,
) )
if not filtered_updates: if not thread_updates_response:
return None return None
# Note: Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.
# Optionally fetch thread root events and their bundled aggregations
thread_root_event_map: ThreadRootsMap = {}
aggregations_map: AggregationsMap = {}
if threads_request.include_roots:
(
thread_root_event_map,
aggregations_map,
) = await self.relations_handler.fetch_thread_roots_and_aggregations(
filtered_updates.keys(), user_id
)
thread_updates: dict[str, dict[str, _ThreadUpdate]] = {}
for thread_root, updates in filtered_updates.items():
# We only care about the latest update for the thread.
# After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering).
latest_update = updates[0]
# Generate per-thread prev_batch token if this thread has multiple visible updates.
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
# we only saw 1 update for them. This is to cover the case where we only saw
# a single update for a given thread, but the global limit prevents us from
# obtaining other updates which would have otherwise been included in the
# range.
per_thread_prev_batch = None
if len(updates) > 1 or prev_batch_token is not None:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
thread_updates.setdefault(latest_update.room_id, {})[thread_root] = (
_ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root),
)
)
return SlidingSyncResult.Extensions.ThreadsExtension( return SlidingSyncResult.Extensions.ThreadsExtension(
updates=thread_updates, updates=thread_updates_response,
prev_batch=prev_batch_token, prev_batch=prev_batch_token,
) )

View File

@@ -22,7 +22,7 @@ import logging
import re import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from synapse.api.constants import Direction from synapse.api.constants import Direction, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import SerializeEventConfig from synapse.events.utils import SerializeEventConfig
from synapse.handlers.relations import ThreadsListInclude from synapse.handlers.relations import ThreadsListInclude
@@ -41,7 +41,8 @@ from synapse.streams.config import (
PaginationConfig, PaginationConfig,
extract_stream_token_from_pagination_token, extract_stream_token_from_pagination_token,
) )
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig
from synapse.types.rest.client import ThreadUpdatesBody from synapse.types.rest.client import ThreadUpdatesBody
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -166,8 +167,7 @@ class ThreadUpdatesServlet(RestServlet):
self.relations_handler = hs.get_relations_handler() self.relations_handler = hs.get_relations_handler()
self.event_serializer = hs.get_event_client_serializer() self.event_serializer = hs.get_event_client_serializer()
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
# TODO: Get sliding sync handler for filter_rooms logic self.sliding_sync_handler = hs.get_sliding_sync_handler()
# self.sliding_sync_handler = hs.get_sliding_sync_handler()
async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@@ -191,135 +191,116 @@ class ThreadUpdatesServlet(RestServlet):
to_token_str = parse_string(request, "to") to_token_str = parse_string(request, "to")
# Parse pagination tokens # Parse pagination tokens
from_token: RoomStreamToken | None = None from_token: StreamToken | None = None
to_token: RoomStreamToken | None = None to_token: StreamToken | None = None
if from_token_str: if from_token_str:
try: try:
stream_token_str = extract_stream_token_from_pagination_token( stream_token_str = extract_stream_token_from_pagination_token(
from_token_str from_token_str
) )
stream_token = await StreamToken.from_string( from_token = await StreamToken.from_string(self.store, stream_token_str)
self.store, stream_token_str except Exception as e:
) logger.exception("Error parsing 'from' token: %s", from_token_str)
from_token = stream_token.room_key raise SynapseError(400, "'from' parameter is invalid") from e
except Exception:
raise SynapseError(400, "'from' parameter is invalid")
if to_token_str: if to_token_str:
try: try:
stream_token_str = extract_stream_token_from_pagination_token( stream_token_str = extract_stream_token_from_pagination_token(
to_token_str to_token_str
) )
stream_token = await StreamToken.from_string( to_token = await StreamToken.from_string(self.store, stream_token_str)
self.store, stream_token_str
)
to_token = stream_token.room_key
except Exception: except Exception:
raise SynapseError(400, "'to' parameter is invalid") raise SynapseError(400, "'to' parameter is invalid")
# Get the list of rooms to fetch thread updates for # Get the list of rooms to fetch thread updates for
# Start with all joined rooms, then apply filters if provided
user_id = requester.user.to_string() user_id = requester.user.to_string()
room_ids = await self.store.get_rooms_for_user(user_id) user = UserID.from_string(user_id)
if body.filters is not None: # Get the current stream token for membership lookup
# TODO: Apply filters using sliding sync room filter logic if from_token is None:
# For now, if filters are provided, we need to call the sliding sync max_stream_ordering = self.store.get_room_max_stream_ordering()
# filter_rooms method to get the applicable room IDs current_token = StreamToken.START.copy_and_replace(
raise SynapseError(501, "Room filters not yet implemented") StreamKeyType.ROOM, RoomStreamToken(stream=max_stream_ordering)
# Fetch thread updates from storage
# For backward pagination:
# - 'from' (upper bound, exclusive) maps to 'to_token' (inclusive with <=)
# Since next_batch is (last_returned - 1), <= excludes the last returned item
# - 'to' (lower bound, exclusive) maps to 'from_token' (exclusive with >)
(
all_thread_updates,
prev_batch_token,
) = await self.store.get_thread_updates_for_rooms(
room_ids=room_ids,
from_token=to_token,
to_token=from_token,
limit=limit,
)
if len(all_thread_updates) == 0:
return 200, {"chunk": {}}
# Filter thread updates for visibility
filtered_updates = (
await self.relations_handler.process_thread_updates_for_visibility(
all_thread_updates, user_id
) )
else:
current_token = from_token
# Get room membership information to properly handle LEAVE/BAN rooms
(
room_membership_for_user_at_to_token_map,
_,
_,
) = await self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
user=user,
to_token=current_token,
from_token=None,
) )
if not filtered_updates: # Determine which rooms to fetch updates for based on lists/room_subscriptions
if body.lists is not None or body.room_subscriptions is not None:
# Use sliding sync room selection logic
sync_config = SlidingSyncConfig(
user=user,
requester=requester,
lists=body.lists,
room_subscriptions=body.room_subscriptions,
)
# Use the sliding sync room list handler to get the same set of rooms
interested_rooms = (
await self.sliding_sync_handler.room_lists.compute_interested_rooms(
sync_config=sync_config,
previous_connection_state=PerConnectionState(),
to_token=current_token,
from_token=None,
)
)
room_ids = frozenset(interested_rooms.relevant_room_map.keys())
else:
# No lists or room_subscriptions, use only joined rooms
room_ids = frozenset(
room_id
for room_id, membership_info in room_membership_for_user_at_to_token_map.items()
if membership_info.membership == Membership.JOIN
)
# Get thread updates using unified helper
(
thread_updates,
prev_batch_token,
) = await self.relations_handler.get_thread_updates_for_rooms(
room_ids=room_ids,
room_membership_map=room_membership_for_user_at_to_token_map,
user_id=user_id,
from_token=to_token,
to_token=from_token if from_token else current_token,
limit=limit,
include_roots=body.include_roots,
)
if not thread_updates:
return 200, {"chunk": {}} return 200, {"chunk": {}}
# Fetch thread root events and their bundled aggregations # Serialize thread updates using shared helper
(
thread_root_event_map,
aggregations_map,
) = await self.relations_handler.fetch_thread_roots_and_aggregations(
filtered_updates.keys(), user_id
)
# Build response with per-thread data
# Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester) serialize_options = SerializeEventConfig(requester=requester)
chunk: dict[str, dict[str, JsonDict]] = {}
for thread_root_id, updates in filtered_updates.items(): serialized = await self.relations_handler.serialize_thread_updates(
# We only care about the latest update for the thread thread_updates=thread_updates,
latest_update = updates[0] prev_batch_token=prev_batch_token,
room_id = latest_update.room_id event_serializer=self.event_serializer,
time_now=time_now,
store=self.store,
serialize_options=serialize_options,
)
if room_id not in chunk: # Build response with "chunk" wrapper and "next_batch" key
chunk[room_id] = {} # (companion endpoint uses different key names than sliding sync)
response: JsonDict = {"chunk": serialized["updates"]}
update_dict: JsonDict = {} if "prev_batch" in serialized:
response["next_batch"] = serialized["prev_batch"]
# Add thread root if present
thread_root_event = thread_root_event_map.get(thread_root_id)
if thread_root_event is not None:
bundle_aggs_map = (
{thread_root_id: aggregations_map[thread_root_id]}
if thread_root_id in aggregations_map
else None
)
serialized_events = await self.event_serializer.serialize_events(
[thread_root_event],
time_now,
config=serialize_options,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]
# Add per-thread prev_batch if this thread has multiple visible updates
if len(updates) > 1:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
update_dict["prev_batch"] = await per_thread_prev_batch.to_string(
self.store
)
chunk[room_id][thread_root_id] = update_dict
# Build response
response: JsonDict = {"chunk": chunk}
# Add next_batch token for pagination
if prev_batch_token is not None:
response["next_batch"] = await prev_batch_token.to_string(self.store)
return 200, response return 200, response

View File

@@ -37,6 +37,7 @@ from synapse.events.utils import (
format_event_raw, format_event_raw,
) )
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.relations import RelationsHandler
from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult
from synapse.handlers.sync import ( from synapse.handlers.sync import (
ArchivedSyncResult, ArchivedSyncResult,
@@ -1107,6 +1108,7 @@ class SlidingSyncRestServlet(RestServlet):
time_now, time_now,
extensions.threads, extensions.threads,
self.store, self.store,
requester,
) )
return serialized_extensions return serialized_extensions
@@ -1150,6 +1152,7 @@ async def _serialise_threads(
time_now: int, time_now: int,
threads: SlidingSyncResult.Extensions.ThreadsExtension, threads: SlidingSyncResult.Extensions.ThreadsExtension,
store: "DataStore", store: "DataStore",
requester: Requester,
) -> JsonDict: ) -> JsonDict:
""" """
Serialize the threads extension response for sliding sync. Serialize the threads extension response for sliding sync.
@@ -1159,6 +1162,7 @@ async def _serialise_threads(
time_now: The current time in milliseconds, used for event serialization. time_now: The current time in milliseconds, used for event serialization.
threads: The threads extension data containing thread updates and pagination tokens. threads: The threads extension data containing thread updates and pagination tokens.
store: The datastore, needed for serializing stream tokens. store: The datastore, needed for serializing stream tokens.
requester: The user making the request, used for transaction_id inclusion.
Returns: Returns:
A JSON-serializable dict containing: A JSON-serializable dict containing:
@@ -1169,46 +1173,24 @@ async def _serialise_threads(
- "prev_batch": A pagination token for fetching older events in the thread. - "prev_batch": A pagination token for fetching older events in the thread.
- "prev_batch": A pagination token for fetching older thread updates (if available). - "prev_batch": A pagination token for fetching older thread updates (if available).
""" """
out: JsonDict = {} if not threads.updates:
out: JsonDict = {}
if threads.prev_batch:
out["prev_batch"] = await threads.prev_batch.to_string(store)
return out
if threads.updates: # Create serialization config to include transaction_id for requester's events
updates_dict: JsonDict = {} serialize_options = SerializeEventConfig(requester=requester)
for room_id, thread_updates in threads.updates.items():
room_updates: JsonDict = {}
for thread_root_id, update in thread_updates.items():
# Serialize the update
update_dict: JsonDict = {}
# Serialize the thread_root event if present # Use shared serialization helper (static method)
if update.thread_root is not None: return await RelationsHandler.serialize_thread_updates(
# Create a mapping of event_id to bundled_aggregations thread_updates=threads.updates,
bundle_aggs_map = ( prev_batch_token=threads.prev_batch,
{thread_root_id: update.bundled_aggregations} event_serializer=event_serializer,
if update.bundled_aggregations time_now=time_now,
else None store=store,
) serialize_options=serialize_options,
serialized_events = await event_serializer.serialize_events( )
[update.thread_root],
time_now,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]
# Add prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
room_updates[thread_root_id] = update_dict
updates_dict[room_id] = room_updates
out["updates"] = updates_dict
if threads.prev_batch:
out["prev_batch"] = await threads.prev_batch.to_string(store)
return out
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:

View File

@@ -37,7 +37,8 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations from synapse.handlers.relations import BundledAggregations, ThreadUpdate
from synapse.types import ( from synapse.types import (
DeviceListUpdates, DeviceListUpdates,
JsonDict, JsonDict,
@@ -409,30 +410,7 @@ class SlidingSyncResult:
to paginate through older thread updates. to paginate through older thread updates.
""" """
@attr.s(slots=True, frozen=True, auto_attribs=True) updates: Mapping[str, Mapping[str, "ThreadUpdate"]] | None
class ThreadUpdate:
"""Information about a single thread that has new activity.
Attributes:
thread_root: The thread root event, if requested via include_roots in the
request. This is the event that started the thread.
prev_batch: A pagination token (exclusive) for fetching older events in this
specific thread. Only present if the thread has multiple updates in the
sync window. This token can be used with the /relations endpoint with
dir=b to paginate backwards through the thread's history.
bundled_aggregations: Bundled aggregations for the thread root event,
including the latest_event in the thread (found in
unsigned.m.relations.m.thread). Only present if thread_root is included.
"""
thread_root: EventBase | None
prev_batch: StreamToken | None
bundled_aggregations: "BundledAggregations | None" = None
def __bool__(self) -> bool:
return bool(self.thread_root) or bool(self.prev_batch)
updates: Mapping[str, Mapping[str, ThreadUpdate]] | None
prev_batch: StreamToken | None prev_batch: StreamToken | None
def __bool__(self) -> bool: def __bool__(self) -> bool:

View File

@@ -433,14 +433,30 @@ class ThreadUpdatesBody(RequestBodyModel):
""" """
Thread updates companion endpoint request body (MSC4360). Thread updates companion endpoint request body (MSC4360).
Allows filtering thread updates using the same filter criteria as sliding sync lists. Allows paginating thread updates using the same room selection as a sliding sync
This enables clients to paginate thread updates using the same room filters that request. This enables clients to fetch thread updates for the same set of rooms
were applied when generating the prev_batch token. that were included in their sliding sync response.
Attributes: Attributes:
filters: Optional room filters to apply, using the same structure as lists: Sliding window API lists, using the same structure as SlidingSyncBody.lists.
SlidingSyncList.Filters. If not provided, thread updates from all If provided along with room_subscriptions, the union of rooms from both will
joined rooms are returned. be used.
room_subscriptions: Room subscription API rooms, using the same structure as
SlidingSyncBody.room_subscriptions. If provided along with lists, the union
of rooms from both will be used.
include_roots: Whether to include the thread root events in the response.
Defaults to False.
If neither lists nor room_subscriptions are provided, thread updates from all
joined rooms are returned.
""" """
filters: SlidingSyncBody.SlidingSyncList.Filters | None = None lists: (
dict[
Annotated[str, StringConstraints(max_length=64, strict=True)],
SlidingSyncBody.SlidingSyncList,
]
| None
) = None
room_subscriptions: dict[StrictStr, SlidingSyncBody.RoomSubscription] | None = None
include_roots: StrictBool = False

View File

@@ -58,7 +58,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
content={}, content={"include_roots": True},
access_token=user1_tok, access_token=user1_tok,
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
@@ -99,7 +99,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
@@ -165,7 +165,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
@@ -227,7 +227,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
@@ -277,7 +277,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
@@ -293,7 +293,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}", f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
@@ -315,7 +315,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}", f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
@@ -347,7 +347,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 400) self.assertEqual(channel.code, 400)
@@ -363,7 +363,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 400) self.assertEqual(channel.code, 400)
@@ -372,7 +372,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 400) self.assertEqual(channel.code, 400)
@@ -388,7 +388,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 400) self.assertEqual(channel.code, 400)
@@ -397,7 +397,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 400) self.assertEqual(channel.code, 400)
@@ -454,7 +454,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertIn("next_batch", channel.json_body) self.assertIn("next_batch", channel.json_body)
@@ -471,7 +471,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}", f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@@ -523,7 +523,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@@ -602,7 +602,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok, access_token=user1_tok,
content={}, content={"include_roots": True},
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@@ -611,3 +611,347 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
self.assertIn(room1_id, chunk) self.assertIn(room1_id, chunk)
self.assertNotIn(room2_id, chunk) self.assertNotIn(room2_id, chunk)
self.assertIn(thread1_root_id, chunk[room1_id]) self.assertIn(thread1_root_id, chunk[room1_id])
def test_room_filtering_with_lists(self) -> None:
"""
Test that room filtering works correctly using the lists parameter.
This verifies that thread updates are only returned for rooms matching
the provided filters.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create an encrypted room and an unencrypted room
encrypted_room_id = self.helper.create_room_as(
user1_id,
tok=user1_tok,
extra_content={
"initial_state": [
{
"type": "m.room.encryption",
"state_key": "",
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
}
]
},
)
unencrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in both rooms
enc_thread_root_id = self.helper.send(
encrypted_room_id, body="Encrypted thread", tok=user1_tok
)["event_id"]
unenc_thread_root_id = self.helper.send(
unencrypted_room_id, body="Unencrypted thread", tok=user1_tok
)["event_id"]
# Add replies to both threads
self.helper.send_event(
encrypted_room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply in encrypted room",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": enc_thread_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
unencrypted_room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply in unencrypted room",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": unenc_thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with filter for encrypted rooms only
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"lists": {
"encrypted_list": {
"ranges": [[0, 99]],
"required_state": [["m.room.encryption", ""]],
"timeline_limit": 10,
"filters": {"is_encrypted": True},
}
}
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should only include the encrypted room
self.assertIn(encrypted_room_id, chunk)
self.assertNotIn(unencrypted_room_id, chunk)
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
def test_room_filtering_with_room_subscriptions(self) -> None:
"""
Test that room filtering works correctly using the room_subscriptions parameter.
This verifies that thread updates are only returned for explicitly subscribed rooms.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create three rooms
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room3_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in all three rooms
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
"event_id"
]
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user1_tok)[
"event_id"
]
thread3_root_id = self.helper.send(room3_id, body="Thread 3", tok=user1_tok)[
"event_id"
]
# Add replies to all threads
for room_id, thread_root_id in [
(room1_id, thread1_root_id),
(room2_id, thread2_root_id),
(room3_id, thread3_root_id),
]:
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with subscription to only room1 and room2
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"room_subscriptions": {
room1_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
},
room2_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
},
}
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should only include room1 and room2, not room3
self.assertIn(room1_id, chunk)
self.assertIn(room2_id, chunk)
self.assertNotIn(room3_id, chunk)
self.assertIn(thread1_root_id, chunk[room1_id])
self.assertIn(thread2_root_id, chunk[room2_id])
def test_room_filtering_with_lists_and_room_subscriptions(self) -> None:
"""
Test that room filtering works correctly when both lists and room_subscriptions
are provided. The union of rooms from both should be included.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create an encrypted room and two unencrypted rooms
encrypted_room_id = self.helper.create_room_as(
user1_id,
tok=user1_tok,
extra_content={
"initial_state": [
{
"type": "m.room.encryption",
"state_key": "",
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
}
]
},
)
unencrypted_room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
unencrypted_room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in all three rooms
enc_thread_root_id = self.helper.send(
encrypted_room_id, body="Encrypted thread", tok=user1_tok
)["event_id"]
unenc1_thread_root_id = self.helper.send(
unencrypted_room1_id, body="Unencrypted thread 1", tok=user1_tok
)["event_id"]
unenc2_thread_root_id = self.helper.send(
unencrypted_room2_id, body="Unencrypted thread 2", tok=user1_tok
)["event_id"]
# Add replies to all threads
for room_id, thread_root_id in [
(encrypted_room_id, enc_thread_root_id),
(unencrypted_room1_id, unenc1_thread_root_id),
(unencrypted_room2_id, unenc2_thread_root_id),
]:
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with:
# - lists: filter for encrypted rooms
# - room_subscriptions: explicitly subscribe to unencrypted_room1_id
# Expected: should get both encrypted_room_id (from list) and unencrypted_room1_id
# (from subscription), but NOT unencrypted_room2_id
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"lists": {
"encrypted_list": {
"ranges": [[0, 99]],
"required_state": [["m.room.encryption", ""]],
"timeline_limit": 10,
"filters": {"is_encrypted": True},
}
},
"room_subscriptions": {
unencrypted_room1_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
}
},
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should include encrypted_room_id (from list filter) and unencrypted_room1_id
# (from subscription), but not unencrypted_room2_id
self.assertIn(encrypted_room_id, chunk)
self.assertIn(unencrypted_room1_id, chunk)
self.assertNotIn(unencrypted_room2_id, chunk)
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
self.assertIn(unenc1_thread_root_id, chunk[unencrypted_room1_id])
def test_threads_not_returned_after_leaving_room(self) -> None:
"""
Test that thread updates are properly bounded when a user leaves a room.
Users should see thread updates that occurred up to the point they left,
but NOT updates that occurred after they left.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# Create room and both users join
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
self.helper.join(room_id, user2_id, tok=user2_tok)
# Create thread
res = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root = res["event_id"]
# Reply in thread while user2 is joined
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply 1 while user2 joined",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root,
},
},
tok=user1_tok,
)
# User2 leaves the room
self.helper.leave(room_id, user2_id, tok=user2_tok)
# Another reply after user2 left
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply 2 after user2 left",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root,
},
},
tok=user1_tok,
)
# User2 gets thread updates with an explicit room subscription
# (We need to explicitly subscribe to the room to include it after leaving;
# otherwise only joined rooms are returned)
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=100",
{
"room_subscriptions": {
room_id: {
"required_state": [],
"timeline_limit": 0,
}
}
},
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert: User2 SHOULD see Reply 1 (happened while joined) but NOT Reply 2 (after leaving)
chunk = channel.json_body["chunk"]
self.assertIn(
room_id,
chunk,
"Thread updates should include the room user2 left",
)
self.assertIn(
thread_root,
chunk[room_id],
"Thread root should be in the updates",
)
# Verify that only a single update was seen (Reply 1) by checking that there's
# no prev_batch token. If Reply 2 was also included, there would be multiple
# updates and a prev_batch token would be present.
thread_update = chunk[room_id][thread_root]
self.assertNotIn(
"prev_batch",
thread_update,
"No prev_batch should be present since only one update (Reply 1) is visible",
)