1
0

Compare commits

...

9 Commits

Author SHA1 Message Date
Devon Hudson
d65fc3861b Fix lint 2025-11-09 20:54:21 -07:00
Devon Hudson
eff2503adc Move thread_update request body next to handler 2025-11-09 20:52:27 -07:00
Devon Hudson
ba3c8a5f3e Refactor thread updates to use the same logic between endpoint and extension 2025-11-09 20:45:32 -07:00
Devon Hudson
c826157c52 Fix linter errors 2025-11-09 16:21:56 -07:00
Devon Hudson
57c884ec83 Merge branch 'devon/ssext_threads' into devon/ssext_threads_companion 2025-11-09 16:15:20 -07:00
Devon Hudson
89f75cc70f Add newsfile 2025-10-10 16:20:03 -06:00
Devon Hudson
2f8568866e Remove unnecessary bits 2025-10-10 16:15:21 -06:00
Devon Hudson
af992dd0e2 Merge branch 'devon/ssext_threads' into devon/ssext_threads_companion 2025-10-10 15:40:06 -06:00
Devon Hudson
87e9fe8b38 Add implementation for /thread_updates MSC4360 companion endpoint 2025-10-10 15:09:01 -06:00
9 changed files with 1888 additions and 285 deletions

View File

@@ -0,0 +1 @@
Add companion endpoint for MSC4360: Sliding Sync Threads Extension.

View File

