Compare commits
35 Commits
devon/acl-
...
devon/ssex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65fc3861b | ||
|
|
eff2503adc | ||
|
|
ba3c8a5f3e | ||
|
|
c826157c52 | ||
|
|
57c884ec83 | ||
|
|
6fa43cb0b4 | ||
|
|
f778ac32c1 | ||
|
|
003fc725db | ||
|
|
934f99a694 | ||
|
|
78e8ec6161 | ||
|
|
f59419377d | ||
|
|
a3b34dfafd | ||
|
|
cb82a4a687 | ||
|
|
0c0ece9612 | ||
|
|
46e3f6756c | ||
|
|
dedd6e35e6 | ||
|
|
a3c7b3ecb9 | ||
|
|
bf594a28a8 | ||
|
|
89f75cc70f | ||
|
|
2f8568866e | ||
|
|
af992dd0e2 | ||
|
|
c757969597 | ||
|
|
87e9fe8b38 | ||
|
|
4cb0eeabdf | ||
|
|
4d7826b006 | ||
|
|
ab7e5a2b17 | ||
|
|
4c51247cb3 | ||
|
|
4dd82e581a | ||
|
|
6e69338abc | ||
|
|
79ea4bed33 | ||
|
|
9ef4ca173e | ||
|
|
24b38733df | ||
|
|
4602b56643 | ||
|
|
6c460b3eae | ||
|
|
cd4f4223de |
1
changelog.d/19005.feature
Normal file
1
changelog.d/19005.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add experimental support for MSC4360: Sliding Sync Threads Extension.
|
||||
1
changelog.d/19041.feature
Normal file
1
changelog.d/19041.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add companion endpoint for MSC4360: Sliding Sync Threads Extension.
|
||||
@@ -272,6 +272,9 @@ class EventContentFields:
|
||||
M_TOPIC: Final = "m.topic"
|
||||
M_TEXT: Final = "m.text"
|
||||
|
||||
# Event relations
|
||||
RELATIONS: Final = "m.relates_to"
|
||||
|
||||
|
||||
class EventUnsignedContentFields:
|
||||
"""Fields found inside the 'unsigned' data on events"""
|
||||
@@ -360,3 +363,10 @@ class Direction(enum.Enum):
|
||||
class ProfileFields:
|
||||
DISPLAYNAME: Final = "displayname"
|
||||
AVATAR_URL: Final = "avatar_url"
|
||||
|
||||
|
||||
class MRelatesToFields:
|
||||
"""Fields found inside m.relates_to content blocks."""
|
||||
|
||||
EVENT_ID: Final = "event_id"
|
||||
REL_TYPE: Final = "rel_type"
|
||||
|
||||
@@ -593,3 +593,6 @@ class ExperimentalConfig(Config):
|
||||
# MSC4306: Thread Subscriptions
|
||||
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
|
||||
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)
|
||||
|
||||
# MSC4360: Threads Extension to Sliding Sync
|
||||
self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False)
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#
|
||||
import enum
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
@@ -30,24 +31,59 @@ from typing import (
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import Direction, EventTypes, RelationTypes
|
||||
from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase, relation_from_event
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
|
||||
from synapse.storage.databases.main.relations import (
|
||||
ThreadsNextBatch,
|
||||
ThreadUpdateInfo,
|
||||
_RelatedEvent,
|
||||
)
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomStreamToken,
|
||||
StreamKeyType,
|
||||
StreamToken,
|
||||
UserID,
|
||||
)
|
||||
from synapse.util.async_helpers import gather_results
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.events.utils import EventClientSerializer
|
||||
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type aliases for thread update processing
|
||||
ThreadUpdatesMap = dict[str, list[ThreadUpdateInfo]]
|
||||
ThreadRootsMap = dict[str, EventBase]
|
||||
AggregationsMap = dict[str, "BundledAggregations"]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadUpdate:
|
||||
"""
|
||||
Data for a single thread update.
|
||||
|
||||
Attributes:
|
||||
thread_root: The thread root event, or None if not requested/not visible
|
||||
prev_batch: Per-thread pagination token for fetching older events in this thread
|
||||
bundled_aggregations: Bundled aggregations for the thread root event
|
||||
"""
|
||||
|
||||
thread_root: EventBase | None
|
||||
prev_batch: StreamToken | None
|
||||
bundled_aggregations: "BundledAggregations | None" = None
|
||||
|
||||
|
||||
class ThreadsListInclude(str, enum.Enum):
|
||||
"""Valid values for the 'include' flag of /threads."""
|
||||
@@ -105,8 +141,6 @@ class RelationsHandler:
|
||||
) -> JsonDict:
|
||||
"""Get related events of a event, ordered by topological ordering.
|
||||
|
||||
TODO Accept a PaginationConfig instead of individual pagination parameters.
|
||||
|
||||
Args:
|
||||
requester: The user requesting the relations.
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
@@ -546,6 +580,367 @@ class RelationsHandler:
|
||||
|
||||
return results
|
||||
|
||||
async def _filter_thread_updates_for_user(
|
||||
self,
|
||||
all_thread_updates: ThreadUpdatesMap,
|
||||
user_id: str,
|
||||
) -> ThreadUpdatesMap:
|
||||
"""Process thread updates by filtering for visibility.
|
||||
|
||||
Takes raw thread updates from storage and filters them based on whether the
|
||||
user can see the events. Preserves the ordering of updates within each thread.
|
||||
|
||||
Args:
|
||||
all_thread_updates: Map of thread_id to list of ThreadUpdateInfo objects
|
||||
user_id: The user ID to filter events for
|
||||
|
||||
Returns:
|
||||
Filtered map of thread_id to list of ThreadUpdateInfo objects, containing
|
||||
only updates for events the user can see.
|
||||
"""
|
||||
# Build a mapping of event_id -> (thread_id, update) for efficient lookup
|
||||
# during visibility filtering.
|
||||
event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {}
|
||||
for thread_id, updates in all_thread_updates.items():
|
||||
for update in updates:
|
||||
event_to_thread_map[update.event_id] = (thread_id, update)
|
||||
|
||||
# Fetch and filter events for visibility
|
||||
all_events = await self._main_store.get_events_as_list(
|
||||
event_to_thread_map.keys()
|
||||
)
|
||||
filtered_events = await filter_events_for_client(
|
||||
self._storage_controllers, user_id, all_events
|
||||
)
|
||||
|
||||
# Rebuild thread updates from filtered events
|
||||
filtered_updates: ThreadUpdatesMap = defaultdict(list)
|
||||
for event in filtered_events:
|
||||
if event.event_id in event_to_thread_map:
|
||||
thread_id, update = event_to_thread_map[event.event_id]
|
||||
filtered_updates[thread_id].append(update)
|
||||
|
||||
return filtered_updates
|
||||
|
||||
def _build_thread_updates_response(
|
||||
self,
|
||||
filtered_updates: ThreadUpdatesMap,
|
||||
thread_root_event_map: ThreadRootsMap,
|
||||
aggregations_map: AggregationsMap,
|
||||
global_prev_batch_token: StreamToken | None,
|
||||
) -> dict[str, dict[str, ThreadUpdate]]:
|
||||
"""Build thread update response structure with per-thread prev_batch tokens.
|
||||
|
||||
Args:
|
||||
filtered_updates: Map of thread_root_id to list of ThreadUpdateInfo
|
||||
thread_root_event_map: Map of thread_root_id to EventBase
|
||||
aggregations_map: Map of thread_root_id to BundledAggregations
|
||||
global_prev_batch_token: Global pagination token, or None if no more results
|
||||
|
||||
Returns:
|
||||
Map of room_id to thread_root_id to ThreadUpdate
|
||||
"""
|
||||
thread_updates: dict[str, dict[str, ThreadUpdate]] = {}
|
||||
|
||||
for thread_root_id, updates in filtered_updates.items():
|
||||
# We only care about the latest update for the thread
|
||||
# Updates are already sorted by stream_ordering DESC from the database query,
|
||||
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
|
||||
# the latest event for each thread.
|
||||
latest_update = updates[0]
|
||||
room_id = latest_update.room_id
|
||||
|
||||
# Generate per-thread prev_batch token if this thread has multiple visible updates
|
||||
# or if we hit the global limit.
|
||||
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
|
||||
# we only saw 1 update for them. This is to cover the case where we only saw
|
||||
# a single update for a given thread, but the global limit prevents us from
|
||||
# obtaining other updates which would have otherwise been included in the range.
|
||||
per_thread_prev_batch = None
|
||||
if len(updates) > 1 or global_prev_batch_token is not None:
|
||||
# Create a token pointing to one position before the latest event's stream position.
|
||||
# This makes it exclusive - /relations with dir=b won't return the latest event again.
|
||||
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
|
||||
per_thread_prev_batch = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM,
|
||||
RoomStreamToken(stream=latest_update.stream_ordering - 1),
|
||||
)
|
||||
|
||||
if room_id not in thread_updates:
|
||||
thread_updates[room_id] = {}
|
||||
|
||||
thread_updates[room_id][thread_root_id] = ThreadUpdate(
|
||||
thread_root=thread_root_event_map.get(thread_root_id),
|
||||
prev_batch=per_thread_prev_batch,
|
||||
bundled_aggregations=aggregations_map.get(thread_root_id),
|
||||
)
|
||||
|
||||
return thread_updates
|
||||
|
||||
async def _fetch_thread_updates(
|
||||
self,
|
||||
room_ids: frozenset[str],
|
||||
room_membership_map: Mapping[str, "RoomsForUserType"],
|
||||
from_token: StreamToken | None,
|
||||
to_token: StreamToken,
|
||||
limit: int,
|
||||
exclude_thread_ids: set[str] | None = None,
|
||||
) -> tuple[ThreadUpdatesMap, StreamToken | None]:
|
||||
"""Fetch thread updates across multiple rooms, handling membership states properly.
|
||||
|
||||
This method separates rooms based on membership status (LEAVE/BAN vs others)
|
||||
and queries them appropriately to prevent data leaks. For rooms where the user
|
||||
has left or been banned, we bound the query to their leave/ban event position.
|
||||
|
||||
Args:
|
||||
room_ids: The set of room IDs to fetch thread updates for
|
||||
room_membership_map: Map of room_id to RoomsForUserType containing membership info
|
||||
from_token: Lower bound (exclusive) for the query, or None for no lower bound
|
||||
to_token: Upper bound for the query (for joined/invited/knocking rooms)
|
||||
limit: Maximum number of thread updates to return across all rooms
|
||||
exclude_thread_ids: Optional set of thread IDs to exclude from results
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- Map of thread_id to list of ThreadUpdateInfo objects
|
||||
- Global prev_batch token if there are more results, None otherwise
|
||||
"""
|
||||
# Separate rooms based on membership to handle LEAVE/BAN rooms specially
|
||||
leave_ban_rooms: set[str] = set()
|
||||
other_rooms: set[str] = set()
|
||||
|
||||
for room_id in room_ids:
|
||||
membership_info = room_membership_map.get(room_id)
|
||||
if membership_info and membership_info.membership in (
|
||||
Membership.LEAVE,
|
||||
Membership.BAN,
|
||||
):
|
||||
leave_ban_rooms.add(room_id)
|
||||
else:
|
||||
other_rooms.add(room_id)
|
||||
|
||||
# Fetch thread updates from storage, handling LEAVE/BAN rooms separately
|
||||
all_thread_updates: ThreadUpdatesMap = {}
|
||||
prev_batch_token: StreamToken | None = None
|
||||
remaining_limit = limit
|
||||
|
||||
# Query LEAVE/BAN rooms with bounded to_token to prevent data leaks
|
||||
if leave_ban_rooms:
|
||||
for room_id in leave_ban_rooms:
|
||||
if remaining_limit <= 0:
|
||||
# We've hit the limit, set prev_batch to indicate more results
|
||||
prev_batch_token = to_token
|
||||
break
|
||||
|
||||
membership_info = room_membership_map[room_id]
|
||||
bounded_to_token = membership_info.event_pos.to_room_stream_token()
|
||||
|
||||
(
|
||||
room_thread_updates,
|
||||
room_prev_batch,
|
||||
) = await self._main_store.get_thread_updates_for_rooms(
|
||||
room_ids={room_id},
|
||||
from_token=from_token.room_key if from_token else None,
|
||||
to_token=bounded_to_token,
|
||||
limit=remaining_limit,
|
||||
exclude_thread_ids=exclude_thread_ids,
|
||||
)
|
||||
|
||||
# Count updates and reduce remaining limit
|
||||
num_updates = sum(
|
||||
len(updates) for updates in room_thread_updates.values()
|
||||
)
|
||||
remaining_limit -= num_updates
|
||||
|
||||
# Merge updates
|
||||
for thread_id, updates in room_thread_updates.items():
|
||||
all_thread_updates.setdefault(thread_id, []).extend(updates)
|
||||
|
||||
# Merge prev_batch tokens (take the maximum for backward pagination)
|
||||
if room_prev_batch is not None:
|
||||
if prev_batch_token is None:
|
||||
prev_batch_token = room_prev_batch
|
||||
elif (
|
||||
room_prev_batch.room_key.stream
|
||||
> prev_batch_token.room_key.stream
|
||||
):
|
||||
prev_batch_token = room_prev_batch
|
||||
|
||||
# Query other rooms (joined/invited/knocking) with normal to_token
|
||||
if other_rooms and remaining_limit > 0:
|
||||
(
|
||||
other_thread_updates,
|
||||
other_prev_batch,
|
||||
) = await self._main_store.get_thread_updates_for_rooms(
|
||||
room_ids=other_rooms,
|
||||
from_token=from_token.room_key if from_token else None,
|
||||
to_token=to_token.room_key,
|
||||
limit=remaining_limit,
|
||||
exclude_thread_ids=exclude_thread_ids,
|
||||
)
|
||||
|
||||
# Merge updates
|
||||
for thread_id, updates in other_thread_updates.items():
|
||||
all_thread_updates.setdefault(thread_id, []).extend(updates)
|
||||
|
||||
# Merge prev_batch tokens
|
||||
if other_prev_batch is not None:
|
||||
if prev_batch_token is None:
|
||||
prev_batch_token = other_prev_batch
|
||||
elif (
|
||||
other_prev_batch.room_key.stream > prev_batch_token.room_key.stream
|
||||
):
|
||||
prev_batch_token = other_prev_batch
|
||||
|
||||
return all_thread_updates, prev_batch_token
|
||||
|
||||
async def get_thread_updates_for_rooms(
|
||||
self,
|
||||
room_ids: frozenset[str],
|
||||
room_membership_map: Mapping[str, "RoomsForUserType"],
|
||||
user_id: str,
|
||||
from_token: StreamToken | None,
|
||||
to_token: StreamToken,
|
||||
limit: int,
|
||||
include_roots: bool = False,
|
||||
exclude_thread_ids: set[str] | None = None,
|
||||
) -> tuple[dict[str, dict[str, ThreadUpdate]], StreamToken | None]:
|
||||
"""Get thread updates across multiple rooms with full processing pipeline.
|
||||
|
||||
This is the main entry point for fetching thread updates. It handles:
|
||||
- Fetching updates with membership-based security
|
||||
- Filtering for visibility
|
||||
- Optionally fetching thread roots and aggregations
|
||||
- Building the response structure
|
||||
|
||||
Args:
|
||||
room_ids: The set of room IDs to fetch updates for
|
||||
room_membership_map: Map of room_id to RoomsForUserType for membership info
|
||||
user_id: The user requesting the updates
|
||||
from_token: Lower bound (exclusive) for the query
|
||||
to_token: Upper bound for the query
|
||||
limit: Maximum number of updates to return
|
||||
include_roots: Whether to fetch and include thread root events (default: False)
|
||||
exclude_thread_ids: Optional set of thread IDs to exclude
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- Map of room_id to thread_root_id to ThreadUpdate
|
||||
- Global prev_batch token if there are more results, None otherwise
|
||||
"""
|
||||
# Fetch thread updates with membership handling
|
||||
all_thread_updates, prev_batch_token = await self._fetch_thread_updates(
|
||||
room_ids=room_ids,
|
||||
room_membership_map=room_membership_map,
|
||||
from_token=from_token,
|
||||
to_token=to_token,
|
||||
limit=limit,
|
||||
exclude_thread_ids=exclude_thread_ids,
|
||||
)
|
||||
|
||||
if not all_thread_updates:
|
||||
return {}, prev_batch_token
|
||||
|
||||
# Filter thread updates for visibility
|
||||
filtered_updates = await self._filter_thread_updates_for_user(
|
||||
all_thread_updates, user_id
|
||||
)
|
||||
|
||||
if not filtered_updates:
|
||||
return {}, prev_batch_token
|
||||
|
||||
# Optionally fetch thread root events and their bundled aggregations
|
||||
thread_root_event_map: ThreadRootsMap = {}
|
||||
aggregations_map: AggregationsMap = {}
|
||||
if include_roots:
|
||||
# Fetch thread root events
|
||||
thread_root_events = await self._main_store.get_events_as_list(
|
||||
filtered_updates.keys()
|
||||
)
|
||||
thread_root_event_map = {e.event_id: e for e in thread_root_events}
|
||||
|
||||
# Fetch bundled aggregations for the thread roots
|
||||
if thread_root_event_map:
|
||||
aggregations_map = await self.get_bundled_aggregations(
|
||||
thread_root_event_map.values(),
|
||||
user_id,
|
||||
)
|
||||
|
||||
# Build response structure with per-thread prev_batch tokens
|
||||
thread_updates = self._build_thread_updates_response(
|
||||
filtered_updates=filtered_updates,
|
||||
thread_root_event_map=thread_root_event_map,
|
||||
aggregations_map=aggregations_map,
|
||||
global_prev_batch_token=prev_batch_token,
|
||||
)
|
||||
|
||||
return thread_updates, prev_batch_token
|
||||
|
||||
@staticmethod
|
||||
async def serialize_thread_updates(
|
||||
thread_updates: Mapping[str, Mapping[str, ThreadUpdate]],
|
||||
prev_batch_token: StreamToken | None,
|
||||
event_serializer: "EventClientSerializer",
|
||||
time_now: int,
|
||||
store: "DataStore",
|
||||
serialize_options: SerializeEventConfig,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Serialize thread updates to JSON format.
|
||||
|
||||
This helper handles serialization of ThreadUpdate objects for both the
|
||||
companion endpoint and the sliding sync extension.
|
||||
|
||||
Args:
|
||||
thread_updates: Map of room_id to thread_root_id to ThreadUpdate
|
||||
prev_batch_token: Global pagination token for fetching more updates
|
||||
event_serializer: The event serializer to use
|
||||
time_now: Current time in milliseconds for event serialization
|
||||
store: Datastore for serializing stream tokens
|
||||
serialize_options: Serialization config
|
||||
|
||||
Returns:
|
||||
JSON-serializable dict with "updates" and optionally "prev_batch"
|
||||
"""
|
||||
updates_dict: JsonDict = {}
|
||||
|
||||
for room_id, room_threads in thread_updates.items():
|
||||
room_updates: JsonDict = {}
|
||||
for thread_root_id, update in room_threads.items():
|
||||
update_dict: JsonDict = {}
|
||||
|
||||
# Serialize thread_root event if present
|
||||
if update.thread_root is not None:
|
||||
bundle_aggs_map = (
|
||||
{thread_root_id: update.bundled_aggregations}
|
||||
if update.bundled_aggregations is not None
|
||||
else None
|
||||
)
|
||||
serialized_events = await event_serializer.serialize_events(
|
||||
[update.thread_root],
|
||||
time_now,
|
||||
config=serialize_options,
|
||||
bundle_aggregations=bundle_aggs_map,
|
||||
)
|
||||
if serialized_events:
|
||||
update_dict["thread_root"] = serialized_events[0]
|
||||
|
||||
# Add per-thread prev_batch if present
|
||||
if update.prev_batch is not None:
|
||||
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
|
||||
|
||||
room_updates[thread_root_id] = update_dict
|
||||
|
||||
updates_dict[room_id] = room_updates
|
||||
|
||||
result: JsonDict = {"updates": updates_dict}
|
||||
|
||||
# Add global prev_batch token if present
|
||||
if prev_batch_token is not None:
|
||||
result["prev_batch"] = await prev_batch_token.to_string(store)
|
||||
|
||||
return result
|
||||
|
||||
async def get_threads(
|
||||
self,
|
||||
requester: Requester,
|
||||
|
||||
@@ -305,6 +305,7 @@ class SlidingSyncHandler:
|
||||
# account data, read receipts, typing indicators, to-device messages, etc).
|
||||
actual_room_ids=set(relevant_room_map.keys()),
|
||||
actual_room_response_map=rooms,
|
||||
room_membership_for_user_at_to_token_map=room_membership_for_user_map,
|
||||
from_token=from_token,
|
||||
to_token=to_token,
|
||||
)
|
||||
|
||||
@@ -26,8 +26,16 @@ from typing import (
|
||||
|
||||
from typing_extensions import TypeAlias, assert_never
|
||||
|
||||
from synapse.api.constants import AccountDataTypes, EduTypes
|
||||
from synapse.api.constants import (
|
||||
AccountDataTypes,
|
||||
EduTypes,
|
||||
EventContentFields,
|
||||
MRelatesToFields,
|
||||
RelationTypes,
|
||||
)
|
||||
from synapse.events import EventBase
|
||||
from synapse.handlers.receipts import ReceiptEventSource
|
||||
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.databases.main.receipts import ReceiptInRoom
|
||||
from synapse.types import (
|
||||
@@ -73,7 +81,10 @@ class SlidingSyncExtensionHandler:
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.push_rules_handler = hs.get_push_rules_handler()
|
||||
self.relations_handler = hs.get_relations_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
|
||||
self._enable_threads_ext = hs.config.experimental.msc4360_enabled
|
||||
|
||||
@trace
|
||||
async def get_extensions_response(
|
||||
@@ -84,6 +95,7 @@ class SlidingSyncExtensionHandler:
|
||||
actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
|
||||
actual_room_ids: set[str],
|
||||
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
|
||||
room_membership_for_user_at_to_token_map: Mapping[str, RoomsForUserType],
|
||||
to_token: StreamToken,
|
||||
from_token: SlidingSyncStreamToken | None,
|
||||
) -> SlidingSyncResult.Extensions:
|
||||
@@ -99,6 +111,8 @@ class SlidingSyncExtensionHandler:
|
||||
actual_room_ids: The actual room IDs in the the Sliding Sync response.
|
||||
actual_room_response_map: A map of room ID to room results in the the
|
||||
Sliding Sync response.
|
||||
room_membership_for_user_at_to_token_map: A map of room ID to the membership
|
||||
information for the user in the room at the time of `to_token`.
|
||||
to_token: The latest point in the stream to sync up to.
|
||||
from_token: The point in the stream to sync from.
|
||||
"""
|
||||
@@ -174,6 +188,18 @@ class SlidingSyncExtensionHandler:
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
threads_coro = None
|
||||
if sync_config.extensions.threads is not None and self._enable_threads_ext:
|
||||
threads_coro = self.get_threads_extension_response(
|
||||
sync_config=sync_config,
|
||||
threads_request=sync_config.extensions.threads,
|
||||
actual_room_ids=actual_room_ids,
|
||||
actual_room_response_map=actual_room_response_map,
|
||||
room_membership_for_user_at_to_token_map=room_membership_for_user_at_to_token_map,
|
||||
to_token=to_token,
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
(
|
||||
to_device_response,
|
||||
e2ee_response,
|
||||
@@ -181,6 +207,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts_response,
|
||||
typing_response,
|
||||
thread_subs_response,
|
||||
threads_response,
|
||||
) = await gather_optional_coroutines(
|
||||
to_device_coro,
|
||||
e2ee_coro,
|
||||
@@ -188,6 +215,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts_coro,
|
||||
typing_coro,
|
||||
thread_subs_coro,
|
||||
threads_coro,
|
||||
)
|
||||
|
||||
return SlidingSyncResult.Extensions(
|
||||
@@ -197,6 +225,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts=receipts_response,
|
||||
typing=typing_response,
|
||||
thread_subscriptions=thread_subs_response,
|
||||
threads=threads_response,
|
||||
)
|
||||
|
||||
def find_relevant_room_ids_for_extension(
|
||||
@@ -967,3 +996,104 @@ class SlidingSyncExtensionHandler:
|
||||
unsubscribed=unsubscribed_threads,
|
||||
prev_batch=prev_batch,
|
||||
)
|
||||
|
||||
def _extract_thread_id_from_event(self, event: EventBase) -> str | None:
|
||||
"""Extract thread ID from event if it's a thread reply.
|
||||
|
||||
Args:
|
||||
event: The event to check.
|
||||
|
||||
Returns:
|
||||
The thread ID if the event is a thread reply, None otherwise.
|
||||
"""
|
||||
relates_to = event.content.get(EventContentFields.RELATIONS)
|
||||
if isinstance(relates_to, dict):
|
||||
if relates_to.get(MRelatesToFields.REL_TYPE) == RelationTypes.THREAD:
|
||||
return relates_to.get(MRelatesToFields.EVENT_ID)
|
||||
return None
|
||||
|
||||
def _find_threads_in_timeline(
|
||||
self,
|
||||
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
|
||||
) -> set[str]:
|
||||
"""Find all thread IDs that have events in room timelines.
|
||||
|
||||
Args:
|
||||
actual_room_response_map: A map of room ID to room results.
|
||||
|
||||
Returns:
|
||||
A set of thread IDs (thread root event IDs) that appear in the timeline.
|
||||
"""
|
||||
threads_in_timeline: set[str] = set()
|
||||
for room_result in actual_room_response_map.values():
|
||||
if room_result.timeline_events:
|
||||
for event in room_result.timeline_events:
|
||||
thread_id = self._extract_thread_id_from_event(event)
|
||||
if thread_id:
|
||||
threads_in_timeline.add(thread_id)
|
||||
return threads_in_timeline
|
||||
|
||||
async def get_threads_extension_response(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
threads_request: SlidingSyncConfig.Extensions.ThreadsExtension,
|
||||
actual_room_ids: set[str],
|
||||
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
|
||||
room_membership_for_user_at_to_token_map: Mapping[str, RoomsForUserType],
|
||||
to_token: StreamToken,
|
||||
from_token: SlidingSyncStreamToken | None,
|
||||
) -> SlidingSyncResult.Extensions.ThreadsExtension | None:
|
||||
"""Handle Threads extension (MSC4360)
|
||||
|
||||
Args:
|
||||
sync_config: Sync configuration.
|
||||
threads_request: The threads extension from the request.
|
||||
actual_room_ids: The actual room IDs in the the Sliding Sync response.
|
||||
actual_room_response_map: A map of room ID to room results in the
|
||||
sliding sync response. Used to determine which threads already have
|
||||
events in the room timeline.
|
||||
room_membership_for_user_at_to_token_map: A map of room ID to the membership
|
||||
information for the user in the room at the time of `to_token`.
|
||||
to_token: The point in the stream to sync up to.
|
||||
from_token: The point in the stream to sync from.
|
||||
|
||||
Returns:
|
||||
the response (None if empty or threads extension is disabled)
|
||||
"""
|
||||
if not threads_request.enabled:
|
||||
return None
|
||||
|
||||
# Identify which threads already have events in the room timelines.
|
||||
# If include_roots=False, we'll exclude these threads from the DB query
|
||||
# since the client already sees the thread activity in the timeline.
|
||||
# If include_roots=True, we fetch all threads regardless, because the client
|
||||
# wants the thread root events.
|
||||
threads_to_exclude: set[str] | None = None
|
||||
if not threads_request.include_roots:
|
||||
threads_to_exclude = self._find_threads_in_timeline(
|
||||
actual_room_response_map
|
||||
)
|
||||
|
||||
# Get thread updates using unified helper
|
||||
user_id = sync_config.user.to_string()
|
||||
(
|
||||
thread_updates_response,
|
||||
prev_batch_token,
|
||||
) = await self.relations_handler.get_thread_updates_for_rooms(
|
||||
room_ids=frozenset(actual_room_ids),
|
||||
room_membership_map=room_membership_for_user_at_to_token_map,
|
||||
user_id=user_id,
|
||||
from_token=from_token.stream_token if from_token else None,
|
||||
to_token=to_token,
|
||||
limit=threads_request.limit,
|
||||
include_roots=threads_request.include_roots,
|
||||
exclude_thread_ids=threads_to_exclude,
|
||||
)
|
||||
|
||||
if not thread_updates_response:
|
||||
return None
|
||||
|
||||
return SlidingSyncResult.Extensions.ThreadsExtension(
|
||||
updates=thread_updates_response,
|
||||
prev_batch=prev_batch_token,
|
||||
)
|
||||
|
||||
@@ -20,17 +20,33 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from pydantic import StrictBool, StrictStr
|
||||
from pydantic.types import StringConstraints
|
||||
|
||||
from synapse.api.constants import Direction, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
from synapse.handlers.relations import ThreadsListInclude
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_and_validate_json_object_from_request,
|
||||
parse_boolean,
|
||||
parse_integer,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.storage.databases.main.relations import ThreadsNextBatch
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict
|
||||
from synapse.streams.config import (
|
||||
PaginationConfig,
|
||||
extract_stream_token_from_pagination_token,
|
||||
)
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
|
||||
from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig
|
||||
from synapse.types.rest.client import RequestBodyModel, SlidingSyncBody
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -38,6 +54,39 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadUpdatesBody(RequestBodyModel):
|
||||
"""
|
||||
Thread updates companion endpoint request body (MSC4360).
|
||||
|
||||
Allows paginating thread updates using the same room selection as a sliding sync
|
||||
request. This enables clients to fetch thread updates for the same set of rooms
|
||||
that were included in their sliding sync response.
|
||||
|
||||
Attributes:
|
||||
lists: Sliding window API lists, using the same structure as SlidingSyncBody.lists.
|
||||
If provided along with room_subscriptions, the union of rooms from both will
|
||||
be used.
|
||||
room_subscriptions: Room subscription API rooms, using the same structure as
|
||||
SlidingSyncBody.room_subscriptions. If provided along with lists, the union
|
||||
of rooms from both will be used.
|
||||
include_roots: Whether to include the thread root events in the response.
|
||||
Defaults to False.
|
||||
|
||||
If neither lists nor room_subscriptions are provided, thread updates from all
|
||||
joined rooms are returned.
|
||||
"""
|
||||
|
||||
lists: (
|
||||
dict[
|
||||
Annotated[str, StringConstraints(max_length=64, strict=True)],
|
||||
SlidingSyncBody.SlidingSyncList,
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
room_subscriptions: dict[StrictStr, SlidingSyncBody.RoomSubscription] | None = None
|
||||
include_roots: StrictBool = False
|
||||
|
||||
|
||||
class RelationPaginationServlet(RestServlet):
|
||||
"""API to paginate relations on an event by topological ordering, optionally
|
||||
filtered by relation type and event type.
|
||||
@@ -133,6 +182,167 @@ class ThreadsServlet(RestServlet):
|
||||
return 200, result
|
||||
|
||||
|
||||
class ThreadUpdatesServlet(RestServlet):
|
||||
"""
|
||||
Companion endpoint to the Sliding Sync threads extension (MSC4360).
|
||||
Allows clients to bulk fetch thread updates across all joined rooms.
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/io.element.msc4360/thread_updates$",
|
||||
unstable=True,
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Client API requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.relations_handler = hs.get_relations_handler()
|
||||
self.event_serializer = hs.get_event_client_serializer()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.sliding_sync_handler = hs.get_sliding_sync_handler()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
# Parse request body
|
||||
body = parse_and_validate_json_object_from_request(request, ThreadUpdatesBody)
|
||||
|
||||
# Parse query parameters
|
||||
dir_str = parse_string(request, "dir", default="b")
|
||||
if dir_str != "b":
|
||||
raise SynapseError(
|
||||
400,
|
||||
"The 'dir' parameter must be 'b' (backward). Forward pagination is not supported.",
|
||||
)
|
||||
|
||||
limit = parse_integer(request, "limit", default=100)
|
||||
if limit <= 0:
|
||||
raise SynapseError(400, "The 'limit' parameter must be positive.")
|
||||
|
||||
from_token_str = parse_string(request, "from")
|
||||
to_token_str = parse_string(request, "to")
|
||||
|
||||
# Parse pagination tokens
|
||||
from_token: StreamToken | None = None
|
||||
to_token: StreamToken | None = None
|
||||
|
||||
if from_token_str:
|
||||
try:
|
||||
stream_token_str = extract_stream_token_from_pagination_token(
|
||||
from_token_str
|
||||
)
|
||||
from_token = await StreamToken.from_string(self.store, stream_token_str)
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing 'from' token: %s", from_token_str)
|
||||
raise SynapseError(400, "'from' parameter is invalid") from e
|
||||
|
||||
if to_token_str:
|
||||
try:
|
||||
stream_token_str = extract_stream_token_from_pagination_token(
|
||||
to_token_str
|
||||
)
|
||||
to_token = await StreamToken.from_string(self.store, stream_token_str)
|
||||
except Exception:
|
||||
raise SynapseError(400, "'to' parameter is invalid")
|
||||
|
||||
# Get the list of rooms to fetch thread updates for
|
||||
user_id = requester.user.to_string()
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
# Get the current stream token for membership lookup
|
||||
if from_token is None:
|
||||
max_stream_ordering = self.store.get_room_max_stream_ordering()
|
||||
current_token = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM, RoomStreamToken(stream=max_stream_ordering)
|
||||
)
|
||||
else:
|
||||
current_token = from_token
|
||||
|
||||
# Get room membership information to properly handle LEAVE/BAN rooms
|
||||
(
|
||||
room_membership_for_user_at_to_token_map,
|
||||
_,
|
||||
_,
|
||||
) = await self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
|
||||
user=user,
|
||||
to_token=current_token,
|
||||
from_token=None,
|
||||
)
|
||||
|
||||
# Determine which rooms to fetch updates for based on lists/room_subscriptions
|
||||
if body.lists is not None or body.room_subscriptions is not None:
|
||||
# Use sliding sync room selection logic
|
||||
sync_config = SlidingSyncConfig(
|
||||
user=user,
|
||||
requester=requester,
|
||||
lists=body.lists,
|
||||
room_subscriptions=body.room_subscriptions,
|
||||
)
|
||||
|
||||
# Use the sliding sync room list handler to get the same set of rooms
|
||||
interested_rooms = (
|
||||
await self.sliding_sync_handler.room_lists.compute_interested_rooms(
|
||||
sync_config=sync_config,
|
||||
previous_connection_state=PerConnectionState(),
|
||||
to_token=current_token,
|
||||
from_token=None,
|
||||
)
|
||||
)
|
||||
|
||||
room_ids = frozenset(interested_rooms.relevant_room_map.keys())
|
||||
else:
|
||||
# No lists or room_subscriptions, use only joined rooms
|
||||
room_ids = frozenset(
|
||||
room_id
|
||||
for room_id, membership_info in room_membership_for_user_at_to_token_map.items()
|
||||
if membership_info.membership == Membership.JOIN
|
||||
)
|
||||
|
||||
# Get thread updates using unified helper
|
||||
(
|
||||
thread_updates,
|
||||
prev_batch_token,
|
||||
) = await self.relations_handler.get_thread_updates_for_rooms(
|
||||
room_ids=room_ids,
|
||||
room_membership_map=room_membership_for_user_at_to_token_map,
|
||||
user_id=user_id,
|
||||
from_token=to_token,
|
||||
to_token=from_token if from_token else current_token,
|
||||
limit=limit,
|
||||
include_roots=body.include_roots,
|
||||
)
|
||||
|
||||
if not thread_updates:
|
||||
return 200, {"chunk": {}}
|
||||
|
||||
# Serialize thread updates using shared helper
|
||||
time_now = self.clock.time_msec()
|
||||
serialize_options = SerializeEventConfig(requester=requester)
|
||||
|
||||
serialized = await self.relations_handler.serialize_thread_updates(
|
||||
thread_updates=thread_updates,
|
||||
prev_batch_token=prev_batch_token,
|
||||
event_serializer=self.event_serializer,
|
||||
time_now=time_now,
|
||||
store=self.store,
|
||||
serialize_options=serialize_options,
|
||||
)
|
||||
|
||||
# Build response with "chunk" wrapper and "next_batch" key
|
||||
# (companion endpoint uses different key names than sliding sync)
|
||||
response: JsonDict = {"chunk": serialized["updates"]}
|
||||
if "prev_batch" in serialized:
|
||||
response["next_batch"] = serialized["prev_batch"]
|
||||
|
||||
return 200, response
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
RelationPaginationServlet(hs).register(http_server)
|
||||
ThreadsServlet(hs).register(http_server)
|
||||
if hs.config.experimental.msc4360_enabled:
|
||||
ThreadUpdatesServlet(hs).register(http_server)
|
||||
|
||||
@@ -31,11 +31,13 @@ from synapse.api.filtering import FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.events.utils import (
|
||||
EventClientSerializer,
|
||||
SerializeEventConfig,
|
||||
format_event_for_client_v2_without_room_id,
|
||||
format_event_raw,
|
||||
)
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.handlers.relations import RelationsHandler
|
||||
from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult
|
||||
from synapse.handlers.sync import (
|
||||
ArchivedSyncResult,
|
||||
@@ -56,6 +58,7 @@ from synapse.http.servlet import (
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
|
||||
from synapse.rest.admin.experimental_features import ExperimentalFeature
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
|
||||
from synapse.types.rest.client import SlidingSyncBody
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
@@ -646,6 +649,7 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
- receipts (MSC3960)
|
||||
- account data (MSC3959)
|
||||
- thread subscriptions (MSC4308)
|
||||
- threads (MSC4360)
|
||||
|
||||
Request query parameters:
|
||||
timeout: How long to wait for new events in milliseconds.
|
||||
@@ -849,7 +853,10 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
logger.info("Client has disconnected; not serializing response.")
|
||||
return 200, {}
|
||||
|
||||
response_content = await self.encode_response(requester, sliding_sync_results)
|
||||
time_now = self.clock.time_msec()
|
||||
response_content = await self.encode_response(
|
||||
requester, sliding_sync_results, time_now
|
||||
)
|
||||
|
||||
return 200, response_content
|
||||
|
||||
@@ -858,6 +865,7 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
self,
|
||||
requester: Requester,
|
||||
sliding_sync_result: SlidingSyncResult,
|
||||
time_now: int,
|
||||
) -> JsonDict:
|
||||
response: JsonDict = defaultdict(dict)
|
||||
|
||||
@@ -866,10 +874,10 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
if serialized_lists:
|
||||
response["lists"] = serialized_lists
|
||||
response["rooms"] = await self.encode_rooms(
|
||||
requester, sliding_sync_result.rooms
|
||||
requester, sliding_sync_result.rooms, time_now
|
||||
)
|
||||
response["extensions"] = await self.encode_extensions(
|
||||
requester, sliding_sync_result.extensions
|
||||
requester, sliding_sync_result.extensions, time_now
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -901,9 +909,8 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
self,
|
||||
requester: Requester,
|
||||
rooms: dict[str, SlidingSyncResult.RoomResult],
|
||||
time_now: int,
|
||||
) -> JsonDict:
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
serialize_options = SerializeEventConfig(
|
||||
event_format=format_event_for_client_v2_without_room_id,
|
||||
requester=requester,
|
||||
@@ -1019,7 +1026,10 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
|
||||
@trace_with_opname("sliding_sync.encode_extensions")
|
||||
async def encode_extensions(
|
||||
self, requester: Requester, extensions: SlidingSyncResult.Extensions
|
||||
self,
|
||||
requester: Requester,
|
||||
extensions: SlidingSyncResult.Extensions,
|
||||
time_now: int,
|
||||
) -> JsonDict:
|
||||
serialized_extensions: JsonDict = {}
|
||||
|
||||
@@ -1089,6 +1099,18 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
_serialise_thread_subscriptions(extensions.thread_subscriptions)
|
||||
)
|
||||
|
||||
# excludes both None and falsy `threads`
|
||||
if extensions.threads:
|
||||
serialized_extensions[
|
||||
"io.element.msc4360.threads"
|
||||
] = await _serialise_threads(
|
||||
self.event_serializer,
|
||||
time_now,
|
||||
extensions.threads,
|
||||
self.store,
|
||||
requester,
|
||||
)
|
||||
|
||||
return serialized_extensions
|
||||
|
||||
|
||||
@@ -1125,6 +1147,52 @@ def _serialise_thread_subscriptions(
|
||||
return out
|
||||
|
||||
|
||||
async def _serialise_threads(
|
||||
event_serializer: EventClientSerializer,
|
||||
time_now: int,
|
||||
threads: SlidingSyncResult.Extensions.ThreadsExtension,
|
||||
store: "DataStore",
|
||||
requester: Requester,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Serialize the threads extension response for sliding sync.
|
||||
|
||||
Args:
|
||||
event_serializer: The event serializer to use for serializing thread root events.
|
||||
time_now: The current time in milliseconds, used for event serialization.
|
||||
threads: The threads extension data containing thread updates and pagination tokens.
|
||||
store: The datastore, needed for serializing stream tokens.
|
||||
requester: The user making the request, used for transaction_id inclusion.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable dict containing:
|
||||
- "updates": A nested dict mapping room_id -> thread_root_id -> thread update.
|
||||
Each thread update may contain:
|
||||
- "thread_root": The serialized thread root event (if include_roots was True),
|
||||
with bundled aggregations including the latest_event in unsigned.m.relations.m.thread.
|
||||
- "prev_batch": A pagination token for fetching older events in the thread.
|
||||
- "prev_batch": A pagination token for fetching older thread updates (if available).
|
||||
"""
|
||||
if not threads.updates:
|
||||
out: JsonDict = {}
|
||||
if threads.prev_batch:
|
||||
out["prev_batch"] = await threads.prev_batch.to_string(store)
|
||||
return out
|
||||
|
||||
# Create serialization config to include transaction_id for requester's events
|
||||
serialize_options = SerializeEventConfig(requester=requester)
|
||||
|
||||
# Use shared serialization helper (static method)
|
||||
return await RelationsHandler.serialize_thread_updates(
|
||||
thread_updates=threads.updates,
|
||||
prev_batch_token=threads.prev_batch,
|
||||
event_serializer=event_serializer,
|
||||
time_now=time_now,
|
||||
store=store,
|
||||
serialize_options=serialize_options,
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
SyncRestServlet(hs).register(http_server)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
@@ -40,13 +41,19 @@ from synapse.storage.database import (
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.stream import (
|
||||
generate_next_token,
|
||||
generate_pagination_bounds,
|
||||
generate_pagination_where_clause,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.types import JsonDict, StreamKeyType, StreamToken
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
RoomStreamToken,
|
||||
StreamKeyType,
|
||||
StreamToken,
|
||||
)
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -88,7 +95,23 @@ class _RelatedEvent:
|
||||
sender: str
|
||||
|
||||
|
||||
class RelationsWorkerStore(SQLBaseStore):
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadUpdateInfo:
|
||||
"""
|
||||
Information about a thread update for the sliding sync threads extension.
|
||||
|
||||
Attributes:
|
||||
event_id: The event ID of the event in the thread.
|
||||
room_id: The room ID where this thread exists.
|
||||
stream_ordering: The stream ordering of this event.
|
||||
"""
|
||||
|
||||
event_id: str
|
||||
room_id: str
|
||||
stream_ordering: int
|
||||
|
||||
|
||||
class RelationsWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
@@ -584,14 +607,18 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
"get_applicable_edits", _get_applicable_edits_txn
|
||||
)
|
||||
|
||||
edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined]
|
||||
edits = await self.get_events(edit_ids.values())
|
||||
|
||||
# Map to the original event IDs to the edit events.
|
||||
#
|
||||
# There might not be an edit event due to there being no edits or
|
||||
# due to the event not being known, either case is treated the same.
|
||||
return {
|
||||
original_event_id: edits.get(edit_ids.get(original_event_id))
|
||||
original_event_id: (
|
||||
edits.get(edit_id)
|
||||
if (edit_id := edit_ids.get(original_event_id))
|
||||
else None
|
||||
)
|
||||
for original_event_id in event_ids
|
||||
}
|
||||
|
||||
@@ -699,7 +726,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
"get_thread_summaries", _get_thread_summaries_txn
|
||||
)
|
||||
|
||||
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
|
||||
latest_events = await self.get_events(latest_event_ids.values())
|
||||
|
||||
# Map to the event IDs to the thread summary.
|
||||
#
|
||||
@@ -1111,6 +1138,148 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
"get_related_thread_id", _get_related_thread_id
|
||||
)
|
||||
|
||||
async def get_thread_updates_for_rooms(
|
||||
self,
|
||||
*,
|
||||
room_ids: Collection[str],
|
||||
from_token: RoomStreamToken | None = None,
|
||||
to_token: RoomStreamToken | None = None,
|
||||
limit: int = 5,
|
||||
exclude_thread_ids: Collection[str] | None = None,
|
||||
) -> tuple[dict[str, list[ThreadUpdateInfo]], StreamToken | None]:
|
||||
"""Get a list of updated threads, ordered by stream ordering of their
|
||||
latest reply, filtered to only include threads in rooms where the user
|
||||
is currently joined.
|
||||
|
||||
Args:
|
||||
room_ids: The room IDs to fetch thread updates for.
|
||||
from_token: The lower bound (exclusive) for thread updates. If None,
|
||||
fetch from the start of the room timeline.
|
||||
to_token: The upper bound (inclusive) for thread updates. If None,
|
||||
fetch up to the current position in the room timeline.
|
||||
limit: Maximum number of thread updates to return.
|
||||
exclude_thread_ids: Optional collection of thread root event IDs to exclude
|
||||
from the results. Useful for filtering out threads already visible
|
||||
in the room timeline.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
A dict mapping thread_id to list of ThreadUpdateInfo objects,
|
||||
ordered by stream_ordering descending (most recent first).
|
||||
A prev_batch StreamToken (exclusive) if there are more results available,
|
||||
None otherwise.
|
||||
"""
|
||||
# Ensure bad limits aren't being passed in.
|
||||
assert limit > 0
|
||||
|
||||
if len(room_ids) == 0:
|
||||
return ({}), None
|
||||
|
||||
def _get_thread_updates_for_user_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> tuple[list[tuple[str, str, str, int]], int | None]:
|
||||
room_clause, room_id_values = make_in_list_sql_clause(
|
||||
txn.database_engine, "e.room_id", room_ids
|
||||
)
|
||||
|
||||
# Generate the pagination clause, if necessary.
|
||||
pagination_clause = ""
|
||||
pagination_args: list[str] = []
|
||||
if from_token:
|
||||
from_bound = from_token.stream
|
||||
pagination_clause += " AND stream_ordering > ?"
|
||||
pagination_args.append(str(from_bound))
|
||||
|
||||
if to_token:
|
||||
to_bound = to_token.stream
|
||||
pagination_clause += " AND stream_ordering <= ?"
|
||||
pagination_args.append(str(to_bound))
|
||||
|
||||
# Generate the exclusion clause for thread IDs, if necessary.
|
||||
exclusion_clause = ""
|
||||
exclusion_args: list[str] = []
|
||||
if exclude_thread_ids:
|
||||
exclusion_clause, exclusion_args = make_in_list_sql_clause(
|
||||
txn.database_engine,
|
||||
"er.relates_to_id",
|
||||
exclude_thread_ids,
|
||||
negative=True,
|
||||
)
|
||||
exclusion_clause = f" AND {exclusion_clause}"
|
||||
|
||||
# TODO: improve the fact that multiple hits for the same thread means we
|
||||
# won't get as many overall updates for the sss response
|
||||
|
||||
# Find any thread events between the stream ordering bounds.
|
||||
sql = f"""
|
||||
SELECT e.event_id, er.relates_to_id, e.room_id, e.stream_ordering
|
||||
FROM event_relations AS er
|
||||
INNER JOIN events AS e ON er.event_id = e.event_id
|
||||
WHERE er.relation_type = '{RelationTypes.THREAD}'
|
||||
AND {room_clause}
|
||||
{exclusion_clause}
|
||||
{pagination_clause}
|
||||
ORDER BY stream_ordering DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
# Fetch `limit + 1` rows as a way to detect if there are more results beyond
|
||||
# what we're returning. If we get exactly `limit + 1` rows back, we know there
|
||||
# are more results available and we can set `next_token`. We only return the
|
||||
# first `limit` rows to the caller. This avoids needing a separate COUNT query.
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
*room_id_values,
|
||||
*exclusion_args,
|
||||
*pagination_args,
|
||||
limit + 1,
|
||||
),
|
||||
)
|
||||
|
||||
# SQL returns: event_id, thread_id, room_id, stream_ordering
|
||||
rows = cast(list[tuple[str, str, str, int]], txn.fetchall())
|
||||
|
||||
# If there are more events, generate the next pagination key from the
|
||||
# last thread which will be returned.
|
||||
next_token = None
|
||||
if len(rows) > limit:
|
||||
# Set the next_token to be the second last row in the result set since
|
||||
# that will be the last row we return from this function.
|
||||
# This works as an exclusive bound that can be backpaginated from.
|
||||
# Use the stream_ordering field (index 2 in original rows)
|
||||
next_token = rows[-2][3]
|
||||
|
||||
return rows[:limit], next_token
|
||||
|
||||
thread_infos, next_token_int = await self.db_pool.runInteraction(
|
||||
"get_thread_updates_for_user", _get_thread_updates_for_user_txn
|
||||
)
|
||||
|
||||
# Convert the next_token int (stream ordering) to a StreamToken.
|
||||
# Use StreamToken.START as base (all other streams at 0) since only room
|
||||
# position matters.
|
||||
# Subtract 1 to make it exclusive - the client can paginate from this point without
|
||||
# receiving the last thread update that was already returned.
|
||||
next_token = None
|
||||
if next_token_int is not None:
|
||||
next_token = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM, RoomStreamToken(stream=next_token_int - 1)
|
||||
)
|
||||
|
||||
# Build ThreadUpdateInfo objects.
|
||||
thread_update_infos: dict[str, list[ThreadUpdateInfo]] = defaultdict(list)
|
||||
for event_id, thread_id, room_id, stream_ordering in thread_infos:
|
||||
thread_update_infos[thread_id].append(
|
||||
ThreadUpdateInfo(
|
||||
event_id=event_id,
|
||||
room_id=room_id,
|
||||
stream_ordering=stream_ordering,
|
||||
)
|
||||
)
|
||||
|
||||
return (thread_update_infos, next_token)
|
||||
|
||||
|
||||
class RelationsStore(RelationsWorkerStore):
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2025 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Add indexes to improve performance of the thread_updates endpoint and
|
||||
-- sliding sync threads extension (MSC4360).
|
||||
|
||||
-- Index for efficiently finding all events that relate to a specific event
|
||||
-- (e.g., all replies to a thread root). This is used by the correlated subquery
|
||||
-- in get_thread_updates_for_user that counts thread updates.
|
||||
-- Also useful for other relation queries (edits, reactions, etc.).
|
||||
CREATE INDEX IF NOT EXISTS event_relations_relates_to_id_type
|
||||
ON event_relations(relates_to_id, relation_type);
|
||||
|
||||
-- Index for the /thread_updates endpoint's cross-room query.
|
||||
-- Allows efficient descending ordering and range filtering of threads
|
||||
-- by stream_ordering across all rooms.
|
||||
CREATE INDEX IF NOT EXISTS threads_stream_ordering_desc
|
||||
ON threads(stream_ordering DESC);
|
||||
|
||||
-- Index for the EXISTS clause that filters threads to only joined rooms.
|
||||
-- Allows efficient lookup of a user's current room memberships.
|
||||
CREATE INDEX IF NOT EXISTS local_current_membership_user_room
|
||||
ON local_current_membership(user_id, membership, room_id);
|
||||
@@ -35,6 +35,32 @@ logger = logging.getLogger(__name__)
|
||||
MAX_LIMIT = 1000
|
||||
|
||||
|
||||
def extract_stream_token_from_pagination_token(token_str: str) -> str:
|
||||
"""
|
||||
Extract the StreamToken portion from a pagination token string.
|
||||
|
||||
Handles both:
|
||||
- StreamToken format: "s123_456_..."
|
||||
- SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /)
|
||||
|
||||
This allows clients using sliding sync to use their pos tokens
|
||||
with endpoints like /relations and /messages.
|
||||
|
||||
Args:
|
||||
token_str: The token string to parse
|
||||
|
||||
Returns:
|
||||
The StreamToken portion of the token
|
||||
"""
|
||||
if "/" in token_str:
|
||||
# SlidingSyncStreamToken format: "connection_position/stream_token"
|
||||
# Split and return just the stream_token part
|
||||
parts = token_str.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1]
|
||||
return token_str
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class PaginationConfig:
|
||||
"""A configuration object which stores pagination parameters."""
|
||||
@@ -62,14 +88,20 @@ class PaginationConfig:
|
||||
if from_tok_str == "END":
|
||||
from_tok = None # For backwards compat.
|
||||
elif from_tok_str:
|
||||
from_tok = await StreamToken.from_string(store, from_tok_str)
|
||||
stream_token_str = extract_stream_token_from_pagination_token(
|
||||
from_tok_str
|
||||
)
|
||||
from_tok = await StreamToken.from_string(store, stream_token_str)
|
||||
except Exception:
|
||||
raise SynapseError(400, "'from' parameter is invalid")
|
||||
|
||||
try:
|
||||
to_tok = None
|
||||
if to_tok_str:
|
||||
to_tok = await StreamToken.from_string(store, to_tok_str)
|
||||
stream_token_str = extract_stream_token_from_pagination_token(
|
||||
to_tok_str
|
||||
)
|
||||
to_tok = await StreamToken.from_string(store, stream_token_str)
|
||||
except Exception:
|
||||
raise SynapseError(400, "'to' parameter is invalid")
|
||||
|
||||
|
||||
@@ -35,6 +35,10 @@ from pydantic import ConfigDict
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.handlers.relations import BundledAggregations, ThreadUpdate
|
||||
|
||||
from synapse.types import (
|
||||
DeviceListUpdates,
|
||||
JsonDict,
|
||||
@@ -388,12 +392,37 @@ class SlidingSyncResult:
|
||||
or bool(self.prev_batch)
|
||||
)
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadsExtension:
|
||||
"""The Threads extension (MSC4360)
|
||||
|
||||
Provides thread updates for threads that have new activity across all of the
|
||||
user's joined rooms within the sync window.
|
||||
|
||||
Attributes:
|
||||
updates: A nested mapping of room_id -> thread_root_id -> ThreadUpdate.
|
||||
Each ThreadUpdate contains information about a thread that has new activity,
|
||||
including the thread root event (if requested) and a pagination token
|
||||
for fetching older events in that specific thread.
|
||||
prev_batch: A pagination token for fetching more thread updates across all rooms.
|
||||
If present, indicates there are more thread updates available beyond what
|
||||
was returned in this response. This token can be used with a future request
|
||||
to paginate through older thread updates.
|
||||
"""
|
||||
|
||||
updates: Mapping[str, Mapping[str, "ThreadUpdate"]] | None
|
||||
prev_batch: StreamToken | None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.updates) or bool(self.prev_batch)
|
||||
|
||||
to_device: ToDeviceExtension | None = None
|
||||
e2ee: E2eeExtension | None = None
|
||||
account_data: AccountDataExtension | None = None
|
||||
receipts: ReceiptsExtension | None = None
|
||||
typing: TypingExtension | None = None
|
||||
thread_subscriptions: ThreadSubscriptionsExtension | None = None
|
||||
threads: ThreadsExtension | None = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(
|
||||
@@ -403,6 +432,7 @@ class SlidingSyncResult:
|
||||
or self.receipts
|
||||
or self.typing
|
||||
or self.thread_subscriptions
|
||||
or self.threads
|
||||
)
|
||||
|
||||
next_pos: SlidingSyncStreamToken
|
||||
@@ -852,6 +882,7 @@ class PerConnectionState:
|
||||
Attributes:
|
||||
rooms: The status of each room for the events stream.
|
||||
receipts: The status of each room for the receipts stream.
|
||||
account_data: The status of each room for the account data stream.
|
||||
room_configs: Map from room_id to the `RoomSyncConfig` of all
|
||||
rooms that we have previously sent down.
|
||||
"""
|
||||
|
||||
@@ -383,6 +383,19 @@ class SlidingSyncBody(RequestBodyModel):
|
||||
enabled: StrictBool | None = False
|
||||
limit: StrictInt = 100
|
||||
|
||||
class ThreadsExtension(RequestBodyModel):
|
||||
"""The Threads extension (MSC4360)
|
||||
|
||||
Attributes:
|
||||
enabled: Whether the threads extension is enabled.
|
||||
include_roots: whether to include thread root events in the extension response.
|
||||
limit: maximum number of thread updates to return across all joined rooms.
|
||||
"""
|
||||
|
||||
enabled: StrictBool | None = False
|
||||
include_roots: StrictBool = False
|
||||
limit: StrictInt = 100
|
||||
|
||||
to_device: ToDeviceExtension | None = None
|
||||
e2ee: E2eeExtension | None = None
|
||||
account_data: AccountDataExtension | None = None
|
||||
@@ -391,6 +404,9 @@ class SlidingSyncBody(RequestBodyModel):
|
||||
thread_subscriptions: ThreadSubscriptionsExtension | None = Field(
|
||||
None, alias="io.element.msc4308.thread_subscriptions"
|
||||
)
|
||||
threads: ThreadsExtension | None = Field(
|
||||
None, alias="io.element.msc4360.threads"
|
||||
)
|
||||
|
||||
conn_id: StrictStr | None = None
|
||||
lists: (
|
||||
|
||||
@@ -340,6 +340,7 @@ T3 = TypeVar("T3")
|
||||
T4 = TypeVar("T4")
|
||||
T5 = TypeVar("T5")
|
||||
T6 = TypeVar("T6")
|
||||
T7 = TypeVar("T7")
|
||||
|
||||
|
||||
@overload
|
||||
@@ -469,6 +470,30 @@ async def gather_optional_coroutines(
|
||||
) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None, T6 | None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[
|
||||
tuple[
|
||||
Coroutine[Any, Any, T1] | None,
|
||||
Coroutine[Any, Any, T2] | None,
|
||||
Coroutine[Any, Any, T3] | None,
|
||||
Coroutine[Any, Any, T4] | None,
|
||||
Coroutine[Any, Any, T5] | None,
|
||||
Coroutine[Any, Any, T6] | None,
|
||||
Coroutine[Any, Any, T7] | None,
|
||||
]
|
||||
],
|
||||
) -> tuple[
|
||||
T1 | None,
|
||||
T2 | None,
|
||||
T3 | None,
|
||||
T4 | None,
|
||||
T5 | None,
|
||||
T6 | None,
|
||||
T7 | None,
|
||||
]: ...
|
||||
|
||||
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None, ...]],
|
||||
) -> tuple[T1 | None, ...]:
|
||||
|
||||
1409
tests/rest/client/sliding_sync/test_extension_threads.py
Normal file
1409
tests/rest/client/sliding_sync/test_extension_threads.py
Normal file
File diff suppressed because it is too large
Load Diff
957
tests/rest/client/test_thread_updates.py
Normal file
957
tests/rest/client/test_thread_updates.py
Normal file
@@ -0,0 +1,957 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
import logging
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import RelationTypes
|
||||
from synapse.rest.client import login, relations, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Test the /thread_updates companion endpoint (MSC4360).
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
relations.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
config["experimental_features"] = {"msc4360_enabled": True}
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
def test_no_updates_for_new_user(self) -> None:
|
||||
"""
|
||||
Test that a user with no thread updates gets an empty response.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Request thread updates
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
content={"include_roots": True},
|
||||
access_token=user1_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert empty chunk and no next_batch
|
||||
self.assertEqual(channel.json_body["chunk"], {})
|
||||
self.assertNotIn("next_batch", channel.json_body)
|
||||
|
||||
def test_single_thread_update(self) -> None:
|
||||
"""
|
||||
Test that a single thread with one reply appears in the response.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create thread root
|
||||
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
|
||||
thread_root_id = thread_root_resp["event_id"]
|
||||
|
||||
# Add reply to thread
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply 1",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert thread is present
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(room_id, chunk)
|
||||
self.assertIn(thread_root_id, chunk[room_id])
|
||||
|
||||
# Assert thread root is included
|
||||
thread_update = chunk[room_id][thread_root_id]
|
||||
self.assertIn("thread_root", thread_update)
|
||||
self.assertEqual(thread_update["thread_root"]["event_id"], thread_root_id)
|
||||
|
||||
# Assert prev_batch is NOT present (only 1 update - the reply)
|
||||
self.assertNotIn("prev_batch", thread_update)
|
||||
|
||||
def test_multiple_threads_single_room(self) -> None:
|
||||
"""
|
||||
Test that multiple threads in the same room are grouped correctly.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create two threads
|
||||
thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Add replies to both threads
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 1",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread1_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 2",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread2_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert both threads are in the same room
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(room_id, chunk)
|
||||
self.assertEqual(len(chunk), 1, "Should only have one room")
|
||||
self.assertEqual(len(chunk[room_id]), 2, "Should have two threads")
|
||||
self.assertIn(thread1_root_id, chunk[room_id])
|
||||
self.assertIn(thread2_root_id, chunk[room_id])
|
||||
|
||||
def test_threads_across_multiple_rooms(self) -> None:
|
||||
"""
|
||||
Test that threads from different rooms are grouped by room_id.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_a_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
room_b_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in both rooms
|
||||
thread_a_root_id = self.helper.send(room_a_id, body="Thread A", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread_b_root_id = self.helper.send(room_b_id, body="Thread B", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Add replies
|
||||
self.helper.send_event(
|
||||
room_a_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to A",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_a_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
room_b_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to B",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_b_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert both rooms are present with their threads
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertEqual(len(chunk), 2, "Should have two rooms")
|
||||
self.assertIn(room_a_id, chunk)
|
||||
self.assertIn(room_b_id, chunk)
|
||||
self.assertIn(thread_a_root_id, chunk[room_a_id])
|
||||
self.assertIn(thread_b_root_id, chunk[room_b_id])
|
||||
|
||||
def test_pagination_with_from_token(self) -> None:
|
||||
"""
|
||||
Test that pagination works using the next_batch token.
|
||||
This verifies that multiple calls to /thread_updates return all thread
|
||||
updates with no duplicates and no gaps.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create many threads (more than default limit)
|
||||
thread_ids = []
|
||||
for i in range(5):
|
||||
thread_root_id = self.helper.send(
|
||||
room_id, body=f"Thread {i}", tok=user1_tok
|
||||
)["event_id"]
|
||||
thread_ids.append(thread_root_id)
|
||||
|
||||
# Add reply
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": f"Reply to thread {i}",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request first page with small limit
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Should have 2 threads and a next_batch token
|
||||
first_page_threads = set(channel.json_body["chunk"][room_id].keys())
|
||||
self.assertEqual(len(first_page_threads), 2)
|
||||
self.assertIn("next_batch", channel.json_body)
|
||||
|
||||
next_batch = channel.json_body["next_batch"]
|
||||
|
||||
# Request second page
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
second_page_threads = set(channel.json_body["chunk"][room_id].keys())
|
||||
self.assertEqual(len(second_page_threads), 2)
|
||||
|
||||
# Verify no overlap
|
||||
self.assertEqual(
|
||||
len(first_page_threads & second_page_threads),
|
||||
0,
|
||||
"Pages should not have overlapping threads",
|
||||
)
|
||||
|
||||
# Request third page to get the remaining thread
|
||||
self.assertIn("next_batch", channel.json_body)
|
||||
next_batch_2 = channel.json_body["next_batch"]
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
third_page_threads = set(channel.json_body["chunk"][room_id].keys())
|
||||
self.assertEqual(len(third_page_threads), 1)
|
||||
|
||||
# Verify no overlap between any pages
|
||||
self.assertEqual(len(first_page_threads & third_page_threads), 0)
|
||||
self.assertEqual(len(second_page_threads & third_page_threads), 0)
|
||||
|
||||
# Verify no gaps - all threads should be accounted for across all pages
|
||||
all_threads = set(thread_ids)
|
||||
combined_threads = first_page_threads | second_page_threads | third_page_threads
|
||||
self.assertEqual(
|
||||
combined_threads,
|
||||
all_threads,
|
||||
"Combined pages should include all thread updates with no gaps",
|
||||
)
|
||||
|
||||
def test_invalid_dir_parameter(self) -> None:
|
||||
"""
|
||||
Test that forward pagination (dir=f) is rejected with an error.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Request with forward direction should fail
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
def test_invalid_limit_parameter(self) -> None:
|
||||
"""
|
||||
Test that invalid limit values are rejected.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Zero limit should fail
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
# Negative limit should fail
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
def test_invalid_pagination_tokens(self) -> None:
|
||||
"""
|
||||
Test that invalid from/to tokens are rejected with appropriate errors.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Invalid from token
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
# Invalid to token
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
def test_to_token_filtering(self) -> None:
|
||||
"""
|
||||
Test that the to_token parameter correctly limits pagination to updates
|
||||
newer than the to_token (since we paginate backwards from newest to oldest).
|
||||
This also verifies the to_token boundary is exclusive - updates at exactly
|
||||
the to_token position should not be included (as they were already returned
|
||||
in a previous response that synced up to that position).
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create two thread roots
|
||||
thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Send replies to both threads
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 1",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread1_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 2",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread2_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request with limit=1 to get only the latest thread update
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertIn("next_batch", channel.json_body)
|
||||
|
||||
# next_batch points to before the update we just received
|
||||
next_batch = channel.json_body["next_batch"]
|
||||
first_response_threads = set(channel.json_body["chunk"][room_id].keys())
|
||||
|
||||
# Request again with to=next_batch (lower bound for backward pagination) and no
|
||||
# limit.
|
||||
# This should get only the same thread updates as before, not the additional
|
||||
# update.
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(room_id, chunk)
|
||||
# Should have exactly one thread update
|
||||
self.assertEqual(len(chunk[room_id]), 1)
|
||||
|
||||
second_response_threads = set(chunk[room_id].keys())
|
||||
|
||||
# Verify no overlap - the from parameter boundary should be exclusive
|
||||
self.assertEqual(
|
||||
first_response_threads,
|
||||
second_response_threads,
|
||||
"to parameter boundary should be exclusive - both responses should be identical",
|
||||
)
|
||||
|
||||
def test_bundled_aggregations_on_thread_roots(self) -> None:
|
||||
"""
|
||||
Test that thread root events include bundled aggregations with latest thread event.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create thread root
|
||||
thread_root_id = self.helper.send(room_id, body="Thread root", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Send replies to create bundled aggregation data
|
||||
for i in range(2):
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": f"Reply {i + 1}",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Check that thread root has bundled aggregations with latest event
|
||||
chunk = channel.json_body["chunk"]
|
||||
thread_update = chunk[room_id][thread_root_id]
|
||||
thread_root_event = thread_update["thread_root"]
|
||||
|
||||
# Should have unsigned data with latest thread event content
|
||||
self.assertIn("unsigned", thread_root_event)
|
||||
self.assertIn("m.relations", thread_root_event["unsigned"])
|
||||
relations = thread_root_event["unsigned"]["m.relations"]
|
||||
self.assertIn(RelationTypes.THREAD, relations)
|
||||
|
||||
# Check latest event is present in bundled aggregations
|
||||
thread_summary = relations[RelationTypes.THREAD]
|
||||
self.assertIn("latest_event", thread_summary)
|
||||
latest_event = thread_summary["latest_event"]
|
||||
self.assertEqual(latest_event["content"]["body"], "Reply 2")
|
||||
|
||||
def test_only_joined_rooms(self) -> None:
|
||||
"""
|
||||
Test that thread updates only include rooms where the user is currently joined.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
user2_id = self.register_user("user2", "pass")
|
||||
user2_tok = self.login(user2_id, "pass")
|
||||
|
||||
# Create two rooms, user1 joins both
|
||||
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
room2_id = self.helper.create_room_as(user2_id, tok=user2_tok)
|
||||
self.helper.join(room2_id, user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in both rooms
|
||||
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user2_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Add replies to both threads
|
||||
self.helper.send_event(
|
||||
room1_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 1",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread1_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
room2_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply to thread 2",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread2_root_id,
|
||||
},
|
||||
},
|
||||
tok=user2_tok,
|
||||
)
|
||||
|
||||
# User1 leaves room2
|
||||
self.helper.leave(room2_id, user1_id, tok=user1_tok)
|
||||
|
||||
# Request thread updates for user1 - should only get room1
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should only have room1, not room2
|
||||
self.assertIn(room1_id, chunk)
|
||||
self.assertNotIn(room2_id, chunk)
|
||||
self.assertIn(thread1_root_id, chunk[room1_id])
|
||||
|
||||
def test_room_filtering_with_lists(self) -> None:
|
||||
"""
|
||||
Test that room filtering works correctly using the lists parameter.
|
||||
This verifies that thread updates are only returned for rooms matching
|
||||
the provided filters.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Create an encrypted room and an unencrypted room
|
||||
encrypted_room_id = self.helper.create_room_as(
|
||||
user1_id,
|
||||
tok=user1_tok,
|
||||
extra_content={
|
||||
"initial_state": [
|
||||
{
|
||||
"type": "m.room.encryption",
|
||||
"state_key": "",
|
||||
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
unencrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in both rooms
|
||||
enc_thread_root_id = self.helper.send(
|
||||
encrypted_room_id, body="Encrypted thread", tok=user1_tok
|
||||
)["event_id"]
|
||||
unenc_thread_root_id = self.helper.send(
|
||||
unencrypted_room_id, body="Unencrypted thread", tok=user1_tok
|
||||
)["event_id"]
|
||||
|
||||
# Add replies to both threads
|
||||
self.helper.send_event(
|
||||
encrypted_room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply in encrypted room",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": enc_thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
unencrypted_room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply in unencrypted room",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": unenc_thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates with filter for encrypted rooms only
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={
|
||||
"lists": {
|
||||
"encrypted_list": {
|
||||
"ranges": [[0, 99]],
|
||||
"required_state": [["m.room.encryption", ""]],
|
||||
"timeline_limit": 10,
|
||||
"filters": {"is_encrypted": True},
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should only include the encrypted room
|
||||
self.assertIn(encrypted_room_id, chunk)
|
||||
self.assertNotIn(unencrypted_room_id, chunk)
|
||||
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
|
||||
|
||||
def test_room_filtering_with_room_subscriptions(self) -> None:
|
||||
"""
|
||||
Test that room filtering works correctly using the room_subscriptions parameter.
|
||||
This verifies that thread updates are only returned for explicitly subscribed rooms.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Create three rooms
|
||||
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
room3_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in all three rooms
|
||||
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread3_root_id = self.helper.send(room3_id, body="Thread 3", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Add replies to all threads
|
||||
for room_id, thread_root_id in [
|
||||
(room1_id, thread1_root_id),
|
||||
(room2_id, thread2_root_id),
|
||||
(room3_id, thread3_root_id),
|
||||
]:
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates with subscription to only room1 and room2
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={
|
||||
"room_subscriptions": {
|
||||
room1_id: {
|
||||
"required_state": [["m.room.name", ""]],
|
||||
"timeline_limit": 10,
|
||||
},
|
||||
room2_id: {
|
||||
"required_state": [["m.room.name", ""]],
|
||||
"timeline_limit": 10,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should only include room1 and room2, not room3
|
||||
self.assertIn(room1_id, chunk)
|
||||
self.assertIn(room2_id, chunk)
|
||||
self.assertNotIn(room3_id, chunk)
|
||||
self.assertIn(thread1_root_id, chunk[room1_id])
|
||||
self.assertIn(thread2_root_id, chunk[room2_id])
|
||||
|
||||
def test_room_filtering_with_lists_and_room_subscriptions(self) -> None:
|
||||
"""
|
||||
Test that room filtering works correctly when both lists and room_subscriptions
|
||||
are provided. The union of rooms from both should be included.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Create an encrypted room and two unencrypted rooms
|
||||
encrypted_room_id = self.helper.create_room_as(
|
||||
user1_id,
|
||||
tok=user1_tok,
|
||||
extra_content={
|
||||
"initial_state": [
|
||||
{
|
||||
"type": "m.room.encryption",
|
||||
"state_key": "",
|
||||
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
unencrypted_room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
unencrypted_room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in all three rooms
|
||||
enc_thread_root_id = self.helper.send(
|
||||
encrypted_room_id, body="Encrypted thread", tok=user1_tok
|
||||
)["event_id"]
|
||||
unenc1_thread_root_id = self.helper.send(
|
||||
unencrypted_room1_id, body="Unencrypted thread 1", tok=user1_tok
|
||||
)["event_id"]
|
||||
unenc2_thread_root_id = self.helper.send(
|
||||
unencrypted_room2_id, body="Unencrypted thread 2", tok=user1_tok
|
||||
)["event_id"]
|
||||
|
||||
# Add replies to all threads
|
||||
for room_id, thread_root_id in [
|
||||
(encrypted_room_id, enc_thread_root_id),
|
||||
(unencrypted_room1_id, unenc1_thread_root_id),
|
||||
(unencrypted_room2_id, unenc2_thread_root_id),
|
||||
]:
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates with:
|
||||
# - lists: filter for encrypted rooms
|
||||
# - room_subscriptions: explicitly subscribe to unencrypted_room1_id
|
||||
# Expected: should get both encrypted_room_id (from list) and unencrypted_room1_id
|
||||
# (from subscription), but NOT unencrypted_room2_id
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={
|
||||
"lists": {
|
||||
"encrypted_list": {
|
||||
"ranges": [[0, 99]],
|
||||
"required_state": [["m.room.encryption", ""]],
|
||||
"timeline_limit": 10,
|
||||
"filters": {"is_encrypted": True},
|
||||
}
|
||||
},
|
||||
"room_subscriptions": {
|
||||
unencrypted_room1_id: {
|
||||
"required_state": [["m.room.name", ""]],
|
||||
"timeline_limit": 10,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should include encrypted_room_id (from list filter) and unencrypted_room1_id
|
||||
# (from subscription), but not unencrypted_room2_id
|
||||
self.assertIn(encrypted_room_id, chunk)
|
||||
self.assertIn(unencrypted_room1_id, chunk)
|
||||
self.assertNotIn(unencrypted_room2_id, chunk)
|
||||
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
|
||||
self.assertIn(unenc1_thread_root_id, chunk[unencrypted_room1_id])
|
||||
|
||||
def test_threads_not_returned_after_leaving_room(self) -> None:
|
||||
"""
|
||||
Test that thread updates are properly bounded when a user leaves a room.
|
||||
|
||||
Users should see thread updates that occurred up to the point they left,
|
||||
but NOT updates that occurred after they left.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
user2_id = self.register_user("user2", "pass")
|
||||
user2_tok = self.login(user2_id, "pass")
|
||||
|
||||
# Create room and both users join
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
self.helper.join(room_id, user2_id, tok=user2_tok)
|
||||
|
||||
# Create thread
|
||||
res = self.helper.send(room_id, body="Thread root", tok=user1_tok)
|
||||
thread_root = res["event_id"]
|
||||
|
||||
# Reply in thread while user2 is joined
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply 1 while user2 joined",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# User2 leaves the room
|
||||
self.helper.leave(room_id, user2_id, tok=user2_tok)
|
||||
|
||||
# Another reply after user2 left
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply 2 after user2 left",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# User2 gets thread updates with an explicit room subscription
|
||||
# (We need to explicitly subscribe to the room to include it after leaving;
|
||||
# otherwise only joined rooms are returned)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=100",
|
||||
{
|
||||
"room_subscriptions": {
|
||||
room_id: {
|
||||
"required_state": [],
|
||||
"timeline_limit": 0,
|
||||
}
|
||||
}
|
||||
},
|
||||
access_token=user2_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert: User2 SHOULD see Reply 1 (happened while joined) but NOT Reply 2 (after leaving)
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(
|
||||
room_id,
|
||||
chunk,
|
||||
"Thread updates should include the room user2 left",
|
||||
)
|
||||
self.assertIn(
|
||||
thread_root,
|
||||
chunk[room_id],
|
||||
"Thread root should be in the updates",
|
||||
)
|
||||
|
||||
# Verify that only a single update was seen (Reply 1) by checking that there's
|
||||
# no prev_batch token. If Reply 2 was also included, there would be multiple
|
||||
# updates and a prev_batch token would be present.
|
||||
thread_update = chunk[room_id][thread_root]
|
||||
self.assertNotIn(
|
||||
"prev_batch",
|
||||
thread_update,
|
||||
"No prev_batch should be present since only one update (Reply 1) is visible",
|
||||
)
|
||||
Reference in New Issue
Block a user