Compare commits
9 Commits
devon/ssex
...
devon/ssex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65fc3861b | ||
|
|
eff2503adc | ||
|
|
ba3c8a5f3e | ||
|
|
c826157c52 | ||
|
|
57c884ec83 | ||
|
|
89f75cc70f | ||
|
|
2f8568866e | ||
|
|
af992dd0e2 | ||
|
|
87e9fe8b38 |
1
changelog.d/19041.feature
Normal file
1
changelog.d/19041.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add companion endpoint for MSC4360: Sliding Sync Threads Extension.
|
||||
@@ -20,6 +20,7 @@
|
||||
#
|
||||
import enum
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
@@ -30,24 +31,59 @@ from typing import (
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import Direction, EventTypes, RelationTypes
|
||||
from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase, relation_from_event
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
|
||||
from synapse.storage.databases.main.relations import (
|
||||
ThreadsNextBatch,
|
||||
ThreadUpdateInfo,
|
||||
_RelatedEvent,
|
||||
)
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomStreamToken,
|
||||
StreamKeyType,
|
||||
StreamToken,
|
||||
UserID,
|
||||
)
|
||||
from synapse.util.async_helpers import gather_results
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.events.utils import EventClientSerializer
|
||||
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type aliases for thread update processing
|
||||
ThreadUpdatesMap = dict[str, list[ThreadUpdateInfo]]
|
||||
ThreadRootsMap = dict[str, EventBase]
|
||||
AggregationsMap = dict[str, "BundledAggregations"]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadUpdate:
|
||||
"""
|
||||
Data for a single thread update.
|
||||
|
||||
Attributes:
|
||||
thread_root: The thread root event, or None if not requested/not visible
|
||||
prev_batch: Per-thread pagination token for fetching older events in this thread
|
||||
bundled_aggregations: Bundled aggregations for the thread root event
|
||||
"""
|
||||
|
||||
thread_root: EventBase | None
|
||||
prev_batch: StreamToken | None
|
||||
bundled_aggregations: "BundledAggregations | None" = None
|
||||
|
||||
|
||||
class ThreadsListInclude(str, enum.Enum):
|
||||
"""Valid values for the 'include' flag of /threads."""
|
||||
@@ -544,6 +580,367 @@ class RelationsHandler:
|
||||
|
||||
return results
|
||||
|
||||
async def _filter_thread_updates_for_user(
|
||||
self,
|
||||
all_thread_updates: ThreadUpdatesMap,
|
||||
user_id: str,
|
||||
) -> ThreadUpdatesMap:
|
||||
"""Process thread updates by filtering for visibility.
|
||||
|
||||
Takes raw thread updates from storage and filters them based on whether the
|
||||
user can see the events. Preserves the ordering of updates within each thread.
|
||||
|
||||
Args:
|
||||
all_thread_updates: Map of thread_id to list of ThreadUpdateInfo objects
|
||||
user_id: The user ID to filter events for
|
||||
|
||||
Returns:
|
||||
Filtered map of thread_id to list of ThreadUpdateInfo objects, containing
|
||||
only updates for events the user can see.
|
||||
"""
|
||||
# Build a mapping of event_id -> (thread_id, update) for efficient lookup
|
||||
# during visibility filtering.
|
||||
event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {}
|
||||
for thread_id, updates in all_thread_updates.items():
|
||||
for update in updates:
|
||||
event_to_thread_map[update.event_id] = (thread_id, update)
|
||||
|
||||
# Fetch and filter events for visibility
|
||||
all_events = await self._main_store.get_events_as_list(
|
||||
event_to_thread_map.keys()
|
||||
)
|
||||
filtered_events = await filter_events_for_client(
|
||||
self._storage_controllers, user_id, all_events
|
||||
)
|
||||
|
||||
# Rebuild thread updates from filtered events
|
||||
filtered_updates: ThreadUpdatesMap = defaultdict(list)
|
||||
for event in filtered_events:
|
||||
if event.event_id in event_to_thread_map:
|
||||
thread_id, update = event_to_thread_map[event.event_id]
|
||||
filtered_updates[thread_id].append(update)
|
||||
|
||||
return filtered_updates
|
||||
|
||||
def _build_thread_updates_response(
|
||||
self,
|
||||
filtered_updates: ThreadUpdatesMap,
|
||||
thread_root_event_map: ThreadRootsMap,
|
||||
aggregations_map: AggregationsMap,
|
||||
global_prev_batch_token: StreamToken | None,
|
||||
) -> dict[str, dict[str, ThreadUpdate]]:
|
||||
"""Build thread update response structure with per-thread prev_batch tokens.
|
||||
|
||||
Args:
|
||||
filtered_updates: Map of thread_root_id to list of ThreadUpdateInfo
|
||||
thread_root_event_map: Map of thread_root_id to EventBase
|
||||
aggregations_map: Map of thread_root_id to BundledAggregations
|
||||
global_prev_batch_token: Global pagination token, or None if no more results
|
||||
|
||||
Returns:
|
||||
Map of room_id to thread_root_id to ThreadUpdate
|
||||
"""
|
||||
thread_updates: dict[str, dict[str, ThreadUpdate]] = {}
|
||||
|
||||
for thread_root_id, updates in filtered_updates.items():
|
||||
# We only care about the latest update for the thread
|
||||
# Updates are already sorted by stream_ordering DESC from the database query,
|
||||
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
|
||||
# the latest event for each thread.
|
||||
latest_update = updates[0]
|
||||
room_id = latest_update.room_id
|
||||
|
||||
# Generate per-thread prev_batch token if this thread has multiple visible updates
|
||||
# or if we hit the global limit.
|
||||
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
|
||||
# we only saw 1 update for them. This is to cover the case where we only saw
|
||||
# a single update for a given thread, but the global limit prevents us from
|
||||
# obtaining other updates which would have otherwise been included in the range.
|
||||
per_thread_prev_batch = None
|
||||
if len(updates) > 1 or global_prev_batch_token is not None:
|
||||
# Create a token pointing to one position before the latest event's stream position.
|
||||
# This makes it exclusive - /relations with dir=b won't return the latest event again.
|
||||
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
|
||||
per_thread_prev_batch = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM,
|
||||
RoomStreamToken(stream=latest_update.stream_ordering - 1),
|
||||
)
|
||||
|
||||
if room_id not in thread_updates:
|
||||
thread_updates[room_id] = {}
|
||||
|
||||
thread_updates[room_id][thread_root_id] = ThreadUpdate(
|
||||
thread_root=thread_root_event_map.get(thread_root_id),
|
||||
prev_batch=per_thread_prev_batch,
|
||||
bundled_aggregations=aggregations_map.get(thread_root_id),
|
||||
)
|
||||
|
||||
return thread_updates
|
||||
|
||||
async def _fetch_thread_updates(
|
||||
self,
|
||||
room_ids: frozenset[str],
|
||||
room_membership_map: Mapping[str, "RoomsForUserType"],
|
||||
from_token: StreamToken | None,
|
||||
to_token: StreamToken,
|
||||
limit: int,
|
||||
exclude_thread_ids: set[str] | None = None,
|
||||
) -> tuple[ThreadUpdatesMap, StreamToken | None]:
|
||||
"""Fetch thread updates across multiple rooms, handling membership states properly.
|
||||
|
||||
This method separates rooms based on membership status (LEAVE/BAN vs others)
|
||||
and queries them appropriately to prevent data leaks. For rooms where the user
|
||||
has left or been banned, we bound the query to their leave/ban event position.
|
||||
|
||||
Args:
|
||||
room_ids: The set of room IDs to fetch thread updates for
|
||||
room_membership_map: Map of room_id to RoomsForUserType containing membership info
|
||||
from_token: Lower bound (exclusive) for the query, or None for no lower bound
|
||||
to_token: Upper bound for the query (for joined/invited/knocking rooms)
|
||||
limit: Maximum number of thread updates to return across all rooms
|
||||
exclude_thread_ids: Optional set of thread IDs to exclude from results
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- Map of thread_id to list of ThreadUpdateInfo objects
|
||||
- Global prev_batch token if there are more results, None otherwise
|
||||
"""
|
||||
# Separate rooms based on membership to handle LEAVE/BAN rooms specially
|
||||
leave_ban_rooms: set[str] = set()
|
||||
other_rooms: set[str] = set()
|
||||
|
||||
for room_id in room_ids:
|
||||
membership_info = room_membership_map.get(room_id)
|
||||
if membership_info and membership_info.membership in (
|
||||
Membership.LEAVE,
|
||||
Membership.BAN,
|
||||
):
|
||||
leave_ban_rooms.add(room_id)
|
||||
else:
|
||||
other_rooms.add(room_id)
|
||||
|
||||
# Fetch thread updates from storage, handling LEAVE/BAN rooms separately
|
||||
all_thread_updates: ThreadUpdatesMap = {}
|
||||
prev_batch_token: StreamToken | None = None
|
||||
remaining_limit = limit
|
||||
|
||||
# Query LEAVE/BAN rooms with bounded to_token to prevent data leaks
|
||||
if leave_ban_rooms:
|
||||
for room_id in leave_ban_rooms:
|
||||
if remaining_limit <= 0:
|
||||
# We've hit the limit, set prev_batch to indicate more results
|
||||
prev_batch_token = to_token
|
||||
break
|
||||
|
||||
membership_info = room_membership_map[room_id]
|
||||
bounded_to_token = membership_info.event_pos.to_room_stream_token()
|
||||
|
||||
(
|
||||
room_thread_updates,
|
||||
room_prev_batch,
|
||||
) = await self._main_store.get_thread_updates_for_rooms(
|
||||
room_ids={room_id},
|
||||
from_token=from_token.room_key if from_token else None,
|
||||
to_token=bounded_to_token,
|
||||
limit=remaining_limit,
|
||||
exclude_thread_ids=exclude_thread_ids,
|
||||
)
|
||||
|
||||
# Count updates and reduce remaining limit
|
||||
num_updates = sum(
|
||||
len(updates) for updates in room_thread_updates.values()
|
||||
)
|
||||
remaining_limit -= num_updates
|
||||
|
||||
# Merge updates
|
||||
for thread_id, updates in room_thread_updates.items():
|
||||
all_thread_updates.setdefault(thread_id, []).extend(updates)
|
||||
|
||||
# Merge prev_batch tokens (take the maximum for backward pagination)
|
||||
if room_prev_batch is not None:
|
||||
if prev_batch_token is None:
|
||||
prev_batch_token = room_prev_batch
|
||||
elif (
|
||||
room_prev_batch.room_key.stream
|
||||
> prev_batch_token.room_key.stream
|
||||
):
|
||||
prev_batch_token = room_prev_batch
|
||||
|
||||
# Query other rooms (joined/invited/knocking) with normal to_token
|
||||
if other_rooms and remaining_limit > 0:
|
||||
(
|
||||
other_thread_updates,
|
||||
other_prev_batch,
|
||||
) = await self._main_store.get_thread_updates_for_rooms(
|
||||
room_ids=other_rooms,
|
||||
from_token=from_token.room_key if from_token else None,
|
||||
to_token=to_token.room_key,
|
||||
limit=remaining_limit,
|
||||
exclude_thread_ids=exclude_thread_ids,
|
||||
)
|
||||
|
||||
# Merge updates
|
||||
for thread_id, updates in other_thread_updates.items():
|
||||
all_thread_updates.setdefault(thread_id, []).extend(updates)
|
||||
|
||||
# Merge prev_batch tokens
|
||||
if other_prev_batch is not None:
|
||||
if prev_batch_token is None:
|
||||
prev_batch_token = other_prev_batch
|
||||
elif (
|
||||
other_prev_batch.room_key.stream > prev_batch_token.room_key.stream
|
||||
):
|
||||
prev_batch_token = other_prev_batch
|
||||
|
||||
return all_thread_updates, prev_batch_token
|
||||
|
||||
async def get_thread_updates_for_rooms(
|
||||
self,
|
||||
room_ids: frozenset[str],
|
||||
room_membership_map: Mapping[str, "RoomsForUserType"],
|
||||
user_id: str,
|
||||
from_token: StreamToken | None,
|
||||
to_token: StreamToken,
|
||||
limit: int,
|
||||
include_roots: bool = False,
|
||||
exclude_thread_ids: set[str] | None = None,
|
||||
) -> tuple[dict[str, dict[str, ThreadUpdate]], StreamToken | None]:
|
||||
"""Get thread updates across multiple rooms with full processing pipeline.
|
||||
|
||||
This is the main entry point for fetching thread updates. It handles:
|
||||
- Fetching updates with membership-based security
|
||||
- Filtering for visibility
|
||||
- Optionally fetching thread roots and aggregations
|
||||
- Building the response structure
|
||||
|
||||
Args:
|
||||
room_ids: The set of room IDs to fetch updates for
|
||||
room_membership_map: Map of room_id to RoomsForUserType for membership info
|
||||
user_id: The user requesting the updates
|
||||
from_token: Lower bound (exclusive) for the query
|
||||
to_token: Upper bound for the query
|
||||
limit: Maximum number of updates to return
|
||||
include_roots: Whether to fetch and include thread root events (default: False)
|
||||
exclude_thread_ids: Optional set of thread IDs to exclude
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- Map of room_id to thread_root_id to ThreadUpdate
|
||||
- Global prev_batch token if there are more results, None otherwise
|
||||
"""
|
||||
# Fetch thread updates with membership handling
|
||||
all_thread_updates, prev_batch_token = await self._fetch_thread_updates(
|
||||
room_ids=room_ids,
|
||||
room_membership_map=room_membership_map,
|
||||
from_token=from_token,
|
||||
to_token=to_token,
|
||||
limit=limit,
|
||||
exclude_thread_ids=exclude_thread_ids,
|
||||
)
|
||||
|
||||
if not all_thread_updates:
|
||||
return {}, prev_batch_token
|
||||
|
||||
# Filter thread updates for visibility
|
||||
filtered_updates = await self._filter_thread_updates_for_user(
|
||||
all_thread_updates, user_id
|
||||
)
|
||||
|
||||
if not filtered_updates:
|
||||
return {}, prev_batch_token
|
||||
|
||||
# Optionally fetch thread root events and their bundled aggregations
|
||||
thread_root_event_map: ThreadRootsMap = {}
|
||||
aggregations_map: AggregationsMap = {}
|
||||
if include_roots:
|
||||
# Fetch thread root events
|
||||
thread_root_events = await self._main_store.get_events_as_list(
|
||||
filtered_updates.keys()
|
||||
)
|
||||
thread_root_event_map = {e.event_id: e for e in thread_root_events}
|
||||
|
||||
# Fetch bundled aggregations for the thread roots
|
||||
if thread_root_event_map:
|
||||
aggregations_map = await self.get_bundled_aggregations(
|
||||
thread_root_event_map.values(),
|
||||
user_id,
|
||||
)
|
||||
|
||||
# Build response structure with per-thread prev_batch tokens
|
||||
thread_updates = self._build_thread_updates_response(
|
||||
filtered_updates=filtered_updates,
|
||||
thread_root_event_map=thread_root_event_map,
|
||||
aggregations_map=aggregations_map,
|
||||
global_prev_batch_token=prev_batch_token,
|
||||
)
|
||||
|
||||
return thread_updates, prev_batch_token
|
||||
|
||||
@staticmethod
|
||||
async def serialize_thread_updates(
|
||||
thread_updates: Mapping[str, Mapping[str, ThreadUpdate]],
|
||||
prev_batch_token: StreamToken | None,
|
||||
event_serializer: "EventClientSerializer",
|
||||
time_now: int,
|
||||
store: "DataStore",
|
||||
serialize_options: SerializeEventConfig,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Serialize thread updates to JSON format.
|
||||
|
||||
This helper handles serialization of ThreadUpdate objects for both the
|
||||
companion endpoint and the sliding sync extension.
|
||||
|
||||
Args:
|
||||
thread_updates: Map of room_id to thread_root_id to ThreadUpdate
|
||||
prev_batch_token: Global pagination token for fetching more updates
|
||||
event_serializer: The event serializer to use
|
||||
time_now: Current time in milliseconds for event serialization
|
||||
store: Datastore for serializing stream tokens
|
||||
serialize_options: Serialization config
|
||||
|
||||
Returns:
|
||||
JSON-serializable dict with "updates" and optionally "prev_batch"
|
||||
"""
|
||||
updates_dict: JsonDict = {}
|
||||
|
||||
for room_id, room_threads in thread_updates.items():
|
||||
room_updates: JsonDict = {}
|
||||
for thread_root_id, update in room_threads.items():
|
||||
update_dict: JsonDict = {}
|
||||
|
||||
# Serialize thread_root event if present
|
||||
if update.thread_root is not None:
|
||||
bundle_aggs_map = (
|
||||
{thread_root_id: update.bundled_aggregations}
|
||||
if update.bundled_aggregations is not None
|
||||
else None
|
||||
)
|
||||
serialized_events = await event_serializer.serialize_events(
|
||||
[update.thread_root],
|
||||
time_now,
|
||||
config=serialize_options,
|
||||
bundle_aggregations=bundle_aggs_map,
|
||||
)
|
||||
if serialized_events:
|
||||
update_dict["thread_root"] = serialized_events[0]
|
||||
|
||||
# Add per-thread prev_batch if present
|
||||
if update.prev_batch is not None:
|
||||
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
|
||||
|
||||
room_updates[thread_root_id] = update_dict
|
||||
|
||||
updates_dict[room_id] = room_updates
|
||||
|
||||
result: JsonDict = {"updates": updates_dict}
|
||||
|
||||
# Add global prev_batch token if present
|
||||
if prev_batch_token is not None:
|
||||
result["prev_batch"] = await prev_batch_token.to_string(store)
|
||||
|
||||
return result
|
||||
|
||||
async def get_threads(
|
||||
self,
|
||||
requester: Requester,
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
@@ -31,7 +30,6 @@ from synapse.api.constants import (
|
||||
AccountDataTypes,
|
||||
EduTypes,
|
||||
EventContentFields,
|
||||
Membership,
|
||||
MRelatesToFields,
|
||||
RelationTypes,
|
||||
)
|
||||
@@ -40,15 +38,12 @@ from synapse.handlers.receipts import ReceiptEventSource
|
||||
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.databases.main.receipts import ReceiptInRoom
|
||||
from synapse.storage.databases.main.relations import ThreadUpdateInfo
|
||||
from synapse.types import (
|
||||
DeviceListUpdates,
|
||||
JsonMapping,
|
||||
MultiWriterStreamToken,
|
||||
RoomStreamToken,
|
||||
SlidingSyncStreamToken,
|
||||
StrCollection,
|
||||
StreamKeyType,
|
||||
StreamToken,
|
||||
ThreadSubscriptionsToken,
|
||||
)
|
||||
@@ -64,7 +59,6 @@ from synapse.util.async_helpers import (
|
||||
concurrently_execute,
|
||||
gather_optional_coroutines,
|
||||
)
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
_ThreadSubscription: TypeAlias = (
|
||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
|
||||
@@ -72,7 +66,6 @@ _ThreadSubscription: TypeAlias = (
|
||||
_ThreadUnsubscription: TypeAlias = (
|
||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
|
||||
)
|
||||
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -1040,42 +1033,6 @@ class SlidingSyncExtensionHandler:
|
||||
threads_in_timeline.add(thread_id)
|
||||
return threads_in_timeline
|
||||
|
||||
def _merge_prev_batch_token(
|
||||
self,
|
||||
current_token: StreamToken | None,
|
||||
new_token: StreamToken | None,
|
||||
) -> StreamToken | None:
|
||||
"""Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination.
|
||||
|
||||
Args:
|
||||
current_token: The current prev_batch token (may be None)
|
||||
new_token: The new prev_batch token to merge (may be None)
|
||||
|
||||
Returns:
|
||||
The merged token (maximum of the two, or None if both are None)
|
||||
"""
|
||||
if new_token is None:
|
||||
return current_token
|
||||
if current_token is None:
|
||||
return new_token
|
||||
if new_token.room_key.stream > current_token.room_key.stream:
|
||||
return new_token
|
||||
return current_token
|
||||
|
||||
def _merge_thread_updates(
|
||||
self,
|
||||
target: dict[str, list[ThreadUpdateInfo]],
|
||||
source: dict[str, list[ThreadUpdateInfo]],
|
||||
) -> None:
|
||||
"""Merge thread updates from source into target.
|
||||
|
||||
Args:
|
||||
target: The target dict to merge into (modified in place)
|
||||
source: The source dict to merge from
|
||||
"""
|
||||
for thread_id, updates in source.items():
|
||||
target.setdefault(thread_id, []).extend(updates)
|
||||
|
||||
async def get_threads_extension_response(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
@@ -1117,159 +1074,26 @@ class SlidingSyncExtensionHandler:
|
||||
actual_room_response_map
|
||||
)
|
||||
|
||||
# Separate rooms into groups based on membership status.
|
||||
# For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events
|
||||
# that occurred after the user left/was banned.
|
||||
leave_ban_rooms: set[str] = set()
|
||||
other_rooms: set[str] = set()
|
||||
|
||||
for room_id in actual_room_ids:
|
||||
membership_info = room_membership_for_user_at_to_token_map.get(room_id)
|
||||
if membership_info and membership_info.membership in (
|
||||
Membership.LEAVE,
|
||||
Membership.BAN,
|
||||
):
|
||||
leave_ban_rooms.add(room_id)
|
||||
else:
|
||||
other_rooms.add(room_id)
|
||||
|
||||
# Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks.
|
||||
all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {}
|
||||
prev_batch_token: StreamToken | None = None
|
||||
remaining_limit = threads_request.limit
|
||||
|
||||
# Query for rooms where the user has left or been banned, using their leave/ban
|
||||
# event position as the upper bound to prevent seeing events after they left.
|
||||
if leave_ban_rooms:
|
||||
for room_id in leave_ban_rooms:
|
||||
if remaining_limit <= 0:
|
||||
# We've already fetched enough updates, but we still need to set
|
||||
# prev_batch to indicate there are more results.
|
||||
prev_batch_token = to_token
|
||||
break
|
||||
|
||||
membership_info = room_membership_for_user_at_to_token_map[room_id]
|
||||
bounded_to_token = membership_info.event_pos.to_room_stream_token()
|
||||
|
||||
(
|
||||
room_thread_updates,
|
||||
room_prev_batch,
|
||||
) = await self.store.get_thread_updates_for_rooms(
|
||||
room_ids={room_id},
|
||||
from_token=from_token.stream_token.room_key if from_token else None,
|
||||
to_token=bounded_to_token,
|
||||
limit=remaining_limit,
|
||||
exclude_thread_ids=threads_to_exclude,
|
||||
)
|
||||
|
||||
# Count how many updates we fetched and reduce the remaining limit
|
||||
num_updates = sum(
|
||||
len(updates) for updates in room_thread_updates.values()
|
||||
)
|
||||
remaining_limit -= num_updates
|
||||
|
||||
self._merge_thread_updates(all_thread_updates, room_thread_updates)
|
||||
prev_batch_token = self._merge_prev_batch_token(
|
||||
prev_batch_token, room_prev_batch
|
||||
)
|
||||
|
||||
# Query for rooms where the user is joined, invited, or knocking, using the
|
||||
# normal to_token as the upper bound.
|
||||
if other_rooms and remaining_limit > 0:
|
||||
(
|
||||
other_thread_updates,
|
||||
other_prev_batch,
|
||||
) = await self.store.get_thread_updates_for_rooms(
|
||||
room_ids=other_rooms,
|
||||
from_token=from_token.stream_token.room_key if from_token else None,
|
||||
to_token=to_token.room_key,
|
||||
limit=remaining_limit,
|
||||
exclude_thread_ids=threads_to_exclude,
|
||||
)
|
||||
|
||||
self._merge_thread_updates(all_thread_updates, other_thread_updates)
|
||||
prev_batch_token = self._merge_prev_batch_token(
|
||||
prev_batch_token, other_prev_batch
|
||||
)
|
||||
|
||||
if len(all_thread_updates) == 0:
|
||||
return None
|
||||
|
||||
# Build a mapping of event_id -> (thread_id, update) for efficient lookup
|
||||
# during visibility filtering.
|
||||
event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {}
|
||||
for thread_id, updates in all_thread_updates.items():
|
||||
for update in updates:
|
||||
event_to_thread_map[update.event_id] = (thread_id, update)
|
||||
|
||||
# Fetch and filter events for visibility
|
||||
all_events = await self.store.get_events_as_list(event_to_thread_map.keys())
|
||||
filtered_events = await filter_events_for_client(
|
||||
self._storage_controllers, sync_config.user.to_string(), all_events
|
||||
# Get thread updates using unified helper
|
||||
user_id = sync_config.user.to_string()
|
||||
(
|
||||
thread_updates_response,
|
||||
prev_batch_token,
|
||||
) = await self.relations_handler.get_thread_updates_for_rooms(
|
||||
room_ids=frozenset(actual_room_ids),
|
||||
room_membership_map=room_membership_for_user_at_to_token_map,
|
||||
user_id=user_id,
|
||||
from_token=from_token.stream_token if from_token else None,
|
||||
to_token=to_token,
|
||||
limit=threads_request.limit,
|
||||
include_roots=threads_request.include_roots,
|
||||
exclude_thread_ids=threads_to_exclude,
|
||||
)
|
||||
|
||||
# Rebuild thread updates from filtered events
|
||||
filtered_updates: dict[str, list[ThreadUpdateInfo]] = defaultdict(list)
|
||||
for event in filtered_events:
|
||||
if event.event_id in event_to_thread_map:
|
||||
thread_id, update = event_to_thread_map[event.event_id]
|
||||
filtered_updates[thread_id].append(update)
|
||||
|
||||
if not filtered_updates:
|
||||
if not thread_updates_response:
|
||||
return None
|
||||
|
||||
# Note: Updates are already sorted by stream_ordering DESC from the database query,
|
||||
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
|
||||
# the latest event for each thread.
|
||||
|
||||
# Optionally fetch thread root events and their bundled aggregations
|
||||
thread_root_event_map = {}
|
||||
aggregations_map = {}
|
||||
if threads_request.include_roots:
|
||||
thread_root_events = await self.store.get_events_as_list(
|
||||
filtered_updates.keys()
|
||||
)
|
||||
thread_root_event_map = {e.event_id: e for e in thread_root_events}
|
||||
|
||||
if thread_root_event_map:
|
||||
aggregations_map = (
|
||||
await self.relations_handler.get_bundled_aggregations(
|
||||
thread_root_event_map.values(),
|
||||
sync_config.user.to_string(),
|
||||
)
|
||||
)
|
||||
|
||||
thread_updates: dict[str, dict[str, _ThreadUpdate]] = {}
|
||||
for thread_root, updates in filtered_updates.items():
|
||||
# We only care about the latest update for the thread.
|
||||
# After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering).
|
||||
latest_update = updates[0]
|
||||
|
||||
# Generate per-thread prev_batch token if this thread has multiple visible updates.
|
||||
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
|
||||
# we only saw 1 update for them. This is to cover the case where we only saw
|
||||
# a single update for a given thread, but the global limit prevents us from
|
||||
# obtaining other updates which would have otherwise been included in the
|
||||
# range.
|
||||
per_thread_prev_batch = None
|
||||
if len(updates) > 1 or prev_batch_token is not None:
|
||||
# Create a token pointing to one position before the latest event's stream position.
|
||||
# This makes it exclusive - /relations with dir=b won't return the latest event again.
|
||||
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
|
||||
per_thread_prev_batch = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM,
|
||||
RoomStreamToken(stream=latest_update.stream_ordering - 1),
|
||||
)
|
||||
|
||||
thread_updates.setdefault(latest_update.room_id, {})[thread_root] = (
|
||||
_ThreadUpdate(
|
||||
thread_root=thread_root_event_map.get(thread_root),
|
||||
prev_batch=per_thread_prev_batch,
|
||||
bundled_aggregations=aggregations_map.get(thread_root),
|
||||
)
|
||||
)
|
||||
|
||||
return SlidingSyncResult.Extensions.ThreadsExtension(
|
||||
updates=thread_updates,
|
||||
updates=thread_updates_response,
|
||||
prev_batch=prev_batch_token,
|
||||
)
|
||||
|
||||
@@ -20,17 +20,33 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from pydantic import StrictBool, StrictStr
|
||||
from pydantic.types import StringConstraints
|
||||
|
||||
from synapse.api.constants import Direction, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
from synapse.handlers.relations import ThreadsListInclude
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_and_validate_json_object_from_request,
|
||||
parse_boolean,
|
||||
parse_integer,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.storage.databases.main.relations import ThreadsNextBatch
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict
|
||||
from synapse.streams.config import (
|
||||
PaginationConfig,
|
||||
extract_stream_token_from_pagination_token,
|
||||
)
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
|
||||
from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig
|
||||
from synapse.types.rest.client import RequestBodyModel, SlidingSyncBody
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -38,6 +54,39 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadUpdatesBody(RequestBodyModel):
|
||||
"""
|
||||
Thread updates companion endpoint request body (MSC4360).
|
||||
|
||||
Allows paginating thread updates using the same room selection as a sliding sync
|
||||
request. This enables clients to fetch thread updates for the same set of rooms
|
||||
that were included in their sliding sync response.
|
||||
|
||||
Attributes:
|
||||
lists: Sliding window API lists, using the same structure as SlidingSyncBody.lists.
|
||||
If provided along with room_subscriptions, the union of rooms from both will
|
||||
be used.
|
||||
room_subscriptions: Room subscription API rooms, using the same structure as
|
||||
SlidingSyncBody.room_subscriptions. If provided along with lists, the union
|
||||
of rooms from both will be used.
|
||||
include_roots: Whether to include the thread root events in the response.
|
||||
Defaults to False.
|
||||
|
||||
If neither lists nor room_subscriptions are provided, thread updates from all
|
||||
joined rooms are returned.
|
||||
"""
|
||||
|
||||
lists: (
|
||||
dict[
|
||||
Annotated[str, StringConstraints(max_length=64, strict=True)],
|
||||
SlidingSyncBody.SlidingSyncList,
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
room_subscriptions: dict[StrictStr, SlidingSyncBody.RoomSubscription] | None = None
|
||||
include_roots: StrictBool = False
|
||||
|
||||
|
||||
class RelationPaginationServlet(RestServlet):
|
||||
"""API to paginate relations on an event by topological ordering, optionally
|
||||
filtered by relation type and event type.
|
||||
@@ -133,6 +182,167 @@ class ThreadsServlet(RestServlet):
|
||||
return 200, result
|
||||
|
||||
|
||||
class ThreadUpdatesServlet(RestServlet):
|
||||
"""
|
||||
Companion endpoint to the Sliding Sync threads extension (MSC4360).
|
||||
Allows clients to bulk fetch thread updates across all joined rooms.
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/io.element.msc4360/thread_updates$",
|
||||
unstable=True,
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Client API requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.relations_handler = hs.get_relations_handler()
|
||||
self.event_serializer = hs.get_event_client_serializer()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.sliding_sync_handler = hs.get_sliding_sync_handler()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
# Parse request body
|
||||
body = parse_and_validate_json_object_from_request(request, ThreadUpdatesBody)
|
||||
|
||||
# Parse query parameters
|
||||
dir_str = parse_string(request, "dir", default="b")
|
||||
if dir_str != "b":
|
||||
raise SynapseError(
|
||||
400,
|
||||
"The 'dir' parameter must be 'b' (backward). Forward pagination is not supported.",
|
||||
)
|
||||
|
||||
limit = parse_integer(request, "limit", default=100)
|
||||
if limit <= 0:
|
||||
raise SynapseError(400, "The 'limit' parameter must be positive.")
|
||||
|
||||
from_token_str = parse_string(request, "from")
|
||||
to_token_str = parse_string(request, "to")
|
||||
|
||||
# Parse pagination tokens
|
||||
from_token: StreamToken | None = None
|
||||
to_token: StreamToken | None = None
|
||||
|
||||
if from_token_str:
|
||||
try:
|
||||
stream_token_str = extract_stream_token_from_pagination_token(
|
||||
from_token_str
|
||||
)
|
||||
from_token = await StreamToken.from_string(self.store, stream_token_str)
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing 'from' token: %s", from_token_str)
|
||||
raise SynapseError(400, "'from' parameter is invalid") from e
|
||||
|
||||
if to_token_str:
|
||||
try:
|
||||
stream_token_str = extract_stream_token_from_pagination_token(
|
||||
to_token_str
|
||||
)
|
||||
to_token = await StreamToken.from_string(self.store, stream_token_str)
|
||||
except Exception:
|
||||
raise SynapseError(400, "'to' parameter is invalid")
|
||||
|
||||
# Get the list of rooms to fetch thread updates for
|
||||
user_id = requester.user.to_string()
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
# Get the current stream token for membership lookup
|
||||
if from_token is None:
|
||||
max_stream_ordering = self.store.get_room_max_stream_ordering()
|
||||
current_token = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM, RoomStreamToken(stream=max_stream_ordering)
|
||||
)
|
||||
else:
|
||||
current_token = from_token
|
||||
|
||||
# Get room membership information to properly handle LEAVE/BAN rooms
|
||||
(
|
||||
room_membership_for_user_at_to_token_map,
|
||||
_,
|
||||
_,
|
||||
) = await self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
|
||||
user=user,
|
||||
to_token=current_token,
|
||||
from_token=None,
|
||||
)
|
||||
|
||||
# Determine which rooms to fetch updates for based on lists/room_subscriptions
|
||||
if body.lists is not None or body.room_subscriptions is not None:
|
||||
# Use sliding sync room selection logic
|
||||
sync_config = SlidingSyncConfig(
|
||||
user=user,
|
||||
requester=requester,
|
||||
lists=body.lists,
|
||||
room_subscriptions=body.room_subscriptions,
|
||||
)
|
||||
|
||||
# Use the sliding sync room list handler to get the same set of rooms
|
||||
interested_rooms = (
|
||||
await self.sliding_sync_handler.room_lists.compute_interested_rooms(
|
||||
sync_config=sync_config,
|
||||
previous_connection_state=PerConnectionState(),
|
||||
to_token=current_token,
|
||||
from_token=None,
|
||||
)
|
||||
)
|
||||
|
||||
room_ids = frozenset(interested_rooms.relevant_room_map.keys())
|
||||
else:
|
||||
# No lists or room_subscriptions, use only joined rooms
|
||||
room_ids = frozenset(
|
||||
room_id
|
||||
for room_id, membership_info in room_membership_for_user_at_to_token_map.items()
|
||||
if membership_info.membership == Membership.JOIN
|
||||
)
|
||||
|
||||
# Get thread updates using unified helper
|
||||
(
|
||||
thread_updates,
|
||||
prev_batch_token,
|
||||
) = await self.relations_handler.get_thread_updates_for_rooms(
|
||||
room_ids=room_ids,
|
||||
room_membership_map=room_membership_for_user_at_to_token_map,
|
||||
user_id=user_id,
|
||||
from_token=to_token,
|
||||
to_token=from_token if from_token else current_token,
|
||||
limit=limit,
|
||||
include_roots=body.include_roots,
|
||||
)
|
||||
|
||||
if not thread_updates:
|
||||
return 200, {"chunk": {}}
|
||||
|
||||
# Serialize thread updates using shared helper
|
||||
time_now = self.clock.time_msec()
|
||||
serialize_options = SerializeEventConfig(requester=requester)
|
||||
|
||||
serialized = await self.relations_handler.serialize_thread_updates(
|
||||
thread_updates=thread_updates,
|
||||
prev_batch_token=prev_batch_token,
|
||||
event_serializer=self.event_serializer,
|
||||
time_now=time_now,
|
||||
store=self.store,
|
||||
serialize_options=serialize_options,
|
||||
)
|
||||
|
||||
# Build response with "chunk" wrapper and "next_batch" key
|
||||
# (companion endpoint uses different key names than sliding sync)
|
||||
response: JsonDict = {"chunk": serialized["updates"]}
|
||||
if "prev_batch" in serialized:
|
||||
response["next_batch"] = serialized["prev_batch"]
|
||||
|
||||
return 200, response
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
RelationPaginationServlet(hs).register(http_server)
|
||||
ThreadsServlet(hs).register(http_server)
|
||||
if hs.config.experimental.msc4360_enabled:
|
||||
ThreadUpdatesServlet(hs).register(http_server)
|
||||
|
||||
@@ -37,6 +37,7 @@ from synapse.events.utils import (
|
||||
format_event_raw,
|
||||
)
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.handlers.relations import RelationsHandler
|
||||
from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult
|
||||
from synapse.handlers.sync import (
|
||||
ArchivedSyncResult,
|
||||
@@ -1107,6 +1108,7 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
time_now,
|
||||
extensions.threads,
|
||||
self.store,
|
||||
requester,
|
||||
)
|
||||
|
||||
return serialized_extensions
|
||||
@@ -1150,6 +1152,7 @@ async def _serialise_threads(
|
||||
time_now: int,
|
||||
threads: SlidingSyncResult.Extensions.ThreadsExtension,
|
||||
store: "DataStore",
|
||||
requester: Requester,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Serialize the threads extension response for sliding sync.
|
||||
@@ -1159,6 +1162,7 @@ async def _serialise_threads(
|
||||
time_now: The current time in milliseconds, used for event serialization.
|
||||
threads: The threads extension data containing thread updates and pagination tokens.
|
||||
store: The datastore, needed for serializing stream tokens.
|
||||
requester: The user making the request, used for transaction_id inclusion.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable dict containing:
|
||||
@@ -1169,46 +1173,24 @@ async def _serialise_threads(
|
||||
- "prev_batch": A pagination token for fetching older events in the thread.
|
||||
- "prev_batch": A pagination token for fetching older thread updates (if available).
|
||||
"""
|
||||
out: JsonDict = {}
|
||||
if not threads.updates:
|
||||
out: JsonDict = {}
|
||||
if threads.prev_batch:
|
||||
out["prev_batch"] = await threads.prev_batch.to_string(store)
|
||||
return out
|
||||
|
||||
if threads.updates:
|
||||
updates_dict: JsonDict = {}
|
||||
for room_id, thread_updates in threads.updates.items():
|
||||
room_updates: JsonDict = {}
|
||||
for thread_root_id, update in thread_updates.items():
|
||||
# Serialize the update
|
||||
update_dict: JsonDict = {}
|
||||
# Create serialization config to include transaction_id for requester's events
|
||||
serialize_options = SerializeEventConfig(requester=requester)
|
||||
|
||||
# Serialize the thread_root event if present
|
||||
if update.thread_root is not None:
|
||||
# Create a mapping of event_id to bundled_aggregations
|
||||
bundle_aggs_map = (
|
||||
{thread_root_id: update.bundled_aggregations}
|
||||
if update.bundled_aggregations
|
||||
else None
|
||||
)
|
||||
serialized_events = await event_serializer.serialize_events(
|
||||
[update.thread_root],
|
||||
time_now,
|
||||
bundle_aggregations=bundle_aggs_map,
|
||||
)
|
||||
if serialized_events:
|
||||
update_dict["thread_root"] = serialized_events[0]
|
||||
|
||||
# Add prev_batch if present
|
||||
if update.prev_batch is not None:
|
||||
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
|
||||
|
||||
room_updates[thread_root_id] = update_dict
|
||||
|
||||
updates_dict[room_id] = room_updates
|
||||
|
||||
out["updates"] = updates_dict
|
||||
|
||||
if threads.prev_batch:
|
||||
out["prev_batch"] = await threads.prev_batch.to_string(store)
|
||||
|
||||
return out
|
||||
# Use shared serialization helper (static method)
|
||||
return await RelationsHandler.serialize_thread_updates(
|
||||
thread_updates=threads.updates,
|
||||
prev_batch_token=threads.prev_batch,
|
||||
event_serializer=event_serializer,
|
||||
time_now=time_now,
|
||||
store=store,
|
||||
serialize_options=serialize_options,
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
|
||||
@@ -35,6 +35,32 @@ logger = logging.getLogger(__name__)
|
||||
MAX_LIMIT = 1000
|
||||
|
||||
|
||||
def extract_stream_token_from_pagination_token(token_str: str) -> str:
|
||||
"""
|
||||
Extract the StreamToken portion from a pagination token string.
|
||||
|
||||
Handles both:
|
||||
- StreamToken format: "s123_456_..."
|
||||
- SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /)
|
||||
|
||||
This allows clients using sliding sync to use their pos tokens
|
||||
with endpoints like /relations and /messages.
|
||||
|
||||
Args:
|
||||
token_str: The token string to parse
|
||||
|
||||
Returns:
|
||||
The StreamToken portion of the token
|
||||
"""
|
||||
if "/" in token_str:
|
||||
# SlidingSyncStreamToken format: "connection_position/stream_token"
|
||||
# Split and return just the stream_token part
|
||||
parts = token_str.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1]
|
||||
return token_str
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class PaginationConfig:
|
||||
"""A configuration object which stores pagination parameters."""
|
||||
@@ -57,32 +83,14 @@ class PaginationConfig:
|
||||
from_tok_str = parse_string(request, "from")
|
||||
to_tok_str = parse_string(request, "to")
|
||||
|
||||
# Helper function to extract StreamToken from either StreamToken or SlidingSyncStreamToken format
|
||||
def extract_stream_token(token_str: str) -> str:
|
||||
"""
|
||||
Extract the StreamToken portion from a token string.
|
||||
|
||||
Handles both:
|
||||
- StreamToken format: "s123_456_..."
|
||||
- SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /)
|
||||
|
||||
This allows clients using sliding sync to use their pos tokens
|
||||
with endpoints like /relations and /messages.
|
||||
"""
|
||||
if "/" in token_str:
|
||||
# SlidingSyncStreamToken format: "connection_position/stream_token"
|
||||
# Split and return just the stream_token part
|
||||
parts = token_str.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1]
|
||||
return token_str
|
||||
|
||||
try:
|
||||
from_tok = None
|
||||
if from_tok_str == "END":
|
||||
from_tok = None # For backwards compat.
|
||||
elif from_tok_str:
|
||||
stream_token_str = extract_stream_token(from_tok_str)
|
||||
stream_token_str = extract_stream_token_from_pagination_token(
|
||||
from_tok_str
|
||||
)
|
||||
from_tok = await StreamToken.from_string(store, stream_token_str)
|
||||
except Exception:
|
||||
raise SynapseError(400, "'from' parameter is invalid")
|
||||
@@ -90,7 +98,9 @@ class PaginationConfig:
|
||||
try:
|
||||
to_tok = None
|
||||
if to_tok_str:
|
||||
stream_token_str = extract_stream_token(to_tok_str)
|
||||
stream_token_str = extract_stream_token_from_pagination_token(
|
||||
to_tok_str
|
||||
)
|
||||
to_tok = await StreamToken.from_string(store, stream_token_str)
|
||||
except Exception:
|
||||
raise SynapseError(400, "'to' parameter is invalid")
|
||||
|
||||
@@ -37,7 +37,8 @@ from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.handlers.relations import BundledAggregations
|
||||
from synapse.handlers.relations import BundledAggregations, ThreadUpdate
|
||||
|
||||
from synapse.types import (
|
||||
DeviceListUpdates,
|
||||
JsonDict,
|
||||
@@ -409,30 +410,7 @@ class SlidingSyncResult:
|
||||
to paginate through older thread updates.
|
||||
"""
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadUpdate:
|
||||
"""Information about a single thread that has new activity.
|
||||
|
||||
Attributes:
|
||||
thread_root: The thread root event, if requested via include_roots in the
|
||||
request. This is the event that started the thread.
|
||||
prev_batch: A pagination token (exclusive) for fetching older events in this
|
||||
specific thread. Only present if the thread has multiple updates in the
|
||||
sync window. This token can be used with the /relations endpoint with
|
||||
dir=b to paginate backwards through the thread's history.
|
||||
bundled_aggregations: Bundled aggregations for the thread root event,
|
||||
including the latest_event in the thread (found in
|
||||
unsigned.m.relations.m.thread). Only present if thread_root is included.
|
||||
"""
|
||||
|
||||
thread_root: EventBase | None
|
||||
prev_batch: StreamToken | None
|
||||
bundled_aggregations: "BundledAggregations | None" = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.thread_root) or bool(self.prev_batch)
|
||||
|
||||
updates: Mapping[str, Mapping[str, ThreadUpdate]] | None
|
||||
updates: Mapping[str, Mapping[str, "ThreadUpdate"]] | None
|
||||
prev_batch: StreamToken | None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
|
||||
@@ -924,6 +924,250 @@ class SlidingSyncThreadsExtensionTestCase(SlidingSyncBase):
|
||||
# Verify the thread root event is present
|
||||
self.assertIn("thread_root", thread_updates[thread_root_id])
|
||||
|
||||
def test_thread_updates_initial_sync(self) -> None:
|
||||
"""
|
||||
Test that prev_batch from the threads extension response can be used
|
||||
with the /thread_updates endpoint to get additional thread updates during
|
||||
initial sync. This verifies:
|
||||
1. The from parameter boundary is exclusive (no duplicates)
|
||||
2. Using prev_batch as 'from' provides complete coverage (no gaps)
|
||||
3. Works correctly with different numbers of threads
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create 5 thread roots
|
||||
thread_ids = []
|
||||
for i in range(5):
|
||||
thread_root_id = self.helper.send(
|
||||
room_id, body=f"Thread {i}", tok=user1_tok
|
||||
)["event_id"]
|
||||
thread_ids.append(thread_root_id)
|
||||
|
||||
# Add reply to each thread
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": f"Reply to thread {i}",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Do initial sync with threads extension enabled and limit=2
|
||||
sync_body = {
|
||||
"lists": {
|
||||
"all-rooms": {
|
||||
"ranges": [[0, 10]],
|
||||
"required_state": [],
|
||||
"timeline_limit": 0,
|
||||
}
|
||||
},
|
||||
"extensions": {
|
||||
EXT_NAME: {
|
||||
"enabled": True,
|
||||
"limit": 2,
|
||||
}
|
||||
},
|
||||
}
|
||||
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
|
||||
|
||||
# Should get 2 thread updates
|
||||
thread_updates = response_body["extensions"][EXT_NAME]["updates"][room_id]
|
||||
self.assertEqual(len(thread_updates), 2)
|
||||
first_sync_threads = set(thread_updates.keys())
|
||||
|
||||
# Get the top-level prev_batch token from the extension
|
||||
self.assertIn("prev_batch", response_body["extensions"][EXT_NAME])
|
||||
prev_batch = response_body["extensions"][EXT_NAME]["prev_batch"]
|
||||
|
||||
# Use prev_batch with /thread_updates endpoint to get remaining updates
|
||||
# Note: prev_batch should be used as 'from' parameter (upper bound for backward pagination)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from={prev_batch}",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Should get the remaining 3 thread updates
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(room_id, chunk)
|
||||
self.assertEqual(len(chunk[room_id]), 3)
|
||||
|
||||
thread_updates_response_threads = set(chunk[room_id].keys())
|
||||
|
||||
# Verify no overlap - the from parameter boundary should be exclusive
|
||||
self.assertEqual(
|
||||
len(first_sync_threads & thread_updates_response_threads),
|
||||
0,
|
||||
"from parameter boundary should be exclusive - no thread should appear in both responses",
|
||||
)
|
||||
|
||||
# Verify no gaps - all threads should be accounted for
|
||||
all_threads = set(thread_ids)
|
||||
combined_threads = first_sync_threads | thread_updates_response_threads
|
||||
self.assertEqual(
|
||||
combined_threads,
|
||||
all_threads,
|
||||
"Combined responses should include all thread updates with no gaps",
|
||||
)
|
||||
|
||||
def test_thread_updates_incremental_sync(self) -> None:
|
||||
"""
|
||||
Test the intended usage pattern from MSC4360: using prev_batch as 'from'
|
||||
and a previous sync pos as 'to' with /thread_updates to fill gaps between
|
||||
syncs. This verifies that using both bounds together provides complete
|
||||
coverage with no gaps or duplicates.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create 3 threads initially
|
||||
initial_thread_ids = []
|
||||
for i in range(3):
|
||||
thread_root_id = self.helper.send(
|
||||
room_id, body=f"Thread {i}", tok=user1_tok
|
||||
)["event_id"]
|
||||
initial_thread_ids.append(thread_root_id)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": f"Reply to thread {i}",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# First sync
|
||||
sync_body = {
|
||||
"lists": {
|
||||
"all-rooms": {
|
||||
"ranges": [[0, 10]],
|
||||
"required_state": [],
|
||||
"timeline_limit": 0,
|
||||
}
|
||||
},
|
||||
"extensions": {
|
||||
EXT_NAME: {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
response_body, pos1 = self.do_sync(sync_body, tok=user1_tok)
|
||||
|
||||
# Should get 3 thread updates
|
||||
first_sync_threads = set(
|
||||
response_body["extensions"][EXT_NAME]["updates"][room_id].keys()
|
||||
)
|
||||
self.assertEqual(len(first_sync_threads), 3)
|
||||
|
||||
# Create 3 more threads after the first sync
|
||||
new_thread_ids = []
|
||||
for i in range(3, 6):
|
||||
thread_root_id = self.helper.send(
|
||||
room_id, body=f"Thread {i}", tok=user1_tok
|
||||
)["event_id"]
|
||||
new_thread_ids.append(thread_root_id)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": f"Reply to thread {i}",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Second sync with limit=1 to get only some of the new threads
|
||||
sync_body_with_limit = {
|
||||
"lists": {
|
||||
"all-rooms": {
|
||||
"ranges": [[0, 10]],
|
||||
"required_state": [],
|
||||
"timeline_limit": 0,
|
||||
}
|
||||
},
|
||||
"extensions": {
|
||||
EXT_NAME: {
|
||||
"enabled": True,
|
||||
"limit": 1,
|
||||
}
|
||||
},
|
||||
}
|
||||
response_body, pos2 = self.do_sync(
|
||||
sync_body_with_limit, tok=user1_tok, since=pos1
|
||||
)
|
||||
|
||||
# Should get 1 thread update
|
||||
second_sync_threads = set(
|
||||
response_body["extensions"][EXT_NAME]["updates"][room_id].keys()
|
||||
)
|
||||
self.assertEqual(len(second_sync_threads), 1)
|
||||
|
||||
# Get prev_batch from the extension
|
||||
self.assertIn("prev_batch", response_body["extensions"][EXT_NAME])
|
||||
prev_batch = response_body["extensions"][EXT_NAME]["prev_batch"]
|
||||
|
||||
# Now use /thread_updates with from=prev_batch and to=pos1
|
||||
# This should get the 2 remaining new threads (created after pos1, not returned in second sync)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from={prev_batch}&to={pos1}",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(room_id, chunk)
|
||||
thread_updates_threads = set(chunk[room_id].keys())
|
||||
|
||||
# Should get exactly 2 threads
|
||||
self.assertEqual(len(thread_updates_threads), 2)
|
||||
|
||||
# Verify no overlap with second sync
|
||||
self.assertEqual(
|
||||
len(second_sync_threads & thread_updates_threads),
|
||||
0,
|
||||
"No thread should appear in both second sync and thread_updates responses",
|
||||
)
|
||||
|
||||
# Verify no overlap with first sync (to=pos1 should exclude those)
|
||||
self.assertEqual(
|
||||
len(first_sync_threads & thread_updates_threads),
|
||||
0,
|
||||
"Threads from first sync should not appear in thread_updates (to=pos1 excludes them)",
|
||||
)
|
||||
|
||||
# Verify no gaps - all new threads should be accounted for
|
||||
all_new_threads = set(new_thread_ids)
|
||||
combined_new_threads = second_sync_threads | thread_updates_threads
|
||||
self.assertEqual(
|
||||
combined_new_threads,
|
||||
all_new_threads,
|
||||
"Combined responses should include all new thread updates with no gaps",
|
||||
)
|
||||
|
||||
def test_threads_only_from_rooms_in_list(self) -> None:
|
||||
"""
|
||||
Test that thread updates are only returned for rooms that are in the
|
||||
|
||||
957
tests/rest/client/test_thread_updates.py
Normal file
957
tests/rest/client/test_thread_updates.py
Normal file
@@ -0,0 +1,957 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
import logging
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import RelationTypes
|
||||
from synapse.rest.client import login, relations, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Test the /thread_updates companion endpoint (MSC4360).
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
relations.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
config["experimental_features"] = {"msc4360_enabled": True}
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
def test_no_updates_for_new_user(self) -> None:
|
||||
"""
|
||||
Test that a user with no thread updates gets an empty response.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Request thread updates
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
content={"include_roots": True},
|
||||
access_token=user1_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert empty chunk and no next_batch
|
||||
self.assertEqual(channel.json_body["chunk"], {})
|
||||
self.assertNotIn("next_batch", channel.json_body)
|
||||
|
||||
def test_single_thread_update(self) -> None:
|
||||
"""
|
||||
Test that a single thread with one reply appears in the response.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create thread root
|
||||
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
|
||||
thread_root_id = thread_root_resp["event_id"]
|
||||
|
||||
# Add reply to thread
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply 1",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert thread is present
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(room_id, chunk)
|
||||
self.assertIn(thread_root_id, chunk[room_id])
|
||||
|
||||
# Assert thread root is included
|
||||
thread_update = chunk[room_id][thread_root_id]
|
||||
self.assertIn("thread_root", thread_update)
|
||||
self.assertEqual(thread_update["thread_root"]["event_id"], thread_root_id)
|
||||
|
||||
# Assert prev_batch is NOT present (only 1 update - the reply)
|
||||
self.assertNotIn("prev_batch", thread_update)
|
||||
|
||||
def test_multiple_threads_single_room(self) -> None:
|
||||
"""
|
||||
Test that multiple threads in the same room are grouped correctly.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create two threads
|
||||
thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Add replies to both threads
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 1",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread1_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 2",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread2_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert both threads are in the same room
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(room_id, chunk)
|
||||
self.assertEqual(len(chunk), 1, "Should only have one room")
|
||||
self.assertEqual(len(chunk[room_id]), 2, "Should have two threads")
|
||||
self.assertIn(thread1_root_id, chunk[room_id])
|
||||
self.assertIn(thread2_root_id, chunk[room_id])
|
||||
|
||||
def test_threads_across_multiple_rooms(self) -> None:
|
||||
"""
|
||||
Test that threads from different rooms are grouped by room_id.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_a_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
room_b_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in both rooms
|
||||
thread_a_root_id = self.helper.send(room_a_id, body="Thread A", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread_b_root_id = self.helper.send(room_b_id, body="Thread B", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Add replies
|
||||
self.helper.send_event(
|
||||
room_a_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to A",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_a_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
room_b_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to B",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_b_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert both rooms are present with their threads
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertEqual(len(chunk), 2, "Should have two rooms")
|
||||
self.assertIn(room_a_id, chunk)
|
||||
self.assertIn(room_b_id, chunk)
|
||||
self.assertIn(thread_a_root_id, chunk[room_a_id])
|
||||
self.assertIn(thread_b_root_id, chunk[room_b_id])
|
||||
|
||||
def test_pagination_with_from_token(self) -> None:
|
||||
"""
|
||||
Test that pagination works using the next_batch token.
|
||||
This verifies that multiple calls to /thread_updates return all thread
|
||||
updates with no duplicates and no gaps.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create many threads (more than default limit)
|
||||
thread_ids = []
|
||||
for i in range(5):
|
||||
thread_root_id = self.helper.send(
|
||||
room_id, body=f"Thread {i}", tok=user1_tok
|
||||
)["event_id"]
|
||||
thread_ids.append(thread_root_id)
|
||||
|
||||
# Add reply
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": f"Reply to thread {i}",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request first page with small limit
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Should have 2 threads and a next_batch token
|
||||
first_page_threads = set(channel.json_body["chunk"][room_id].keys())
|
||||
self.assertEqual(len(first_page_threads), 2)
|
||||
self.assertIn("next_batch", channel.json_body)
|
||||
|
||||
next_batch = channel.json_body["next_batch"]
|
||||
|
||||
# Request second page
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
second_page_threads = set(channel.json_body["chunk"][room_id].keys())
|
||||
self.assertEqual(len(second_page_threads), 2)
|
||||
|
||||
# Verify no overlap
|
||||
self.assertEqual(
|
||||
len(first_page_threads & second_page_threads),
|
||||
0,
|
||||
"Pages should not have overlapping threads",
|
||||
)
|
||||
|
||||
# Request third page to get the remaining thread
|
||||
self.assertIn("next_batch", channel.json_body)
|
||||
next_batch_2 = channel.json_body["next_batch"]
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
third_page_threads = set(channel.json_body["chunk"][room_id].keys())
|
||||
self.assertEqual(len(third_page_threads), 1)
|
||||
|
||||
# Verify no overlap between any pages
|
||||
self.assertEqual(len(first_page_threads & third_page_threads), 0)
|
||||
self.assertEqual(len(second_page_threads & third_page_threads), 0)
|
||||
|
||||
# Verify no gaps - all threads should be accounted for across all pages
|
||||
all_threads = set(thread_ids)
|
||||
combined_threads = first_page_threads | second_page_threads | third_page_threads
|
||||
self.assertEqual(
|
||||
combined_threads,
|
||||
all_threads,
|
||||
"Combined pages should include all thread updates with no gaps",
|
||||
)
|
||||
|
||||
def test_invalid_dir_parameter(self) -> None:
|
||||
"""
|
||||
Test that forward pagination (dir=f) is rejected with an error.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Request with forward direction should fail
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
def test_invalid_limit_parameter(self) -> None:
|
||||
"""
|
||||
Test that invalid limit values are rejected.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Zero limit should fail
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
# Negative limit should fail
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
def test_invalid_pagination_tokens(self) -> None:
|
||||
"""
|
||||
Test that invalid from/to tokens are rejected with appropriate errors.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Invalid from token
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
# Invalid to token
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
def test_to_token_filtering(self) -> None:
|
||||
"""
|
||||
Test that the to_token parameter correctly limits pagination to updates
|
||||
newer than the to_token (since we paginate backwards from newest to oldest).
|
||||
This also verifies the to_token boundary is exclusive - updates at exactly
|
||||
the to_token position should not be included (as they were already returned
|
||||
in a previous response that synced up to that position).
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create two thread roots
|
||||
thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Send replies to both threads
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 1",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread1_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 2",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread2_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request with limit=1 to get only the latest thread update
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertIn("next_batch", channel.json_body)
|
||||
|
||||
# next_batch points to before the update we just received
|
||||
next_batch = channel.json_body["next_batch"]
|
||||
first_response_threads = set(channel.json_body["chunk"][room_id].keys())
|
||||
|
||||
# Request again with to=next_batch (lower bound for backward pagination) and no
|
||||
# limit.
|
||||
# This should get only the same thread updates as before, not the additional
|
||||
# update.
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(room_id, chunk)
|
||||
# Should have exactly one thread update
|
||||
self.assertEqual(len(chunk[room_id]), 1)
|
||||
|
||||
second_response_threads = set(chunk[room_id].keys())
|
||||
|
||||
# Verify no overlap - the from parameter boundary should be exclusive
|
||||
self.assertEqual(
|
||||
first_response_threads,
|
||||
second_response_threads,
|
||||
"to parameter boundary should be exclusive - both responses should be identical",
|
||||
)
|
||||
|
||||
def test_bundled_aggregations_on_thread_roots(self) -> None:
|
||||
"""
|
||||
Test that thread root events include bundled aggregations with latest thread event.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create thread root
|
||||
thread_root_id = self.helper.send(room_id, body="Thread root", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Send replies to create bundled aggregation data
|
||||
for i in range(2):
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": f"Reply {i + 1}",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Check that thread root has bundled aggregations with latest event
|
||||
chunk = channel.json_body["chunk"]
|
||||
thread_update = chunk[room_id][thread_root_id]
|
||||
thread_root_event = thread_update["thread_root"]
|
||||
|
||||
# Should have unsigned data with latest thread event content
|
||||
self.assertIn("unsigned", thread_root_event)
|
||||
self.assertIn("m.relations", thread_root_event["unsigned"])
|
||||
relations = thread_root_event["unsigned"]["m.relations"]
|
||||
self.assertIn(RelationTypes.THREAD, relations)
|
||||
|
||||
# Check latest event is present in bundled aggregations
|
||||
thread_summary = relations[RelationTypes.THREAD]
|
||||
self.assertIn("latest_event", thread_summary)
|
||||
latest_event = thread_summary["latest_event"]
|
||||
self.assertEqual(latest_event["content"]["body"], "Reply 2")
|
||||
|
||||
def test_only_joined_rooms(self) -> None:
|
||||
"""
|
||||
Test that thread updates only include rooms where the user is currently joined.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
user2_id = self.register_user("user2", "pass")
|
||||
user2_tok = self.login(user2_id, "pass")
|
||||
|
||||
# Create two rooms, user1 joins both
|
||||
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
room2_id = self.helper.create_room_as(user2_id, tok=user2_tok)
|
||||
self.helper.join(room2_id, user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in both rooms
|
||||
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user2_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Add replies to both threads
|
||||
self.helper.send_event(
|
||||
room1_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 1",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread1_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
room2_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 2",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread2_root_id,
|
||||
},
|
||||
},
|
||||
tok=user2_tok,
|
||||
)
|
||||
|
||||
# User1 leaves room2
|
||||
self.helper.leave(room2_id, user1_id, tok=user1_tok)
|
||||
|
||||
# Request thread updates for user1 - should only get room1
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should only have room1, not room2
|
||||
self.assertIn(room1_id, chunk)
|
||||
self.assertNotIn(room2_id, chunk)
|
||||
self.assertIn(thread1_root_id, chunk[room1_id])
|
||||
|
||||
def test_room_filtering_with_lists(self) -> None:
|
||||
"""
|
||||
Test that room filtering works correctly using the lists parameter.
|
||||
This verifies that thread updates are only returned for rooms matching
|
||||
the provided filters.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Create an encrypted room and an unencrypted room
|
||||
encrypted_room_id = self.helper.create_room_as(
|
||||
user1_id,
|
||||
tok=user1_tok,
|
||||
extra_content={
|
||||
"initial_state": [
|
||||
{
|
||||
"type": "m.room.encryption",
|
||||
"state_key": "",
|
||||
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
unencrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in both rooms
|
||||
enc_thread_root_id = self.helper.send(
|
||||
encrypted_room_id, body="Encrypted thread", tok=user1_tok
|
||||
)["event_id"]
|
||||
unenc_thread_root_id = self.helper.send(
|
||||
unencrypted_room_id, body="Unencrypted thread", tok=user1_tok
|
||||
)["event_id"]
|
||||
|
||||
# Add replies to both threads
|
||||
self.helper.send_event(
|
||||
encrypted_room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply in encrypted room",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": enc_thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
unencrypted_room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply in unencrypted room",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": unenc_thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates with filter for encrypted rooms only
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={
|
||||
"lists": {
|
||||
"encrypted_list": {
|
||||
"ranges": [[0, 99]],
|
||||
"required_state": [["m.room.encryption", ""]],
|
||||
"timeline_limit": 10,
|
||||
"filters": {"is_encrypted": True},
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should only include the encrypted room
|
||||
self.assertIn(encrypted_room_id, chunk)
|
||||
self.assertNotIn(unencrypted_room_id, chunk)
|
||||
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
|
||||
|
||||
def test_room_filtering_with_room_subscriptions(self) -> None:
|
||||
"""
|
||||
Test that room filtering works correctly using the room_subscriptions parameter.
|
||||
This verifies that thread updates are only returned for explicitly subscribed rooms.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Create three rooms
|
||||
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
room3_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in all three rooms
|
||||
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread3_root_id = self.helper.send(room3_id, body="Thread 3", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Add replies to all threads
|
||||
for room_id, thread_root_id in [
|
||||
(room1_id, thread1_root_id),
|
||||
(room2_id, thread2_root_id),
|
||||
(room3_id, thread3_root_id),
|
||||
]:
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates with subscription to only room1 and room2
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={
|
||||
"room_subscriptions": {
|
||||
room1_id: {
|
||||
"required_state": [["m.room.name", ""]],
|
||||
"timeline_limit": 10,
|
||||
},
|
||||
room2_id: {
|
||||
"required_state": [["m.room.name", ""]],
|
||||
"timeline_limit": 10,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should only include room1 and room2, not room3
|
||||
self.assertIn(room1_id, chunk)
|
||||
self.assertIn(room2_id, chunk)
|
||||
self.assertNotIn(room3_id, chunk)
|
||||
self.assertIn(thread1_root_id, chunk[room1_id])
|
||||
self.assertIn(thread2_root_id, chunk[room2_id])
|
||||
|
||||
def test_room_filtering_with_lists_and_room_subscriptions(self) -> None:
|
||||
"""
|
||||
Test that room filtering works correctly when both lists and room_subscriptions
|
||||
are provided. The union of rooms from both should be included.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Create an encrypted room and two unencrypted rooms
|
||||
encrypted_room_id = self.helper.create_room_as(
|
||||
user1_id,
|
||||
tok=user1_tok,
|
||||
extra_content={
|
||||
"initial_state": [
|
||||
{
|
||||
"type": "m.room.encryption",
|
||||
"state_key": "",
|
||||
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
unencrypted_room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
unencrypted_room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in all three rooms
|
||||
enc_thread_root_id = self.helper.send(
|
||||
encrypted_room_id, body="Encrypted thread", tok=user1_tok
|
||||
)["event_id"]
|
||||
unenc1_thread_root_id = self.helper.send(
|
||||
unencrypted_room1_id, body="Unencrypted thread 1", tok=user1_tok
|
||||
)["event_id"]
|
||||
unenc2_thread_root_id = self.helper.send(
|
||||
unencrypted_room2_id, body="Unencrypted thread 2", tok=user1_tok
|
||||
)["event_id"]
|
||||
|
||||
# Add replies to all threads
|
||||
for room_id, thread_root_id in [
|
||||
(encrypted_room_id, enc_thread_root_id),
|
||||
(unencrypted_room1_id, unenc1_thread_root_id),
|
||||
(unencrypted_room2_id, unenc2_thread_root_id),
|
||||
]:
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates with:
|
||||
# - lists: filter for encrypted rooms
|
||||
# - room_subscriptions: explicitly subscribe to unencrypted_room1_id
|
||||
# Expected: should get both encrypted_room_id (from list) and unencrypted_room1_id
|
||||
# (from subscription), but NOT unencrypted_room2_id
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={
|
||||
"lists": {
|
||||
"encrypted_list": {
|
||||
"ranges": [[0, 99]],
|
||||
"required_state": [["m.room.encryption", ""]],
|
||||
"timeline_limit": 10,
|
||||
"filters": {"is_encrypted": True},
|
||||
}
|
||||
},
|
||||
"room_subscriptions": {
|
||||
unencrypted_room1_id: {
|
||||
"required_state": [["m.room.name", ""]],
|
||||
"timeline_limit": 10,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should include encrypted_room_id (from list filter) and unencrypted_room1_id
|
||||
# (from subscription), but not unencrypted_room2_id
|
||||
self.assertIn(encrypted_room_id, chunk)
|
||||
self.assertIn(unencrypted_room1_id, chunk)
|
||||
self.assertNotIn(unencrypted_room2_id, chunk)
|
||||
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
|
||||
self.assertIn(unenc1_thread_root_id, chunk[unencrypted_room1_id])
|
||||
|
||||
def test_threads_not_returned_after_leaving_room(self) -> None:
|
||||
"""
|
||||
Test that thread updates are properly bounded when a user leaves a room.
|
||||
|
||||
Users should see thread updates that occurred up to the point they left,
|
||||
but NOT updates that occurred after they left.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
user2_id = self.register_user("user2", "pass")
|
||||
user2_tok = self.login(user2_id, "pass")
|
||||
|
||||
# Create room and both users join
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
self.helper.join(room_id, user2_id, tok=user2_tok)
|
||||
|
||||
# Create thread
|
||||
res = self.helper.send(room_id, body="Thread root", tok=user1_tok)
|
||||
thread_root = res["event_id"]
|
||||
|
||||
# Reply in thread while user2 is joined
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply 1 while user2 joined",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# User2 leaves the room
|
||||
self.helper.leave(room_id, user2_id, tok=user2_tok)
|
||||
|
||||
# Another reply after user2 left
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply 2 after user2 left",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# User2 gets thread updates with an explicit room subscription
|
||||
# (We need to explicitly subscribe to the room to include it after leaving;
|
||||
# otherwise only joined rooms are returned)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=100",
|
||||
{
|
||||
"room_subscriptions": {
|
||||
room_id: {
|
||||
"required_state": [],
|
||||
"timeline_limit": 0,
|
||||
}
|
||||
}
|
||||
},
|
||||
access_token=user2_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert: User2 SHOULD see Reply 1 (happened while joined) but NOT Reply 2 (after leaving)
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(
|
||||
room_id,
|
||||
chunk,
|
||||
"Thread updates should include the room user2 left",
|
||||
)
|
||||
self.assertIn(
|
||||
thread_root,
|
||||
chunk[room_id],
|
||||
"Thread root should be in the updates",
|
||||
)
|
||||
|
||||
# Verify that only a single update was seen (Reply 1) by checking that there's
|
||||
# no prev_batch token. If Reply 2 was also included, there would be multiple
|
||||
# updates and a prev_batch token would be present.
|
||||
thread_update = chunk[room_id][thread_root]
|
||||
self.assertNotIn(
|
||||
"prev_batch",
|
||||
thread_update,
|
||||
"No prev_batch should be present since only one update (Reply 1) is visible",
|
||||
)
|
||||
Reference in New Issue
Block a user