@@ -20,6 +20,7 @@
#
import enum
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
@@ -30,24 +31,59 @@ from typing import (
import attr
from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.events.utils import SerializeEventConfig
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
from synapse.storage.databases.main.relations import (
ThreadsNextBatch,
ThreadUpdateInfo,
_RelatedEvent,
)
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, UserID
from synapse.types import (
JsonDict,
Requester,
RoomStreamToken,
StreamKeyType,
StreamToken,
UserID,
)
from synapse.util.async_helpers import gather_results
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.events.utils import EventClientSerializer
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
# Type aliases for thread update processing
ThreadUpdatesMap = dict[str, list[ThreadUpdateInfo]]
ThreadRootsMap = dict[str, EventBase]
AggregationsMap = dict[str, "BundledAggregations"]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdate:
"""
Data for a single thread update.
Attributes:
thread_root: The thread root event, or None if not requested/not visible
prev_batch: Per-thread pagination token for fetching older events in this thread
bundled_aggregations: Bundled aggregations for the thread root event
"""
thread_root: EventBase | None
prev_batch: StreamToken | None
bundled_aggregations: "BundledAggregations | None" = None
class ThreadsListInclude(str, enum.Enum):
"""Valid values for the 'include' flag of /threads."""
@@ -544,6 +580,367 @@ class RelationsHandler:
return results
async def _filter_thread_updates_for_user(
self,
all_thread_updates: ThreadUpdatesMap,
user_id: str,
) -> ThreadUpdatesMap:
"""Process thread updates by filtering for visibility.
Takes raw thread updates from storage and filters them based on whether the
user can see the events. Preserves the ordering of updates within each thread.
Args:
all_thread_updates: Map of thread_id to list of ThreadUpdateInfo objects
user_id: The user ID to filter events for
Returns:
Filtered map of thread_id to list of ThreadUpdateInfo objects, containing
only updates for events the user can see.
"""
# Build a mapping of event_id -> (thread_id, update) for efficient lookup
# during visibility filtering.
event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {}
for thread_id, updates in all_thread_updates.items():
for update in updates:
event_to_thread_map[update.event_id] = (thread_id, update)
# Fetch and filter events for visibility
all_events = await self._main_store.get_events_as_list(
event_to_thread_map.keys()
)
filtered_events = await filter_events_for_client(
self._storage_controllers, user_id, all_events
)
# Rebuild thread updates from filtered events
filtered_updates: ThreadUpdatesMap = defaultdict(list)
for event in filtered_events:
if event.event_id in event_to_thread_map:
thread_id, update = event_to_thread_map[event.event_id]
filtered_updates[thread_id].append(update)
return filtered_updates
def _build_thread_updates_response(
self,
filtered_updates: ThreadUpdatesMap,
thread_root_event_map: ThreadRootsMap,
aggregations_map: AggregationsMap,
global_prev_batch_token: StreamToken | None,
) -> dict[str, dict[str, ThreadUpdate]]:
"""Build thread update response structure with per-thread prev_batch tokens.
Args:
filtered_updates: Map of thread_root_id to list of ThreadUpdateInfo
thread_root_event_map: Map of thread_root_id to EventBase
aggregations_map: Map of thread_root_id to BundledAggregations
global_prev_batch_token: Global pagination token, or None if no more results
Returns:
Map of room_id to thread_root_id to ThreadUpdate
"""
thread_updates: dict[str, dict[str, ThreadUpdate]] = {}
for thread_root_id, updates in filtered_updates.items():
# We only care about the latest update for the thread
# Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.
latest_update = updates[0]
room_id = latest_update.room_id
# Generate per-thread prev_batch token if this thread has multiple visible updates
# or if we hit the global limit.
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
# we only saw 1 update for them. This is to cover the case where we only saw
# a single update for a given thread, but the global limit prevents us from
# obtaining other updates which would have otherwise been included in the range.
per_thread_prev_batch = None
if len(updates) > 1 or global_prev_batch_token is not None:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
if room_id not in thread_updates:
thread_updates[room_id] = {}
thread_updates[room_id][thread_root_id] = ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root_id),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root_id),
)
return thread_updates
async def _fetch_thread_updates(
self,
room_ids: frozenset[str],
room_membership_map: Mapping[str, "RoomsForUserType"],
from_token: StreamToken | None,
to_token: StreamToken,
limit: int,
exclude_thread_ids: set[str] | None = None,
) -> tuple[ThreadUpdatesMap, StreamToken | None]:
"""Fetch thread updates across multiple rooms, handling membership states properly.
This method separates rooms based on membership status (LEAVE/BAN vs others)
and queries them appropriately to prevent data leaks. For rooms where the user
has left or been banned, we bound the query to their leave/ban event position.
Args:
room_ids: The set of room IDs to fetch thread updates for
room_membership_map: Map of room_id to RoomsForUserType containing membership info
from_token: Lower bound (exclusive) for the query, or None for no lower bound
to_token: Upper bound for the query (for joined/invited/knocking rooms)
limit: Maximum number of thread updates to return across all rooms
exclude_thread_ids: Optional set of thread IDs to exclude from results
Returns:
A tuple of:
- Map of thread_id to list of ThreadUpdateInfo objects
- Global prev_batch token if there are more results, None otherwise
"""
# Separate rooms based on membership to handle LEAVE/BAN rooms specially
leave_ban_rooms: set[str] = set()
other_rooms: set[str] = set()
for room_id in room_ids:
membership_info = room_membership_map.get(room_id)
if membership_info and membership_info.membership in (
Membership.LEAVE,
Membership.BAN,
):
leave_ban_rooms.add(room_id)
else:
other_rooms.add(room_id)
# Fetch thread updates from storage, handling LEAVE/BAN rooms separately
all_thread_updates: ThreadUpdatesMap = {}
prev_batch_token: StreamToken | None = None
remaining_limit = limit
# Query LEAVE/BAN rooms with bounded to_token to prevent data leaks
if leave_ban_rooms:
for room_id in leave_ban_rooms:
if remaining_limit <= 0:
# We've hit the limit, set prev_batch to indicate more results
prev_batch_token = to_token
break
membership_info = room_membership_map[room_id]
bounded_to_token = membership_info.event_pos.to_room_stream_token()
(
room_thread_updates,
room_prev_batch,
) = await self._main_store.get_thread_updates_for_rooms(
room_ids={room_id},
from_token=from_token.room_key if from_token else None,
to_token=bounded_to_token,
limit=remaining_limit,
exclude_thread_ids=exclude_thread_ids,
)
# Count updates and reduce remaining limit
num_updates = sum(
len(updates) for updates in room_thread_updates.values()
)
remaining_limit -= num_updates
# Merge updates
for thread_id, updates in room_thread_updates.items():
all_thread_updates.setdefault(thread_id, []).extend(updates)
# Merge prev_batch tokens (take the maximum for backward pagination)
if room_prev_batch is not None:
if prev_batch_token is None:
prev_batch_token = room_prev_batch
elif (
room_prev_batch.room_key.stream
> prev_batch_token.room_key.stream
):
prev_batch_token = room_prev_batch
# Query other rooms (joined/invited/knocking) with normal to_token
if other_rooms and remaining_limit > 0:
(
other_thread_updates,
other_prev_batch,
) = await self._main_store.get_thread_updates_for_rooms(
room_ids=other_rooms,
from_token=from_token.room_key if from_token else None,
to_token=to_token.room_key,
limit=remaining_limit,
exclude_thread_ids=exclude_thread_ids,
)
# Merge updates
for thread_id, updates in other_thread_updates.items():
all_thread_updates.setdefault(thread_id, []).extend(updates)
# Merge prev_batch tokens
if other_prev_batch is not None:
if prev_batch_token is None:
prev_batch_token = other_prev_batch
elif (
other_prev_batch.room_key.stream > prev_batch_token.room_key.stream
):
prev_batch_token = other_prev_batch
return all_thread_updates, prev_batch_token
async def get_thread_updates_for_rooms(
self,
room_ids: frozenset[str],
room_membership_map: Mapping[str, "RoomsForUserType"],
user_id: str,
from_token: StreamToken | None,
to_token: StreamToken,
limit: int,
include_roots: bool = False,
exclude_thread_ids: set[str] | None = None,
) -> tuple[dict[str, dict[str, ThreadUpdate]], StreamToken | None]:
"""Get thread updates across multiple rooms with full processing pipeline.
This is the main entry point for fetching thread updates. It handles:
- Fetching updates with membership-based security
- Filtering for visibility
- Optionally fetching thread roots and aggregations
- Building the response structure
Args:
room_ids: The set of room IDs to fetch updates for
room_membership_map: Map of room_id to RoomsForUserType for membership info
user_id: The user requesting the updates
from_token: Lower bound (exclusive) for the query
to_token: Upper bound for the query
limit: Maximum number of updates to return
include_roots: Whether to fetch and include thread root events (default: False)
exclude_thread_ids: Optional set of thread IDs to exclude
Returns:
A tuple of:
- Map of room_id to thread_root_id to ThreadUpdate
- Global prev_batch token if there are more results, None otherwise
"""
# Fetch thread updates with membership handling
all_thread_updates, prev_batch_token = await self._fetch_thread_updates(
room_ids=room_ids,
room_membership_map=room_membership_map,
from_token=from_token,
to_token=to_token,
limit=limit,
exclude_thread_ids=exclude_thread_ids,
)
if not all_thread_updates:
return {}, prev_batch_token
# Filter thread updates for visibility
filtered_updates = await self._filter_thread_updates_for_user(
all_thread_updates, user_id
)
if not filtered_updates:
return {}, prev_batch_token
# Optionally fetch thread root events and their bundled aggregations
thread_root_event_map: ThreadRootsMap = {}
aggregations_map: AggregationsMap = {}
if include_roots:
# Fetch thread root events
thread_root_events = await self._main_store.get_events_as_list(
filtered_updates.keys()
)
thread_root_event_map = {e.event_id: e for e in thread_root_events}
# Fetch bundled aggregations for the thread roots
if thread_root_event_map:
aggregations_map = await self.get_bundled_aggregations(
thread_root_event_map.values(),
user_id,
)
# Build response structure with per-thread prev_batch tokens
thread_updates = self._build_thread_updates_response(
filtered_updates=filtered_updates,
thread_root_event_map=thread_root_event_map,
aggregations_map=aggregations_map,
global_prev_batch_token=prev_batch_token,
)
return thread_updates, prev_batch_token
@staticmethod
async def serialize_thread_updates(
thread_updates: Mapping[str, Mapping[str, ThreadUpdate]],
prev_batch_token: StreamToken | None,
event_serializer: "EventClientSerializer",
time_now: int,
store: "DataStore",
serialize_options: SerializeEventConfig,
) -> JsonDict:
"""
Serialize thread updates to JSON format.
This helper handles serialization of ThreadUpdate objects for both the
companion endpoint and the sliding sync extension.
Args:
thread_updates: Map of room_id to thread_root_id to ThreadUpdate
prev_batch_token: Global pagination token for fetching more updates
event_serializer: The event serializer to use
time_now: Current time in milliseconds for event serialization
store: Datastore for serializing stream tokens
serialize_options: Serialization config
Returns:
JSON-serializable dict with "updates" and optionally "prev_batch"
"""
updates_dict: JsonDict = {}
for room_id, room_threads in thread_updates.items():
room_updates: JsonDict = {}
for thread_root_id, update in room_threads.items():
update_dict: JsonDict = {}
# Serialize thread_root event if present
if update.thread_root is not None:
bundle_aggs_map = (
{thread_root_id: update.bundled_aggregations}
if update.bundled_aggregations is not None
else None
)
serialized_events = await event_serializer.serialize_events(
[update.thread_root],
time_now,
config=serialize_options,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]
# Add per-thread prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
room_updates[thread_root_id] = update_dict
updates_dict[room_id] = room_updates
result: JsonDict = {"updates": updates_dict}
# Add global prev_batch token if present
if prev_batch_token is not None:
result["prev_batch"] = await prev_batch_token.to_string(store)
return result
async def get_threads(
self,
requester: Requester,

View File

@@ -14,7 +14,6 @@
import itertools
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
AbstractSet,
@@ -31,7 +30,6 @@ from synapse.api.constants import (
AccountDataTypes,
EduTypes,
EventContentFields,
Membership,
MRelatesToFields,
RelationTypes,
)
@@ -40,15 +38,12 @@ from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.receipts import ReceiptInRoom
from synapse.storage.databases.main.relations import ThreadUpdateInfo
from synapse.types import (
DeviceListUpdates,
JsonMapping,
MultiWriterStreamToken,
RoomStreamToken,
SlidingSyncStreamToken,
StrCollection,
StreamKeyType,
StreamToken,
ThreadSubscriptionsToken,
)
@@ -64,7 +59,6 @@ from synapse.util.async_helpers import (
concurrently_execute,
gather_optional_coroutines,
)
from synapse.visibility import filter_events_for_client
_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
@@ -72,7 +66,6 @@ _ThreadSubscription: TypeAlias = (
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -1040,42 +1033,6 @@ class SlidingSyncExtensionHandler:
threads_in_timeline.add(thread_id)
return threads_in_timeline
def _merge_prev_batch_token(
self,
current_token: StreamToken | None,
new_token: StreamToken | None,
) -> StreamToken | None:
"""Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination.
Args:
current_token: The current prev_batch token (may be None)
new_token: The new prev_batch token to merge (may be None)
Returns:
The merged token (maximum of the two, or None if both are None)
"""
if new_token is None:
return current_token
if current_token is None:
return new_token
if new_token.room_key.stream > current_token.room_key.stream:
return new_token
return current_token
def _merge_thread_updates(
self,
target: dict[str, list[ThreadUpdateInfo]],
source: dict[str, list[ThreadUpdateInfo]],
) -> None:
"""Merge thread updates from source into target.
Args:
target: The target dict to merge into (modified in place)
source: The source dict to merge from
"""
for thread_id, updates in source.items():
target.setdefault(thread_id, []).extend(updates)
async def get_threads_extension_response(
self,
sync_config: SlidingSyncConfig,
@@ -1117,159 +1074,26 @@ class SlidingSyncExtensionHandler:
actual_room_response_map
)
# Separate rooms into groups based on membership status.
# For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events
# that occurred after the user left/was banned.
leave_ban_rooms: set[str] = set()
other_rooms: set[str] = set()
for room_id in actual_room_ids:
membership_info = room_membership_for_user_at_to_token_map.get(room_id)
if membership_info and membership_info.membership in (
Membership.LEAVE,
Membership.BAN,
):
leave_ban_rooms.add(room_id)
else:
other_rooms.add(room_id)
# Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks.
all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {}
prev_batch_token: StreamToken | None = None
remaining_limit = threads_request.limit
# Query for rooms where the user has left or been banned, using their leave/ban
# event position as the upper bound to prevent seeing events after they left.
if leave_ban_rooms:
for room_id in leave_ban_rooms:
if remaining_limit <= 0:
# We've already fetched enough updates, but we still need to set
# prev_batch to indicate there are more results.
prev_batch_token = to_token
break
membership_info = room_membership_for_user_at_to_token_map[room_id]
bounded_to_token = membership_info.event_pos.to_room_stream_token()
(
room_thread_updates,
room_prev_batch,
) = await self.store.get_thread_updates_for_rooms(
room_ids={room_id},
from_token=from_token.stream_token.room_key if from_token else None,
to_token=bounded_to_token,
limit=remaining_limit,
exclude_thread_ids=threads_to_exclude,
)
# Count how many updates we fetched and reduce the remaining limit
num_updates = sum(
len(updates) for updates in room_thread_updates.values()
)
remaining_limit -= num_updates
self._merge_thread_updates(all_thread_updates, room_thread_updates)
prev_batch_token = self._merge_prev_batch_token(
prev_batch_token, room_prev_batch
)
# Query for rooms where the user is joined, invited, or knocking, using the
# normal to_token as the upper bound.
if other_rooms and remaining_limit > 0:
(
other_thread_updates,
other_prev_batch,
) = await self.store.get_thread_updates_for_rooms(
room_ids=other_rooms,
from_token=from_token.stream_token.room_key if from_token else None,
to_token=to_token.room_key,
limit=remaining_limit,
exclude_thread_ids=threads_to_exclude,
)
self._merge_thread_updates(all_thread_updates, other_thread_updates)
prev_batch_token = self._merge_prev_batch_token(
prev_batch_token, other_prev_batch
)
if len(all_thread_updates) == 0:
return None
# Build a mapping of event_id -> (thread_id, update) for efficient lookup
# during visibility filtering.
event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {}
for thread_id, updates in all_thread_updates.items():
for update in updates:
event_to_thread_map[update.event_id] = (thread_id, update)
# Fetch and filter events for visibility
all_events = await self.store.get_events_as_list(event_to_thread_map.keys())
filtered_events = await filter_events_for_client(
self._storage_controllers, sync_config.user.to_string(), all_events
# Get thread updates using unified helper
user_id = sync_config.user.to_string()
(
thread_updates_response,
prev_batch_token,
) = await self.relations_handler.get_thread_updates_for_rooms(
room_ids=frozenset(actual_room_ids),
room_membership_map=room_membership_for_user_at_to_token_map,
user_id=user_id,
from_token=from_token.stream_token if from_token else None,
to_token=to_token,
limit=threads_request.limit,
include_roots=threads_request.include_roots,
exclude_thread_ids=threads_to_exclude,
)
# Rebuild thread updates from filtered events
filtered_updates: dict[str, list[ThreadUpdateInfo]] = defaultdict(list)
for event in filtered_events:
if event.event_id in event_to_thread_map:
thread_id, update = event_to_thread_map[event.event_id]
filtered_updates[thread_id].append(update)
if not filtered_updates:
if not thread_updates_response:
return None
# Note: Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.
# Optionally fetch thread root events and their bundled aggregations
thread_root_event_map = {}
aggregations_map = {}
if threads_request.include_roots:
thread_root_events = await self.store.get_events_as_list(
filtered_updates.keys()
)
thread_root_event_map = {e.event_id: e for e in thread_root_events}
if thread_root_event_map:
aggregations_map = (
await self.relations_handler.get_bundled_aggregations(
thread_root_event_map.values(),
sync_config.user.to_string(),
)
)
thread_updates: dict[str, dict[str, _ThreadUpdate]] = {}
for thread_root, updates in filtered_updates.items():
# We only care about the latest update for the thread.
# After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering).
latest_update = updates[0]
# Generate per-thread prev_batch token if this thread has multiple visible updates.
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
# we only saw 1 update for them. This is to cover the case where we only saw
# a single update for a given thread, but the global limit prevents us from
# obtaining other updates which would have otherwise been included in the
# range.
per_thread_prev_batch = None
if len(updates) > 1 or prev_batch_token is not None:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
thread_updates.setdefault(latest_update.room_id, {})[thread_root] = (
_ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root),
)
)
return SlidingSyncResult.Extensions.ThreadsExtension(
updates=thread_updates,
updates=thread_updates_response,
prev_batch=prev_batch_token,
)

View File

@@ -20,17 +20,33 @@
import logging
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Annotated
from synapse.api.constants import Direction
from pydantic import StrictBool, StrictStr
from pydantic.types import StringConstraints
from synapse.api.constants import Direction, Membership
from synapse.api.errors import SynapseError
from synapse.events.utils import SerializeEventConfig
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.servlet import (
RestServlet,
parse_and_validate_json_object_from_request,
parse_boolean,
parse_integer,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.relations import ThreadsNextBatch
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict
from synapse.streams.config import (
PaginationConfig,
extract_stream_token_from_pagination_token,
)
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig
from synapse.types.rest.client import RequestBodyModel, SlidingSyncBody
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -38,6 +54,39 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ThreadUpdatesBody(RequestBodyModel):
"""
Thread updates companion endpoint request body (MSC4360).
Allows paginating thread updates using the same room selection as a sliding sync
request. This enables clients to fetch thread updates for the same set of rooms
that were included in their sliding sync response.
Attributes:
lists: Sliding window API lists, using the same structure as SlidingSyncBody.lists.
If provided along with room_subscriptions, the union of rooms from both will
be used.
room_subscriptions: Room subscription API rooms, using the same structure as
SlidingSyncBody.room_subscriptions. If provided along with lists, the union
of rooms from both will be used.
include_roots: Whether to include the thread root events in the response.
Defaults to False.
If neither lists nor room_subscriptions are provided, thread updates from all
joined rooms are returned.
"""
lists: (
dict[
Annotated[str, StringConstraints(max_length=64, strict=True)],
SlidingSyncBody.SlidingSyncList,
]
| None
) = None
room_subscriptions: dict[StrictStr, SlidingSyncBody.RoomSubscription] | None = None
include_roots: StrictBool = False
class RelationPaginationServlet(RestServlet):
"""API to paginate relations on an event by topological ordering, optionally
filtered by relation type and event type.
@@ -133,6 +182,167 @@ class ThreadsServlet(RestServlet):
return 200, result
class ThreadUpdatesServlet(RestServlet):
"""
Companion endpoint to the Sliding Sync threads extension (MSC4360).
Allows clients to bulk fetch thread updates across all joined rooms.
"""
PATTERNS = client_patterns(
"/io.element.msc4360/thread_updates$",
unstable=True,
releases=(),
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.relations_handler = hs.get_relations_handler()
self.event_serializer = hs.get_event_client_serializer()
self._storage_controllers = hs.get_storage_controllers()
self.sliding_sync_handler = hs.get_sliding_sync_handler()
async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
# Parse request body
body = parse_and_validate_json_object_from_request(request, ThreadUpdatesBody)
# Parse query parameters
dir_str = parse_string(request, "dir", default="b")
if dir_str != "b":
raise SynapseError(
400,
"The 'dir' parameter must be 'b' (backward). Forward pagination is not supported.",
)
limit = parse_integer(request, "limit", default=100)
if limit <= 0:
raise SynapseError(400, "The 'limit' parameter must be positive.")
from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")
# Parse pagination tokens
from_token: StreamToken | None = None
to_token: StreamToken | None = None
if from_token_str:
try:
stream_token_str = extract_stream_token_from_pagination_token(
from_token_str
)
from_token = await StreamToken.from_string(self.store, stream_token_str)
except Exception as e:
logger.exception("Error parsing 'from' token: %s", from_token_str)
raise SynapseError(400, "'from' parameter is invalid") from e
if to_token_str:
try:
stream_token_str = extract_stream_token_from_pagination_token(
to_token_str
)
to_token = await StreamToken.from_string(self.store, stream_token_str)
except Exception:
raise SynapseError(400, "'to' parameter is invalid")
# Get the list of rooms to fetch thread updates for
user_id = requester.user.to_string()
user = UserID.from_string(user_id)
# Get the current stream token for membership lookup
if from_token is None:
max_stream_ordering = self.store.get_room_max_stream_ordering()
current_token = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM, RoomStreamToken(stream=max_stream_ordering)
)
else:
current_token = from_token
# Get room membership information to properly handle LEAVE/BAN rooms
(
room_membership_for_user_at_to_token_map,
_,
_,
) = await self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
user=user,
to_token=current_token,
from_token=None,
)
# Determine which rooms to fetch updates for based on lists/room_subscriptions
if body.lists is not None or body.room_subscriptions is not None:
# Use sliding sync room selection logic
sync_config = SlidingSyncConfig(
user=user,
requester=requester,
lists=body.lists,
room_subscriptions=body.room_subscriptions,
)
# Use the sliding sync room list handler to get the same set of rooms
interested_rooms = (
await self.sliding_sync_handler.room_lists.compute_interested_rooms(
sync_config=sync_config,
previous_connection_state=PerConnectionState(),
to_token=current_token,
from_token=None,
)
)
room_ids = frozenset(interested_rooms.relevant_room_map.keys())
else:
# No lists or room_subscriptions, use only joined rooms
room_ids = frozenset(
room_id
for room_id, membership_info in room_membership_for_user_at_to_token_map.items()
if membership_info.membership == Membership.JOIN
)
# Get thread updates using unified helper
(
thread_updates,
prev_batch_token,
) = await self.relations_handler.get_thread_updates_for_rooms(
room_ids=room_ids,
room_membership_map=room_membership_for_user_at_to_token_map,
user_id=user_id,
from_token=to_token,
to_token=from_token if from_token else current_token,
limit=limit,
include_roots=body.include_roots,
)
if not thread_updates:
return 200, {"chunk": {}}
# Serialize thread updates using shared helper
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
serialized = await self.relations_handler.serialize_thread_updates(
thread_updates=thread_updates,
prev_batch_token=prev_batch_token,
event_serializer=self.event_serializer,
time_now=time_now,
store=self.store,
serialize_options=serialize_options,
)
# Build response with "chunk" wrapper and "next_batch" key
# (companion endpoint uses different key names than sliding sync)
response: JsonDict = {"chunk": serialized["updates"]}
if "prev_batch" in serialized:
response["next_batch"] = serialized["prev_batch"]
return 200, response
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
ThreadsServlet(hs).register(http_server)
if hs.config.experimental.msc4360_enabled:
ThreadUpdatesServlet(hs).register(http_server)

View File

@@ -37,6 +37,7 @@ from synapse.events.utils import (
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.relations import RelationsHandler
from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult
from synapse.handlers.sync import (
ArchivedSyncResult,
@@ -1107,6 +1108,7 @@ class SlidingSyncRestServlet(RestServlet):
time_now,
extensions.threads,
self.store,
requester,
)
return serialized_extensions
@@ -1150,6 +1152,7 @@ async def _serialise_threads(
time_now: int,
threads: SlidingSyncResult.Extensions.ThreadsExtension,
store: "DataStore",
requester: Requester,
) -> JsonDict:
"""
Serialize the threads extension response for sliding sync.
@@ -1159,6 +1162,7 @@ async def _serialise_threads(
time_now: The current time in milliseconds, used for event serialization.
threads: The threads extension data containing thread updates and pagination tokens.
store: The datastore, needed for serializing stream tokens.
requester: The user making the request, used for transaction_id inclusion.
Returns:
A JSON-serializable dict containing:
@@ -1169,46 +1173,24 @@ async def _serialise_threads(
- "prev_batch": A pagination token for fetching older events in the thread.
- "prev_batch": A pagination token for fetching older thread updates (if available).
"""
out: JsonDict = {}
if not threads.updates:
out: JsonDict = {}
if threads.prev_batch:
out["prev_batch"] = await threads.prev_batch.to_string(store)
return out
if threads.updates:
updates_dict: JsonDict = {}
for room_id, thread_updates in threads.updates.items():
room_updates: JsonDict = {}
for thread_root_id, update in thread_updates.items():
# Serialize the update
update_dict: JsonDict = {}
# Create serialization config to include transaction_id for requester's events
serialize_options = SerializeEventConfig(requester=requester)
# Serialize the thread_root event if present
if update.thread_root is not None:
# Create a mapping of event_id to bundled_aggregations
bundle_aggs_map = (
{thread_root_id: update.bundled_aggregations}
if update.bundled_aggregations
else None
)
serialized_events = await event_serializer.serialize_events(
[update.thread_root],
time_now,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]
# Add prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
room_updates[thread_root_id] = update_dict
updates_dict[room_id] = room_updates
out["updates"] = updates_dict
if threads.prev_batch:
out["prev_batch"] = await threads.prev_batch.to_string(store)
return out
# Use shared serialization helper (static method)
return await RelationsHandler.serialize_thread_updates(
thread_updates=threads.updates,
prev_batch_token=threads.prev_batch,
event_serializer=event_serializer,
time_now=time_now,
store=store,
serialize_options=serialize_options,
)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:

View File

@@ -35,6 +35,32 @@ logger = logging.getLogger(__name__)
MAX_LIMIT = 1000
def extract_stream_token_from_pagination_token(token_str: str) -> str:
"""
Extract the StreamToken portion from a pagination token string.
Handles both:
- StreamToken format: "s123_456_..."
- SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /)
This allows clients using sliding sync to use their pos tokens
with endpoints like /relations and /messages.
Args:
token_str: The token string to parse
Returns:
The StreamToken portion of the token
"""
if "/" in token_str:
# SlidingSyncStreamToken format: "connection_position/stream_token"
# Split and return just the stream_token part
parts = token_str.split("/", 1)
if len(parts) == 2:
return parts[1]
return token_str
@attr.s(slots=True, auto_attribs=True)
class PaginationConfig:
"""A configuration object which stores pagination parameters."""
@@ -57,32 +83,14 @@ class PaginationConfig:
from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")
# Helper function to extract StreamToken from either StreamToken or SlidingSyncStreamToken format
def extract_stream_token(token_str: str) -> str:
"""
Extract the StreamToken portion from a token string.
Handles both:
- StreamToken format: "s123_456_..."
- SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /)
This allows clients using sliding sync to use their pos tokens
with endpoints like /relations and /messages.
"""
if "/" in token_str:
# SlidingSyncStreamToken format: "connection_position/stream_token"
# Split and return just the stream_token part
parts = token_str.split("/", 1)
if len(parts) == 2:
return parts[1]
return token_str
try:
from_tok = None
if from_tok_str == "END":
from_tok = None # For backwards compat.
elif from_tok_str:
stream_token_str = extract_stream_token(from_tok_str)
stream_token_str = extract_stream_token_from_pagination_token(
from_tok_str
)
from_tok = await StreamToken.from_string(store, stream_token_str)
except Exception:
raise SynapseError(400, "'from' parameter is invalid")
@@ -90,7 +98,9 @@ class PaginationConfig:
try:
to_tok = None
if to_tok_str:
stream_token_str = extract_stream_token(to_tok_str)
stream_token_str = extract_stream_token_from_pagination_token(
to_tok_str
)
to_tok = await StreamToken.from_string(store, stream_token_str)
except Exception:
raise SynapseError(400, "'to' parameter is invalid")

View File

@@ -37,7 +37,8 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.handlers.relations import BundledAggregations, ThreadUpdate
from synapse.types import (
DeviceListUpdates,
JsonDict,
@@ -409,30 +410,7 @@ class SlidingSyncResult:
to paginate through older thread updates.
"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdate:
"""Information about a single thread that has new activity.
Attributes:
thread_root: The thread root event, if requested via include_roots in the
request. This is the event that started the thread.
prev_batch: A pagination token (exclusive) for fetching older events in this
specific thread. Only present if the thread has multiple updates in the
sync window. This token can be used with the /relations endpoint with
dir=b to paginate backwards through the thread's history.
bundled_aggregations: Bundled aggregations for the thread root event,
including the latest_event in the thread (found in
unsigned.m.relations.m.thread). Only present if thread_root is included.
"""
thread_root: EventBase | None
prev_batch: StreamToken | None
bundled_aggregations: "BundledAggregations | None" = None
def __bool__(self) -> bool:
return bool(self.thread_root) or bool(self.prev_batch)
updates: Mapping[str, Mapping[str, ThreadUpdate]] | None
updates: Mapping[str, Mapping[str, "ThreadUpdate"]] | None
prev_batch: StreamToken | None
def __bool__(self) -> bool:

View File

@@ -924,6 +924,250 @@ class SlidingSyncThreadsExtensionTestCase(SlidingSyncBase):
# Verify the thread root event is present
self.assertIn("thread_root", thread_updates[thread_root_id])
def test_thread_updates_initial_sync(self) -> None:
"""
Test that prev_batch from the threads extension response can be used
with the /thread_updates endpoint to get additional thread updates during
initial sync. This verifies:
1. The from parameter boundary is exclusive (no duplicates)
2. Using prev_batch as 'from' provides complete coverage (no gaps)
3. Works correctly with different numbers of threads
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create 5 thread roots
thread_ids = []
for i in range(5):
thread_root_id = self.helper.send(
room_id, body=f"Thread {i}", tok=user1_tok
)["event_id"]
thread_ids.append(thread_root_id)
# Add reply to each thread
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": f"Reply to thread {i}",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Do initial sync with threads extension enabled and limit=2
sync_body = {
"lists": {
"all-rooms": {
"ranges": [[0, 10]],
"required_state": [],
"timeline_limit": 0,
}
},
"extensions": {
EXT_NAME: {
"enabled": True,
"limit": 2,
}
},
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Should get 2 thread updates
thread_updates = response_body["extensions"][EXT_NAME]["updates"][room_id]
self.assertEqual(len(thread_updates), 2)
first_sync_threads = set(thread_updates.keys())
# Get the top-level prev_batch token from the extension
self.assertIn("prev_batch", response_body["extensions"][EXT_NAME])
prev_batch = response_body["extensions"][EXT_NAME]["prev_batch"]
# Use prev_batch with /thread_updates endpoint to get remaining updates
# Note: prev_batch should be used as 'from' parameter (upper bound for backward pagination)
channel = self.make_request(
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from={prev_batch}",
access_token=user1_tok,
content={},
)
self.assertEqual(channel.code, 200)
# Should get the remaining 3 thread updates
chunk = channel.json_body["chunk"]
self.assertIn(room_id, chunk)
self.assertEqual(len(chunk[room_id]), 3)
thread_updates_response_threads = set(chunk[room_id].keys())
# Verify no overlap - the from parameter boundary should be exclusive
self.assertEqual(
len(first_sync_threads & thread_updates_response_threads),
0,
"from parameter boundary should be exclusive - no thread should appear in both responses",
)
# Verify no gaps - all threads should be accounted for
all_threads = set(thread_ids)
combined_threads = first_sync_threads | thread_updates_response_threads
self.assertEqual(
combined_threads,
all_threads,
"Combined responses should include all thread updates with no gaps",
)
def test_thread_updates_incremental_sync(self) -> None:
"""
Test the intended usage pattern from MSC4360: using prev_batch as 'from'
and a previous sync pos as 'to' with /thread_updates to fill gaps between
syncs. This verifies that using both bounds together provides complete
coverage with no gaps or duplicates.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create 3 threads initially
initial_thread_ids = []
for i in range(3):
thread_root_id = self.helper.send(
room_id, body=f"Thread {i}", tok=user1_tok
)["event_id"]
initial_thread_ids.append(thread_root_id)
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": f"Reply to thread {i}",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# First sync
sync_body = {
"lists": {
"all-rooms": {
"ranges": [[0, 10]],
"required_state": [],
"timeline_limit": 0,
}
},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
response_body, pos1 = self.do_sync(sync_body, tok=user1_tok)
# Should get 3 thread updates
first_sync_threads = set(
response_body["extensions"][EXT_NAME]["updates"][room_id].keys()
)
self.assertEqual(len(first_sync_threads), 3)
# Create 3 more threads after the first sync
new_thread_ids = []
for i in range(3, 6):
thread_root_id = self.helper.send(
room_id, body=f"Thread {i}", tok=user1_tok
)["event_id"]
new_thread_ids.append(thread_root_id)
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": f"Reply to thread {i}",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Second sync with limit=1 to get only some of the new threads
sync_body_with_limit = {
"lists": {
"all-rooms": {
"ranges": [[0, 10]],
"required_state": [],
"timeline_limit": 0,
}
},
"extensions": {
EXT_NAME: {
"enabled": True,
"limit": 1,
}
},
}
response_body, pos2 = self.do_sync(
sync_body_with_limit, tok=user1_tok, since=pos1
)
# Should get 1 thread update
second_sync_threads = set(
response_body["extensions"][EXT_NAME]["updates"][room_id].keys()
)
self.assertEqual(len(second_sync_threads), 1)
# Get prev_batch from the extension
self.assertIn("prev_batch", response_body["extensions"][EXT_NAME])
prev_batch = response_body["extensions"][EXT_NAME]["prev_batch"]
# Now use /thread_updates with from=prev_batch and to=pos1
# This should get the 2 remaining new threads (created after pos1, not returned in second sync)
channel = self.make_request(
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from={prev_batch}&to={pos1}",
access_token=user1_tok,
content={},
)
self.assertEqual(channel.code, 200)
chunk = channel.json_body["chunk"]
self.assertIn(room_id, chunk)
thread_updates_threads = set(chunk[room_id].keys())
# Should get exactly 2 threads
self.assertEqual(len(thread_updates_threads), 2)
# Verify no overlap with second sync
self.assertEqual(
len(second_sync_threads & thread_updates_threads),
0,
"No thread should appear in both second sync and thread_updates responses",
)
# Verify no overlap with first sync (to=pos1 should exclude those)
self.assertEqual(
len(first_sync_threads & thread_updates_threads),
0,
"Threads from first sync should not appear in thread_updates (to=pos1 excludes them)",
)
# Verify no gaps - all new threads should be accounted for
all_new_threads = set(new_thread_ids)
combined_new_threads = second_sync_threads | thread_updates_threads
self.assertEqual(
combined_new_threads,
all_new_threads,
"Combined responses should include all new thread updates with no gaps",
)
def test_threads_only_from_rooms_in_list(self) -> None:
"""
Test that thread updates are only returned for rooms that are in the

View File

@@ -0,0 +1,957 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import logging
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import RelationTypes
from synapse.rest.client import login, relations, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util.clock import Clock
from tests import unittest
logger = logging.getLogger(__name__)
class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
"""
Test the /thread_updates companion endpoint (MSC4360).
"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
relations.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc4360_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
def test_no_updates_for_new_user(self) -> None:
"""
Test that a user with no thread updates gets an empty response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Request thread updates
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
content={"include_roots": True},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert empty chunk and no next_batch
self.assertEqual(channel.json_body["chunk"], {})
self.assertNotIn("next_batch", channel.json_body)
def test_single_thread_update(self) -> None:
"""
Test that a single thread with one reply appears in the response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create thread root
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# Add reply to thread
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply 1",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert thread is present
chunk = channel.json_body["chunk"]
self.assertIn(room_id, chunk)
self.assertIn(thread_root_id, chunk[room_id])
# Assert thread root is included
thread_update = chunk[room_id][thread_root_id]
self.assertIn("thread_root", thread_update)
self.assertEqual(thread_update["thread_root"]["event_id"], thread_root_id)
# Assert prev_batch is NOT present (only 1 update - the reply)
self.assertNotIn("prev_batch", thread_update)
def test_multiple_threads_single_room(self) -> None:
"""
Test that multiple threads in the same room are grouped correctly.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create two threads
thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[
"event_id"
]
thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[
"event_id"
]
# Add replies to both threads
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 1",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread1_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 2",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread2_root_id,
},
},
tok=user1_tok,
)
# Request thread updates
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert both threads are in the same room
chunk = channel.json_body["chunk"]
self.assertIn(room_id, chunk)
self.assertEqual(len(chunk), 1, "Should only have one room")
self.assertEqual(len(chunk[room_id]), 2, "Should have two threads")
self.assertIn(thread1_root_id, chunk[room_id])
self.assertIn(thread2_root_id, chunk[room_id])
def test_threads_across_multiple_rooms(self) -> None:
"""
Test that threads from different rooms are grouped by room_id.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_a_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room_b_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in both rooms
thread_a_root_id = self.helper.send(room_a_id, body="Thread A", tok=user1_tok)[
"event_id"
]
thread_b_root_id = self.helper.send(room_b_id, body="Thread B", tok=user1_tok)[
"event_id"
]
# Add replies
self.helper.send_event(
room_a_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to A",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_a_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
room_b_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to B",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_b_root_id,
},
},
tok=user1_tok,
)
# Request thread updates
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert both rooms are present with their threads
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 2, "Should have two rooms")
self.assertIn(room_a_id, chunk)
self.assertIn(room_b_id, chunk)
self.assertIn(thread_a_root_id, chunk[room_a_id])
self.assertIn(thread_b_root_id, chunk[room_b_id])
def test_pagination_with_from_token(self) -> None:
"""
Test that pagination works using the next_batch token.
This verifies that multiple calls to /thread_updates return all thread
updates with no duplicates and no gaps.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create many threads (more than default limit)
thread_ids = []
for i in range(5):
thread_root_id = self.helper.send(
room_id, body=f"Thread {i}", tok=user1_tok
)["event_id"]
thread_ids.append(thread_root_id)
# Add reply
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": f"Reply to thread {i}",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request first page with small limit
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
# Should have 2 threads and a next_batch token
first_page_threads = set(channel.json_body["chunk"][room_id].keys())
self.assertEqual(len(first_page_threads), 2)
self.assertIn("next_batch", channel.json_body)
next_batch = channel.json_body["next_batch"]
# Request second page
channel = self.make_request(
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
second_page_threads = set(channel.json_body["chunk"][room_id].keys())
self.assertEqual(len(second_page_threads), 2)
# Verify no overlap
self.assertEqual(
len(first_page_threads & second_page_threads),
0,
"Pages should not have overlapping threads",
)
# Request third page to get the remaining thread
self.assertIn("next_batch", channel.json_body)
next_batch_2 = channel.json_body["next_batch"]
channel = self.make_request(
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200, channel.json_body)
third_page_threads = set(channel.json_body["chunk"][room_id].keys())
self.assertEqual(len(third_page_threads), 1)
# Verify no overlap between any pages
self.assertEqual(len(first_page_threads & third_page_threads), 0)
self.assertEqual(len(second_page_threads & third_page_threads), 0)
# Verify no gaps - all threads should be accounted for across all pages
all_threads = set(thread_ids)
combined_threads = first_page_threads | second_page_threads | third_page_threads
self.assertEqual(
combined_threads,
all_threads,
"Combined pages should include all thread updates with no gaps",
)
def test_invalid_dir_parameter(self) -> None:
"""
Test that forward pagination (dir=f) is rejected with an error.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Request with forward direction should fail
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
def test_invalid_limit_parameter(self) -> None:
"""
Test that invalid limit values are rejected.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Zero limit should fail
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
# Negative limit should fail
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
def test_invalid_pagination_tokens(self) -> None:
"""
Test that invalid from/to tokens are rejected with appropriate errors.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Invalid from token
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
# Invalid to token
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 400)
def test_to_token_filtering(self) -> None:
"""
Test that the to_token parameter correctly limits pagination to updates
newer than the to_token (since we paginate backwards from newest to oldest).
This also verifies the to_token boundary is exclusive - updates at exactly
the to_token position should not be included (as they were already returned
in a previous response that synced up to that position).
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create two thread roots
thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[
"event_id"
]
thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[
"event_id"
]
# Send replies to both threads
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 1",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread1_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 2",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread2_root_id,
},
},
tok=user1_tok,
)
# Request with limit=1 to get only the latest thread update
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
self.assertIn("next_batch", channel.json_body)
# next_batch points to before the update we just received
next_batch = channel.json_body["next_batch"]
first_response_threads = set(channel.json_body["chunk"][room_id].keys())
# Request again with to=next_batch (lower bound for backward pagination) and no
# limit.
# This should get only the same thread updates as before, not the additional
# update.
channel = self.make_request(
"POST",
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
chunk = channel.json_body["chunk"]
self.assertIn(room_id, chunk)
# Should have exactly one thread update
self.assertEqual(len(chunk[room_id]), 1)
second_response_threads = set(chunk[room_id].keys())
# Verify no overlap - the from parameter boundary should be exclusive
self.assertEqual(
first_response_threads,
second_response_threads,
"to parameter boundary should be exclusive - both responses should be identical",
)
def test_bundled_aggregations_on_thread_roots(self) -> None:
"""
Test that thread root events include bundled aggregations with latest thread event.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create thread root
thread_root_id = self.helper.send(room_id, body="Thread root", tok=user1_tok)[
"event_id"
]
# Send replies to create bundled aggregation data
for i in range(2):
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": f"Reply {i + 1}",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
# Check that thread root has bundled aggregations with latest event
chunk = channel.json_body["chunk"]
thread_update = chunk[room_id][thread_root_id]
thread_root_event = thread_update["thread_root"]
# Should have unsigned data with latest thread event content
self.assertIn("unsigned", thread_root_event)
self.assertIn("m.relations", thread_root_event["unsigned"])
relations = thread_root_event["unsigned"]["m.relations"]
self.assertIn(RelationTypes.THREAD, relations)
# Check latest event is present in bundled aggregations
thread_summary = relations[RelationTypes.THREAD]
self.assertIn("latest_event", thread_summary)
latest_event = thread_summary["latest_event"]
self.assertEqual(latest_event["content"]["body"], "Reply 2")
def test_only_joined_rooms(self) -> None:
"""
Test that thread updates only include rooms where the user is currently joined.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# Create two rooms, user1 joins both
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room2_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room2_id, user1_id, tok=user1_tok)
# Create threads in both rooms
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
"event_id"
]
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user2_tok)[
"event_id"
]
# Add replies to both threads
self.helper.send_event(
room1_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 1",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread1_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
room2_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply to thread 2",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread2_root_id,
},
},
tok=user2_tok,
)
# User1 leaves room2
self.helper.leave(room2_id, user1_id, tok=user1_tok)
# Request thread updates for user1 - should only get room1
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={"include_roots": True},
)
self.assertEqual(channel.code, 200)
chunk = channel.json_body["chunk"]
# Should only have room1, not room2
self.assertIn(room1_id, chunk)
self.assertNotIn(room2_id, chunk)
self.assertIn(thread1_root_id, chunk[room1_id])
def test_room_filtering_with_lists(self) -> None:
"""
Test that room filtering works correctly using the lists parameter.
This verifies that thread updates are only returned for rooms matching
the provided filters.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create an encrypted room and an unencrypted room
encrypted_room_id = self.helper.create_room_as(
user1_id,
tok=user1_tok,
extra_content={
"initial_state": [
{
"type": "m.room.encryption",
"state_key": "",
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
}
]
},
)
unencrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in both rooms
enc_thread_root_id = self.helper.send(
encrypted_room_id, body="Encrypted thread", tok=user1_tok
)["event_id"]
unenc_thread_root_id = self.helper.send(
unencrypted_room_id, body="Unencrypted thread", tok=user1_tok
)["event_id"]
# Add replies to both threads
self.helper.send_event(
encrypted_room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply in encrypted room",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": enc_thread_root_id,
},
},
tok=user1_tok,
)
self.helper.send_event(
unencrypted_room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply in unencrypted room",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": unenc_thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with filter for encrypted rooms only
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"lists": {
"encrypted_list": {
"ranges": [[0, 99]],
"required_state": [["m.room.encryption", ""]],
"timeline_limit": 10,
"filters": {"is_encrypted": True},
}
}
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should only include the encrypted room
self.assertIn(encrypted_room_id, chunk)
self.assertNotIn(unencrypted_room_id, chunk)
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
def test_room_filtering_with_room_subscriptions(self) -> None:
"""
Test that room filtering works correctly using the room_subscriptions parameter.
This verifies that thread updates are only returned for explicitly subscribed rooms.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create three rooms
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
room3_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in all three rooms
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
"event_id"
]
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user1_tok)[
"event_id"
]
thread3_root_id = self.helper.send(room3_id, body="Thread 3", tok=user1_tok)[
"event_id"
]
# Add replies to all threads
for room_id, thread_root_id in [
(room1_id, thread1_root_id),
(room2_id, thread2_root_id),
(room3_id, thread3_root_id),
]:
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with subscription to only room1 and room2
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"room_subscriptions": {
room1_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
},
room2_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
},
}
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should only include room1 and room2, not room3
self.assertIn(room1_id, chunk)
self.assertIn(room2_id, chunk)
self.assertNotIn(room3_id, chunk)
self.assertIn(thread1_root_id, chunk[room1_id])
self.assertIn(thread2_root_id, chunk[room2_id])
def test_room_filtering_with_lists_and_room_subscriptions(self) -> None:
"""
Test that room filtering works correctly when both lists and room_subscriptions
are provided. The union of rooms from both should be included.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create an encrypted room and two unencrypted rooms
encrypted_room_id = self.helper.create_room_as(
user1_id,
tok=user1_tok,
extra_content={
"initial_state": [
{
"type": "m.room.encryption",
"state_key": "",
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
}
]
},
)
unencrypted_room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
unencrypted_room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create threads in all three rooms
enc_thread_root_id = self.helper.send(
encrypted_room_id, body="Encrypted thread", tok=user1_tok
)["event_id"]
unenc1_thread_root_id = self.helper.send(
unencrypted_room1_id, body="Unencrypted thread 1", tok=user1_tok
)["event_id"]
unenc2_thread_root_id = self.helper.send(
unencrypted_room2_id, body="Unencrypted thread 2", tok=user1_tok
)["event_id"]
# Add replies to all threads
for room_id, thread_root_id in [
(encrypted_room_id, enc_thread_root_id),
(unencrypted_room1_id, unenc1_thread_root_id),
(unencrypted_room2_id, unenc2_thread_root_id),
]:
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)
# Request thread updates with:
# - lists: filter for encrypted rooms
# - room_subscriptions: explicitly subscribe to unencrypted_room1_id
# Expected: should get both encrypted_room_id (from list) and unencrypted_room1_id
# (from subscription), but NOT unencrypted_room2_id
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
access_token=user1_tok,
content={
"lists": {
"encrypted_list": {
"ranges": [[0, 99]],
"required_state": [["m.room.encryption", ""]],
"timeline_limit": 10,
"filters": {"is_encrypted": True},
}
},
"room_subscriptions": {
unencrypted_room1_id: {
"required_state": [["m.room.name", ""]],
"timeline_limit": 10,
}
},
},
)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
# Should include encrypted_room_id (from list filter) and unencrypted_room1_id
# (from subscription), but not unencrypted_room2_id
self.assertIn(encrypted_room_id, chunk)
self.assertIn(unencrypted_room1_id, chunk)
self.assertNotIn(unencrypted_room2_id, chunk)
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
self.assertIn(unenc1_thread_root_id, chunk[unencrypted_room1_id])
def test_threads_not_returned_after_leaving_room(self) -> None:
"""
Test that thread updates are properly bounded when a user leaves a room.
Users should see thread updates that occurred up to the point they left,
but NOT updates that occurred after they left.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# Create room and both users join
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
self.helper.join(room_id, user2_id, tok=user2_tok)
# Create thread
res = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root = res["event_id"]
# Reply in thread while user2 is joined
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply 1 while user2 joined",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root,
},
},
tok=user1_tok,
)
# User2 leaves the room
self.helper.leave(room_id, user2_id, tok=user2_tok)
# Another reply after user2 left
self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": "Reply 2 after user2 left",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root,
},
},
tok=user1_tok,
)
# User2 gets thread updates with an explicit room subscription
# (We need to explicitly subscribe to the room to include it after leaving;
# otherwise only joined rooms are returned)
channel = self.make_request(
"POST",
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=100",
{
"room_subscriptions": {
room_id: {
"required_state": [],
"timeline_limit": 0,
}
}
},
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Assert: User2 SHOULD see Reply 1 (happened while joined) but NOT Reply 2 (after leaving)
chunk = channel.json_body["chunk"]
self.assertIn(
room_id,
chunk,
"Thread updates should include the room user2 left",
)
self.assertIn(
thread_root,
chunk[room_id],
"Thread root should be in the updates",
)
# Verify that only a single update was seen (Reply 1) by checking that there's
# no prev_batch token. If Reply 2 was also included, there would be multiple
# updates and a prev_batch token would be present.
thread_update = chunk[room_id][thread_root]
self.assertNotIn(
"prev_batch",
thread_update,
"No prev_batch should be present since only one update (Reply 1) is visible",
)