Refactor thread updates to use the same logic between endpoint and extension

This commit is contained in:
Devon Hudson
2025-11-09 20:45:32 -07:00
parent c826157c52
commit ba3c8a5f3e
7 changed files with 845 additions and 391 deletions

View File

@@ -31,7 +31,7 @@ from typing import (
import attr
from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.events.utils import SerializeEventConfig
@@ -43,12 +43,22 @@ from synapse.storage.databases.main.relations import (
_RelatedEvent,
)
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, UserID
from synapse.types import (
JsonDict,
Requester,
RoomStreamToken,
StreamKeyType,
StreamToken,
UserID,
)
from synapse.util.async_helpers import gather_results
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.events.utils import EventClientSerializer
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -59,6 +69,22 @@ ThreadRootsMap = dict[str, EventBase]
AggregationsMap = dict[str, "BundledAggregations"]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdate:
"""
Data for a single thread update.
Attributes:
thread_root: The thread root event, or None if not requested/not visible
prev_batch: Per-thread pagination token for fetching older events in this thread
bundled_aggregations: Bundled aggregations for the thread root event
"""
thread_root: EventBase | None
prev_batch: StreamToken | None
bundled_aggregations: "BundledAggregations | None" = None
class ThreadsListInclude(str, enum.Enum):
"""Valid values for the 'include' flag of /threads."""
@@ -554,7 +580,7 @@ class RelationsHandler:
return results
async def process_thread_updates_for_visibility(
async def _filter_thread_updates_for_user(
self,
all_thread_updates: ThreadUpdatesMap,
user_id: str,
@@ -596,37 +622,324 @@ class RelationsHandler:
return filtered_updates
async def fetch_thread_roots_and_aggregations(
def _build_thread_updates_response(
self,
thread_ids: Collection[str],
user_id: str,
) -> tuple[ThreadRootsMap, AggregationsMap]:
"""Fetch thread root events and their bundled aggregations.
filtered_updates: ThreadUpdatesMap,
thread_root_event_map: ThreadRootsMap,
aggregations_map: AggregationsMap,
global_prev_batch_token: StreamToken | None,
) -> dict[str, dict[str, ThreadUpdate]]:
"""Build thread update response structure with per-thread prev_batch tokens.
Args:
thread_ids: The thread root event IDs to fetch
user_id: The user ID requesting the aggregations
filtered_updates: Map of thread_root_id to list of ThreadUpdateInfo
thread_root_event_map: Map of thread_root_id to EventBase
aggregations_map: Map of thread_root_id to BundledAggregations
global_prev_batch_token: Global pagination token, or None if no more results
Returns:
Map of room_id to thread_root_id to ThreadUpdate
"""
thread_updates: dict[str, dict[str, ThreadUpdate]] = {}
for thread_root_id, updates in filtered_updates.items():
# We only care about the latest update for the thread
# Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.
latest_update = updates[0]
room_id = latest_update.room_id
# Generate per-thread prev_batch token if this thread has multiple visible updates
# or if we hit the global limit.
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
# we only saw 1 update for them. This is to cover the case where we only saw
# a single update for a given thread, but the global limit prevents us from
# obtaining other updates which would have otherwise been included in the range.
per_thread_prev_batch = None
if len(updates) > 1 or global_prev_batch_token is not None:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
if room_id not in thread_updates:
thread_updates[room_id] = {}
thread_updates[room_id][thread_root_id] = ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root_id),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root_id),
)
return thread_updates
async def _fetch_thread_updates(
self,
room_ids: frozenset[str],
room_membership_map: Mapping[str, "RoomsForUserType"],
from_token: StreamToken | None,
to_token: StreamToken,
limit: int,
exclude_thread_ids: set[str] | None = None,
) -> tuple[ThreadUpdatesMap, StreamToken | None]:
"""Fetch thread updates across multiple rooms, handling membership states properly.
This method separates rooms based on membership status (LEAVE/BAN vs others)
and queries them appropriately to prevent data leaks. For rooms where the user
has left or been banned, we bound the query to their leave/ban event position.
Args:
room_ids: The set of room IDs to fetch thread updates for
room_membership_map: Map of room_id to RoomsForUserType containing membership info
from_token: Lower bound (exclusive) for the query, or None for no lower bound
to_token: Upper bound for the query (for joined/invited/knocking rooms)
limit: Maximum number of thread updates to return across all rooms
exclude_thread_ids: Optional set of thread IDs to exclude from results
Returns:
A tuple of:
- Map of event_id to EventBase for thread root events
- Map of event_id to BundledAggregations for those events
- Map of thread_id to list of ThreadUpdateInfo objects
- Global prev_batch token if there are more results, None otherwise
"""
# Separate rooms based on membership to handle LEAVE/BAN rooms specially
leave_ban_rooms: set[str] = set()
other_rooms: set[str] = set()
for room_id in room_ids:
membership_info = room_membership_map.get(room_id)
if membership_info and membership_info.membership in (
Membership.LEAVE,
Membership.BAN,
):
leave_ban_rooms.add(room_id)
else:
other_rooms.add(room_id)
# Fetch thread updates from storage, handling LEAVE/BAN rooms separately
all_thread_updates: ThreadUpdatesMap = {}
prev_batch_token: StreamToken | None = None
remaining_limit = limit
# Query LEAVE/BAN rooms with bounded to_token to prevent data leaks
if leave_ban_rooms:
for room_id in leave_ban_rooms:
if remaining_limit <= 0:
# We've hit the limit, set prev_batch to indicate more results
prev_batch_token = to_token
break
membership_info = room_membership_map[room_id]
bounded_to_token = membership_info.event_pos.to_room_stream_token()
(
room_thread_updates,
room_prev_batch,
) = await self._main_store.get_thread_updates_for_rooms(
room_ids={room_id},
from_token=from_token.room_key if from_token else None,
to_token=bounded_to_token,
limit=remaining_limit,
exclude_thread_ids=exclude_thread_ids,
)
# Count updates and reduce remaining limit
num_updates = sum(
len(updates) for updates in room_thread_updates.values()
)
remaining_limit -= num_updates
# Merge updates
for thread_id, updates in room_thread_updates.items():
all_thread_updates.setdefault(thread_id, []).extend(updates)
# Merge prev_batch tokens (take the maximum for backward pagination)
if room_prev_batch is not None:
if prev_batch_token is None:
prev_batch_token = room_prev_batch
elif (
room_prev_batch.room_key.stream
> prev_batch_token.room_key.stream
):
prev_batch_token = room_prev_batch
# Query other rooms (joined/invited/knocking) with normal to_token
if other_rooms and remaining_limit > 0:
(
other_thread_updates,
other_prev_batch,
) = await self._main_store.get_thread_updates_for_rooms(
room_ids=other_rooms,
from_token=from_token.room_key if from_token else None,
to_token=to_token.room_key,
limit=remaining_limit,
exclude_thread_ids=exclude_thread_ids,
)
# Merge updates
for thread_id, updates in other_thread_updates.items():
all_thread_updates.setdefault(thread_id, []).extend(updates)
# Merge prev_batch tokens
if other_prev_batch is not None:
if prev_batch_token is None:
prev_batch_token = other_prev_batch
elif (
other_prev_batch.room_key.stream > prev_batch_token.room_key.stream
):
prev_batch_token = other_prev_batch
return all_thread_updates, prev_batch_token
async def get_thread_updates_for_rooms(
self,
room_ids: frozenset[str],
room_membership_map: Mapping[str, "RoomsForUserType"],
user_id: str,
from_token: StreamToken | None,
to_token: StreamToken,
limit: int,
include_roots: bool = False,
exclude_thread_ids: set[str] | None = None,
) -> tuple[dict[str, dict[str, ThreadUpdate]], StreamToken | None]:
"""Get thread updates across multiple rooms with full processing pipeline.
This is the main entry point for fetching thread updates. It handles:
- Fetching updates with membership-based security
- Filtering for visibility
- Optionally fetching thread roots and aggregations
- Building the response structure
Args:
room_ids: The set of room IDs to fetch updates for
room_membership_map: Map of room_id to RoomsForUserType for membership info
user_id: The user requesting the updates
from_token: Lower bound (exclusive) for the query
to_token: Upper bound for the query
limit: Maximum number of updates to return
include_roots: Whether to fetch and include thread root events (default: False)
exclude_thread_ids: Optional set of thread IDs to exclude
Returns:
A tuple of:
- Map of room_id to thread_root_id to ThreadUpdate
- Global prev_batch token if there are more results, None otherwise
"""
# Fetch thread updates with membership handling
all_thread_updates, prev_batch_token = await self._fetch_thread_updates(
room_ids=room_ids,
room_membership_map=room_membership_map,
from_token=from_token,
to_token=to_token,
limit=limit,
exclude_thread_ids=exclude_thread_ids,
)
if not all_thread_updates:
return {}, prev_batch_token
# Filter thread updates for visibility
filtered_updates = await self._filter_thread_updates_for_user(
all_thread_updates, user_id
)
if not filtered_updates:
return {}, prev_batch_token
# Optionally fetch thread root events and their bundled aggregations
thread_root_event_map: ThreadRootsMap = {}
aggregations_map: AggregationsMap = {}
if include_roots:
# Fetch thread root events
thread_root_events = await self._main_store.get_events_as_list(thread_ids)
thread_root_event_map: ThreadRootsMap = {
e.event_id: e for e in thread_root_events
}
thread_root_events = await self._main_store.get_events_as_list(
filtered_updates.keys()
)
thread_root_event_map = {e.event_id: e for e in thread_root_events}
# Fetch bundled aggregations for the thread roots
aggregations_map: AggregationsMap = {}
if thread_root_event_map:
aggregations_map = await self.get_bundled_aggregations(
thread_root_event_map.values(),
user_id,
)
return thread_root_event_map, aggregations_map
# Build response structure with per-thread prev_batch tokens
thread_updates = self._build_thread_updates_response(
filtered_updates=filtered_updates,
thread_root_event_map=thread_root_event_map,
aggregations_map=aggregations_map,
global_prev_batch_token=prev_batch_token,
)
return thread_updates, prev_batch_token
@staticmethod
async def serialize_thread_updates(
thread_updates: Mapping[str, Mapping[str, ThreadUpdate]],
prev_batch_token: StreamToken | None,
event_serializer: "EventClientSerializer",
time_now: int,
store: "DataStore",
serialize_options: SerializeEventConfig,
) -> JsonDict:
"""
Serialize thread updates to JSON format.
This helper handles serialization of ThreadUpdate objects for both the
companion endpoint and the sliding sync extension.
Args:
thread_updates: Map of room_id to thread_root_id to ThreadUpdate
prev_batch_token: Global pagination token for fetching more updates
event_serializer: The event serializer to use
time_now: Current time in milliseconds for event serialization
store: Datastore for serializing stream tokens
serialize_options: Serialization config
Returns:
JSON-serializable dict with "updates" and optionally "prev_batch"
"""
updates_dict: JsonDict = {}
for room_id, room_threads in thread_updates.items():
room_updates: JsonDict = {}
for thread_root_id, update in room_threads.items():
update_dict: JsonDict = {}
# Serialize thread_root event if present
if update.thread_root is not None:
bundle_aggs_map = (
{thread_root_id: update.bundled_aggregations}
if update.bundled_aggregations is not None
else None
)
serialized_events = await event_serializer.serialize_events(
[update.thread_root],
time_now,
config=serialize_options,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]
# Add per-thread prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
room_updates[thread_root_id] = update_dict
updates_dict[room_id] = room_updates
result: JsonDict = {"updates": updates_dict}
# Add global prev_batch token if present
if prev_batch_token is not None:
result["prev_batch"] = await prev_batch_token.to_string(store)
return result
async def get_threads(
self,

View File

@@ -30,28 +30,20 @@ from synapse.api.constants import (
AccountDataTypes,
EduTypes,
EventContentFields,
Membership,
MRelatesToFields,
RelationTypes,
)
from synapse.events import EventBase
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.relations import (
AggregationsMap,
ThreadRootsMap,
)
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.receipts import ReceiptInRoom
from synapse.storage.databases.main.relations import ThreadUpdateInfo
from synapse.types import (
DeviceListUpdates,
JsonMapping,
MultiWriterStreamToken,
RoomStreamToken,
SlidingSyncStreamToken,
StrCollection,
StreamKeyType,
StreamToken,
ThreadSubscriptionsToken,
)
@@ -74,7 +66,6 @@ _ThreadSubscription: TypeAlias = (
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -1042,42 +1033,6 @@ class SlidingSyncExtensionHandler:
threads_in_timeline.add(thread_id)
return threads_in_timeline
def _merge_prev_batch_token(
self,
current_token: StreamToken | None,
new_token: StreamToken | None,
) -> StreamToken | None:
"""Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination.
Args:
current_token: The current prev_batch token (may be None)
new_token: The new prev_batch token to merge (may be None)
Returns:
The merged token (maximum of the two, or None if both are None)
"""
if new_token is None:
return current_token
if current_token is None:
return new_token
if new_token.room_key.stream > current_token.room_key.stream:
return new_token
return current_token
def _merge_thread_updates(
self,
target: dict[str, list[ThreadUpdateInfo]],
source: dict[str, list[ThreadUpdateInfo]],
) -> None:
"""Merge thread updates from source into target.
Args:
target: The target dict to merge into (modified in place)
source: The source dict to merge from
"""
for thread_id, updates in source.items():
target.setdefault(thread_id, []).extend(updates)
async def get_threads_extension_response(
self,
sync_config: SlidingSyncConfig,
@@ -1119,141 +1074,26 @@ class SlidingSyncExtensionHandler:
actual_room_response_map
)
# Separate rooms into groups based on membership status.
# For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events
# that occurred after the user left/was banned.
leave_ban_rooms: set[str] = set()
other_rooms: set[str] = set()
for room_id in actual_room_ids:
membership_info = room_membership_for_user_at_to_token_map.get(room_id)
if membership_info and membership_info.membership in (
Membership.LEAVE,
Membership.BAN,
):
leave_ban_rooms.add(room_id)
else:
other_rooms.add(room_id)
# Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks.
all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {}
prev_batch_token: StreamToken | None = None
remaining_limit = threads_request.limit
# Query for rooms where the user has left or been banned, using their leave/ban
# event position as the upper bound to prevent seeing events after they left.
if leave_ban_rooms:
for room_id in leave_ban_rooms:
if remaining_limit <= 0:
# We've already fetched enough updates, but we still need to set
# prev_batch to indicate there are more results.
prev_batch_token = to_token
break
membership_info = room_membership_for_user_at_to_token_map[room_id]
bounded_to_token = membership_info.event_pos.to_room_stream_token()
(
room_thread_updates,
room_prev_batch,
) = await self.store.get_thread_updates_for_rooms(
room_ids={room_id},
from_token=from_token.stream_token.room_key if from_token else None,
to_token=bounded_to_token,
limit=remaining_limit,
exclude_thread_ids=threads_to_exclude,
)
# Count how many updates we fetched and reduce the remaining limit
num_updates = sum(
len(updates) for updates in room_thread_updates.values()
)
remaining_limit -= num_updates
self._merge_thread_updates(all_thread_updates, room_thread_updates)
prev_batch_token = self._merge_prev_batch_token(
prev_batch_token, room_prev_batch
)
# Query for rooms where the user is joined, invited, or knocking, using the
# normal to_token as the upper bound.
if other_rooms and remaining_limit > 0:
(
other_thread_updates,
other_prev_batch,
) = await self.store.get_thread_updates_for_rooms(
room_ids=other_rooms,
from_token=from_token.stream_token.room_key if from_token else None,
to_token=to_token.room_key,
limit=remaining_limit,
exclude_thread_ids=threads_to_exclude,
)
self._merge_thread_updates(all_thread_updates, other_thread_updates)
prev_batch_token = self._merge_prev_batch_token(
prev_batch_token, other_prev_batch
)
if len(all_thread_updates) == 0:
return None
# Filter thread updates for visibility
# Get thread updates using unified helper
user_id = sync_config.user.to_string()
filtered_updates = (
await self.relations_handler.process_thread_updates_for_visibility(
all_thread_updates, user_id
)
)
if not filtered_updates:
return None
# Note: Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.
# Optionally fetch thread root events and their bundled aggregations
thread_root_event_map: ThreadRootsMap = {}
aggregations_map: AggregationsMap = {}
if threads_request.include_roots:
(
thread_root_event_map,
aggregations_map,
) = await self.relations_handler.fetch_thread_roots_and_aggregations(
filtered_updates.keys(), user_id
thread_updates_response,
prev_batch_token,
) = await self.relations_handler.get_thread_updates_for_rooms(
room_ids=frozenset(actual_room_ids),
room_membership_map=room_membership_for_user_at_to_token_map,
user_id=user_id,
from_token=from_token.stream_token if from_token else None,
to_token=to_token,
limit=threads_request.limit,
include_roots=threads_request.include_roots,
exclude_thread_ids=threads_to_exclude,
)
thread_updates: dict[str, dict[str, _ThreadUpdate]] = {}
for thread_root, updates in filtered_updates.items():
# We only care about the latest update for the thread.
# After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering).
latest_update = updates[0]
# Generate per-thread prev_batch token if this thread has multiple visible updates.
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
# we only saw 1 update for them. This is to cover the case where we only saw
# a single update for a given thread, but the global limit prevents us from
# obtaining other updates which would have otherwise been included in the
# range.
per_thread_prev_batch = None
if len(updates) > 1 or prev_batch_token is not None:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
thread_updates.setdefault(latest_update.room_id, {})[thread_root] = (
_ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root),
)
)
if not thread_updates_response:
return None
return SlidingSyncResult.Extensions.ThreadsExtension(
updates=thread_updates,
updates=thread_updates_response,
prev_batch=prev_batch_token,
)

View File

@@ -22,7 +22,7 @@ import logging
import re
from typing import TYPE_CHECKING
from synapse.api.constants import Direction
from synapse.api.constants import Direction, Membership
from synapse.api.errors import SynapseError
from synapse.events.utils import SerializeEventConfig
from synapse.handlers.relations import ThreadsListInclude
@@ -41,7 +41,8 @@ from synapse.streams.config import (
PaginationConfig,
extract_stream_token_from_pagination_token,
)
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig
from synapse.types.rest.client import ThreadUpdatesBody
if TYPE_CHECKING:
@@ -166,8 +167,7 @@ class ThreadUpdatesServlet(RestServlet):
self.relations_handler = hs.get_relations_handler()
self.event_serializer = hs.get_event_client_serializer()
self._storage_controllers = hs.get_storage_controllers()
# TODO: Get sliding sync handler for filter_rooms logic
# self.sliding_sync_handler = hs.get_sliding_sync_handler()
self.sliding_sync_handler = hs.get_sliding_sync_handler()
async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -191,135 +191,116 @@ class ThreadUpdatesServlet(RestServlet):
to_token_str = parse_string(request, "to")
# Parse pagination tokens
from_token: RoomStreamToken | None = None
to_token: RoomStreamToken | None = None
from_token: StreamToken | None = None
to_token: StreamToken | None = None
if from_token_str:
try:
stream_token_str = extract_stream_token_from_pagination_token(
from_token_str
)
stream_token = await StreamToken.from_string(
self.store, stream_token_str
)
from_token = stream_token.room_key
except Exception:
raise SynapseError(400, "'from' parameter is invalid")
from_token = await StreamToken.from_string(self.store, stream_token_str)
except Exception as e:
logger.exception("Error parsing 'from' token: %s", from_token_str)
raise SynapseError(400, "'from' parameter is invalid") from e
if to_token_str:
try:
stream_token_str = extract_stream_token_from_pagination_token(
to_token_str
)
stream_token = await StreamToken.from_string(
self.store, stream_token_str
)
to_token = stream_token.room_key
to_token = await StreamToken.from_string(self.store, stream_token_str)
except Exception:
raise SynapseError(400, "'to' parameter is invalid")
# Get the list of rooms to fetch thread updates for
# Start with all joined rooms, then apply filters if provided
user_id = requester.user.to_string()
room_ids = await self.store.get_rooms_for_user(user_id)
user = UserID.from_string(user_id)
if body.filters is not None:
# TODO: Apply filters using sliding sync room filter logic
# For now, if filters are provided, we need to call the sliding sync
# filter_rooms method to get the applicable room IDs
raise SynapseError(501, "Room filters not yet implemented")
# Get the current stream token for membership lookup
if from_token is None:
max_stream_ordering = self.store.get_room_max_stream_ordering()
current_token = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM, RoomStreamToken(stream=max_stream_ordering)
)
else:
current_token = from_token
# Fetch thread updates from storage
# For backward pagination:
# - 'from' (upper bound, exclusive) maps to 'to_token' (inclusive with <=)
# Since next_batch is (last_returned - 1), <= excludes the last returned item
# - 'to' (lower bound, exclusive) maps to 'from_token' (exclusive with >)
# Get room membership information to properly handle LEAVE/BAN rooms
(
all_thread_updates,
room_membership_for_user_at_to_token_map,
_,
_,
) = await self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
user=user,
to_token=current_token,
from_token=None,
)
# Determine which rooms to fetch updates for based on lists/room_subscriptions
if body.lists is not None or body.room_subscriptions is not None:
# Use sliding sync room selection logic
sync_config = SlidingSyncConfig(
user=user,
requester=requester,
lists=body.lists,
room_subscriptions=body.room_subscriptions,
)
# Use the sliding sync room list handler to get the same set of rooms
interested_rooms = (
await self.sliding_sync_handler.room_lists.compute_interested_rooms(
sync_config=sync_config,
previous_connection_state=PerConnectionState(),
to_token=current_token,
from_token=None,
)
)
room_ids = frozenset(interested_rooms.relevant_room_map.keys())
else:
# No lists or room_subscriptions, use only joined rooms
room_ids = frozenset(
room_id
for room_id, membership_info in room_membership_for_user_at_to_token_map.items()
if membership_info.membership == Membership.JOIN
)
# Get thread updates using unified helper
(
thread_updates,
prev_batch_token,
) = await self.store.get_thread_updates_for_rooms(
) = await self.relations_handler.get_thread_updates_for_rooms(
room_ids=room_ids,
room_membership_map=room_membership_for_user_at_to_token_map,
user_id=user_id,
from_token=to_token,
to_token=from_token,
to_token=from_token if from_token else current_token,
limit=limit,
include_roots=body.include_roots,
)
if len(all_thread_updates) == 0:
if not thread_updates:
return 200, {"chunk": {}}
# Filter thread updates for visibility
filtered_updates = (
await self.relations_handler.process_thread_updates_for_visibility(
all_thread_updates, user_id
)
)
if not filtered_updates:
return 200, {"chunk": {}}
# Fetch thread root events and their bundled aggregations
(
thread_root_event_map,
aggregations_map,
) = await self.relations_handler.fetch_thread_roots_and_aggregations(
filtered_updates.keys(), user_id
)
# Build response with per-thread data
# Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.
# Serialize thread updates using shared helper
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
chunk: dict[str, dict[str, JsonDict]] = {}
for thread_root_id, updates in filtered_updates.items():
# We only care about the latest update for the thread
latest_update = updates[0]
room_id = latest_update.room_id
if room_id not in chunk:
chunk[room_id] = {}
update_dict: JsonDict = {}
# Add thread root if present
thread_root_event = thread_root_event_map.get(thread_root_id)
if thread_root_event is not None:
bundle_aggs_map = (
{thread_root_id: aggregations_map[thread_root_id]}
if thread_root_id in aggregations_map
else None
)
serialized_events = await self.event_serializer.serialize_events(
[thread_root_event],
time_now,
config=serialize_options,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]
# Add per-thread prev_batch if this thread has multiple visible updates
if len(updates) > 1:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
update_dict["prev_batch"] = await per_thread_prev_batch.to_string(
self.store
serialized = await self.relations_handler.serialize_thread_updates(
thread_updates=thread_updates,
prev_batch_token=prev_batch_token,
event_serializer=self.event_serializer,
time_now=time_now,
store=self.store,
serialize_options=serialize_options,
)
chunk[room_id][thread_root_id] = update_dict
# Build response
response: JsonDict = {"chunk": chunk}
# Add next_batch token for pagination
if prev_batch_token is not None:
response["next_batch"] = await prev_batch_token.to_string(self.store)
# Build response with "chunk" wrapper and "next_batch" key
# (companion endpoint uses different key names than sliding sync)
response: JsonDict = {"chunk": serialized["updates"]}
if "prev_batch" in serialized:
response["next_batch"] = serialized["prev_batch"]
return 200, response

View File

@@ -37,6 +37,7 @@ from synapse.events.utils import (
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.relations import RelationsHandler
from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult
from synapse.handlers.sync import (
ArchivedSyncResult,
@@ -1107,6 +1108,7 @@ class SlidingSyncRestServlet(RestServlet):
time_now,
extensions.threads,
self.store,
requester,
)
return serialized_extensions
@@ -1150,6 +1152,7 @@ async def _serialise_threads(
time_now: int,
threads: SlidingSyncResult.Extensions.ThreadsExtension,
store: "DataStore",
requester: Requester,
) -> JsonDict:
"""
Serialize the threads extension response for sliding sync.
@@ -1159,6 +1162,7 @@ async def _serialise_threads(
time_now: The current time in milliseconds, used for event serialization.
threads: The threads extension data containing thread updates and pagination tokens.
store: The datastore, needed for serializing stream tokens.
requester: The user making the request, used for transaction_id inclusion.
Returns:
A JSON-serializable dict containing:
@@ -1169,47 +1173,25 @@ async def _serialise_threads(
- "prev_batch": A pagination token for fetching older events in the thread.
- "prev_batch": A pagination token for fetching older thread updates (if available).
"""
if not threads.updates:
out: JsonDict = {}
if threads.updates:
updates_dict: JsonDict = {}
for room_id, thread_updates in threads.updates.items():
room_updates: JsonDict = {}
for thread_root_id, update in thread_updates.items():
# Serialize the update
update_dict: JsonDict = {}
# Serialize the thread_root event if present
if update.thread_root is not None:
# Create a mapping of event_id to bundled_aggregations
bundle_aggs_map = (
{thread_root_id: update.bundled_aggregations}
if update.bundled_aggregations
else None
)
serialized_events = await event_serializer.serialize_events(
[update.thread_root],
time_now,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]
# Add prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
room_updates[thread_root_id] = update_dict
updates_dict[room_id] = room_updates
out["updates"] = updates_dict
if threads.prev_batch:
out["prev_batch"] = await threads.prev_batch.to_string(store)
return out
# Create serialization config to include transaction_id for requester's events
serialize_options = SerializeEventConfig(requester=requester)
# Use shared serialization helper (static method)
return await RelationsHandler.serialize_thread_updates(
thread_updates=threads.updates,
prev_batch_token=threads.prev_batch,
event_serializer=event_serializer,
time_now=time_now,
store=store,
serialize_options=serialize_options,
)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

View File

@@ -37,7 +37,8 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.handlers.relations import BundledAggregations, ThreadUpdate
from synapse.types import (
DeviceListUpdates,
JsonDict,
@@ -409,30 +410,7 @@ class SlidingSyncResult:
to paginate through older thread updates.
"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdate:
"""Information about a single thread that has new activity.
Attributes:
thread_root: The thread root event, if requested via include_roots in the
request. This is the event that started the thread.
prev_batch: A pagination token (exclusive) for fetching older events in this
specific thread. Only present if the thread has multiple updates in the
sync window. This token can be used with the /relations endpoint with
dir=b to paginate backwards through the thread's history.
bundled_aggregations: Bundled aggregations for the thread root event,
including the latest_event in the thread (found in
unsigned.m.relations.m.thread). Only present if thread_root is included.
"""
thread_root: EventBase | None
prev_batch: StreamToken | None
bundled_aggregations: "BundledAggregations | None" = None
def __bool__(self) -> bool:
return bool(self.thread_root) or bool(self.prev_batch)
updates: Mapping[str, Mapping[str, ThreadUpdate]] | None
updates: Mapping[str, Mapping[str, "ThreadUpdate"]] | None
prev_batch: StreamToken | None
def __bool__(self) -> bool:

View File

@@ -433,14 +433,30 @@ class ThreadUpdatesBody(RequestBodyModel):
"""
Thread updates companion endpoint request body (MSC4360).
Allows filtering thread updates using the same filter criteria as sliding sync lists.
This enables clients to paginate thread updates using the same room filters that
were applied when generating the prev_batch token.
Allows paginating thread updates using the same room selection as a sliding sync
request. This enables clients to fetch thread updates for the same set of rooms
that were included in their sliding sync response.
Attributes:
filters: Optional room filters to apply, using the same structure as
SlidingSyncList.Filters. If not provided, thread updates from all
lists: Sliding window API lists, using the same structure as SlidingSyncBody.lists.
If provided along with room_subscriptions, the union of rooms from both will
be used.
room_subscriptions: Room subscription API rooms, using the same structure as
SlidingSyncBody.room_subscriptions. If provided along with lists, the union
of rooms from both will be used.
include_roots: Whether to include the thread root events in the response.
Defaults to False.
If neither lists nor room_subscriptions are provided, thread updates from all
joined rooms are returned.
"""
filters: SlidingSyncBody.SlidingSyncList.Filters | None = None
lists: (
dict[
Annotated[str, StringConstraints(max_length=64, strict=True)],
SlidingSyncBody.SlidingSyncList,
]
| None
) = None
room_subscriptions: dict[StrictStr, SlidingSyncBody.RoomSubscription] | None = None
include_roots: StrictBool = False

View File

@@ -58,7 +58,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
content={},
content={"include_roots": True},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -99,7 +99,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -165,7 +165,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -227,7 +227,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -277,7 +277,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -293,7 +293,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -315,7 +315,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -347,7 +347,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
@@ -363,7 +363,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
@@ -372,7 +372,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
@@ -388,7 +388,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
@@ -397,7 +397,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
@@ -454,7 +454,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
self.assertIn("next_batch", channel.json_body)
@@ -471,7 +471,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
@@ -523,7 +523,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
@@ -602,7 +602,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={},
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
@@ -611,3 +611,347 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
self.assertIn(room1_id, chunk)
self.assertNotIn(room2_id, chunk)
self.assertIn(thread1_root_id, chunk[room1_id])
def test_room_filtering_with_lists(self) -> None:
"""
Test that room filtering works correctly using the lists parameter.
This verifies that thread updates are only returned for rooms matching
the provided filters.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create an encrypted room and an unencrypted room
encrypted_room_id = self.helper.create_room_as(
user1_id,
tok=user1_tok,
extra_content={
"initial_state": [
{
"type": "m.room.encryption",
"state_key": "",
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
}
]
},
)
unencrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in both rooms
enc_thread_root_id = self.helper.send(
encrypted_room_id, body="Encrypted thread", tok=user1_tok
)["event_id"]
unenc_thread_root_id = self.helper.send(
unencrypted_room_id, body="Unencrypted thread", tok=user1_tok
)["event_id"]
# Add replies to both threads
self.helper.send_event(
encrypted_room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply in encrypted room",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": enc_thread_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
unencrypted_room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply in unencrypted room",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": unenc_thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with filter for encrypted rooms only
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"lists": {
"encrypted_list": {
"ranges": [[0, 99]],
"required_state": [["m.room.encryption", ""]],
"timeline_limit": 10,
"filters": {"is_encrypted": True},
}
}
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should only include the encrypted room
self.assertIn(encrypted_room_id, chunk)
self.assertNotIn(unencrypted_room_id, chunk)
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
def test_room_filtering_with_room_subscriptions(self) -> None:
"""
Test that room filtering works correctly using the room_subscriptions parameter.
This verifies that thread updates are only returned for explicitly subscribed rooms.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create three rooms
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room3_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in all three rooms
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
"event_id"
]
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user1_tok)[
"event_id"
]
thread3_root_id = self.helper.send(room3_id, body="Thread 3", tok=user1_tok)[
"event_id"
]
# Add replies to all threads
for room_id, thread_root_id in [
(room1_id, thread1_root_id),
(room2_id, thread2_root_id),
(room3_id, thread3_root_id),
]:
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with subscription to only room1 and room2
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"room_subscriptions": {
room1_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
},
room2_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
},
}
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should only include room1 and room2, not room3
self.assertIn(room1_id, chunk)
self.assertIn(room2_id, chunk)
self.assertNotIn(room3_id, chunk)
self.assertIn(thread1_root_id, chunk[room1_id])
self.assertIn(thread2_root_id, chunk[room2_id])
def test_room_filtering_with_lists_and_room_subscriptions(self) -> None:
"""
Test that room filtering works correctly when both lists and room_subscriptions
are provided. The union of rooms from both should be included.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create an encrypted room and two unencrypted rooms
encrypted_room_id = self.helper.create_room_as(
user1_id,
tok=user1_tok,
extra_content={
"initial_state": [
{
"type": "m.room.encryption",
"state_key": "",
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
}
]
},
)
unencrypted_room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
unencrypted_room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in all three rooms
enc_thread_root_id = self.helper.send(
encrypted_room_id, body="Encrypted thread", tok=user1_tok
)["event_id"]
unenc1_thread_root_id = self.helper.send(
unencrypted_room1_id, body="Unencrypted thread 1", tok=user1_tok
)["event_id"]
unenc2_thread_root_id = self.helper.send(
unencrypted_room2_id, body="Unencrypted thread 2", tok=user1_tok
)["event_id"]
# Add replies to all threads
for room_id, thread_root_id in [
(encrypted_room_id, enc_thread_root_id),
(unencrypted_room1_id, unenc1_thread_root_id),
(unencrypted_room2_id, unenc2_thread_root_id),
]:
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with:
# - lists: filter for encrypted rooms
# - room_subscriptions: explicitly subscribe to unencrypted_room1_id
# Expected: should get both encrypted_room_id (from list) and unencrypted_room1_id
# (from subscription), but NOT unencrypted_room2_id
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"lists": {
"encrypted_list": {
"ranges": [[0, 99]],
"required_state": [["m.room.encryption", ""]],
"timeline_limit": 10,
"filters": {"is_encrypted": True},
}
},
"room_subscriptions": {
unencrypted_room1_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
}
},
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should include encrypted_room_id (from list filter) and unencrypted_room1_id
# (from subscription), but not unencrypted_room2_id
self.assertIn(encrypted_room_id, chunk)
self.assertIn(unencrypted_room1_id, chunk)
self.assertNotIn(unencrypted_room2_id, chunk)
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
self.assertIn(unenc1_thread_root_id, chunk[unencrypted_room1_id])
def test_threads_not_returned_after_leaving_room(self) -> None:
"""
Test that thread updates are properly bounded when a user leaves a room.
Users should see thread updates that occurred up to the point they left,
but NOT updates that occurred after they left.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# Create room and both users join
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
self.helper.join(room_id, user2_id, tok=user2_tok)
# Create thread
res = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root = res["event_id"]
# Reply in thread while user2 is joined
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply 1 while user2 joined",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root,
},
},
tok=user1_tok,
)
# User2 leaves the room
self.helper.leave(room_id, user2_id, tok=user2_tok)
# Another reply after user2 left
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply 2 after user2 left",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root,
},
},
tok=user1_tok,
)
# User2 gets thread updates with an explicit room subscription
# (We need to explicitly subscribe to the room to include it after leaving;
# otherwise only joined rooms are returned)
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=100",
{
"room_subscriptions": {
room_id: {
"required_state": [],
"timeline_limit": 0,
}
}
},
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert: User2 SHOULD see Reply 1 (happened while joined) but NOT Reply 2 (after leaving)
chunk = channel.json_body["chunk"]
self.assertIn(
room_id,
chunk,
"Thread updates should include the room user2 left",
)
self.assertIn(
thread_root,
chunk[room_id],
"Thread root should be in the updates",
)
# Verify that only a single update was seen (Reply 1) by checking that there's
# no prev_batch token. If Reply 2 was also included, there would be multiple
# updates and a prev_batch token would be present.
thread_update = chunk[room_id][thread_root]
self.assertNotIn(
"prev_batch",
thread_update,
"No prev_batch should be present since only one update (Reply 1) is visible",
)