Refactor thread updates to use the same logic between endpoint and extension
This commit is contained in:
@@ -31,7 +31,7 @@ from typing import (
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import Direction, EventTypes, RelationTypes
|
||||
from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase, relation_from_event
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
@@ -43,12 +43,22 @@ from synapse.storage.databases.main.relations import (
|
||||
_RelatedEvent,
|
||||
)
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomStreamToken,
|
||||
StreamKeyType,
|
||||
StreamToken,
|
||||
UserID,
|
||||
)
|
||||
from synapse.util.async_helpers import gather_results
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.events.utils import EventClientSerializer
|
||||
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -59,6 +69,22 @@ ThreadRootsMap = dict[str, EventBase]
|
||||
AggregationsMap = dict[str, "BundledAggregations"]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadUpdate:
|
||||
"""
|
||||
Data for a single thread update.
|
||||
|
||||
Attributes:
|
||||
thread_root: The thread root event, or None if not requested/not visible
|
||||
prev_batch: Per-thread pagination token for fetching older events in this thread
|
||||
bundled_aggregations: Bundled aggregations for the thread root event
|
||||
"""
|
||||
|
||||
thread_root: EventBase | None
|
||||
prev_batch: StreamToken | None
|
||||
bundled_aggregations: "BundledAggregations | None" = None
|
||||
|
||||
|
||||
class ThreadsListInclude(str, enum.Enum):
|
||||
"""Valid values for the 'include' flag of /threads."""
|
||||
|
||||
@@ -554,7 +580,7 @@ class RelationsHandler:
|
||||
|
||||
return results
|
||||
|
||||
async def process_thread_updates_for_visibility(
|
||||
async def _filter_thread_updates_for_user(
|
||||
self,
|
||||
all_thread_updates: ThreadUpdatesMap,
|
||||
user_id: str,
|
||||
@@ -596,37 +622,324 @@ class RelationsHandler:
|
||||
|
||||
return filtered_updates
|
||||
|
||||
async def fetch_thread_roots_and_aggregations(
|
||||
def _build_thread_updates_response(
|
||||
self,
|
||||
thread_ids: Collection[str],
|
||||
user_id: str,
|
||||
) -> tuple[ThreadRootsMap, AggregationsMap]:
|
||||
"""Fetch thread root events and their bundled aggregations.
|
||||
filtered_updates: ThreadUpdatesMap,
|
||||
thread_root_event_map: ThreadRootsMap,
|
||||
aggregations_map: AggregationsMap,
|
||||
global_prev_batch_token: StreamToken | None,
|
||||
) -> dict[str, dict[str, ThreadUpdate]]:
|
||||
"""Build thread update response structure with per-thread prev_batch tokens.
|
||||
|
||||
Args:
|
||||
thread_ids: The thread root event IDs to fetch
|
||||
user_id: The user ID requesting the aggregations
|
||||
filtered_updates: Map of thread_root_id to list of ThreadUpdateInfo
|
||||
thread_root_event_map: Map of thread_root_id to EventBase
|
||||
aggregations_map: Map of thread_root_id to BundledAggregations
|
||||
global_prev_batch_token: Global pagination token, or None if no more results
|
||||
|
||||
Returns:
|
||||
Map of room_id to thread_root_id to ThreadUpdate
|
||||
"""
|
||||
thread_updates: dict[str, dict[str, ThreadUpdate]] = {}
|
||||
|
||||
for thread_root_id, updates in filtered_updates.items():
|
||||
# We only care about the latest update for the thread
|
||||
# Updates are already sorted by stream_ordering DESC from the database query,
|
||||
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
|
||||
# the latest event for each thread.
|
||||
latest_update = updates[0]
|
||||
room_id = latest_update.room_id
|
||||
|
||||
# Generate per-thread prev_batch token if this thread has multiple visible updates
|
||||
# or if we hit the global limit.
|
||||
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
|
||||
# we only saw 1 update for them. This is to cover the case where we only saw
|
||||
# a single update for a given thread, but the global limit prevents us from
|
||||
# obtaining other updates which would have otherwise been included in the range.
|
||||
per_thread_prev_batch = None
|
||||
if len(updates) > 1 or global_prev_batch_token is not None:
|
||||
# Create a token pointing to one position before the latest event's stream position.
|
||||
# This makes it exclusive - /relations with dir=b won't return the latest event again.
|
||||
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
|
||||
per_thread_prev_batch = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM,
|
||||
RoomStreamToken(stream=latest_update.stream_ordering - 1),
|
||||
)
|
||||
|
||||
if room_id not in thread_updates:
|
||||
thread_updates[room_id] = {}
|
||||
|
||||
thread_updates[room_id][thread_root_id] = ThreadUpdate(
|
||||
thread_root=thread_root_event_map.get(thread_root_id),
|
||||
prev_batch=per_thread_prev_batch,
|
||||
bundled_aggregations=aggregations_map.get(thread_root_id),
|
||||
)
|
||||
|
||||
return thread_updates
|
||||
|
||||
async def _fetch_thread_updates(
|
||||
self,
|
||||
room_ids: frozenset[str],
|
||||
room_membership_map: Mapping[str, "RoomsForUserType"],
|
||||
from_token: StreamToken | None,
|
||||
to_token: StreamToken,
|
||||
limit: int,
|
||||
exclude_thread_ids: set[str] | None = None,
|
||||
) -> tuple[ThreadUpdatesMap, StreamToken | None]:
|
||||
"""Fetch thread updates across multiple rooms, handling membership states properly.
|
||||
|
||||
This method separates rooms based on membership status (LEAVE/BAN vs others)
|
||||
and queries them appropriately to prevent data leaks. For rooms where the user
|
||||
has left or been banned, we bound the query to their leave/ban event position.
|
||||
|
||||
Args:
|
||||
room_ids: The set of room IDs to fetch thread updates for
|
||||
room_membership_map: Map of room_id to RoomsForUserType containing membership info
|
||||
from_token: Lower bound (exclusive) for the query, or None for no lower bound
|
||||
to_token: Upper bound for the query (for joined/invited/knocking rooms)
|
||||
limit: Maximum number of thread updates to return across all rooms
|
||||
exclude_thread_ids: Optional set of thread IDs to exclude from results
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- Map of event_id to EventBase for thread root events
|
||||
- Map of event_id to BundledAggregations for those events
|
||||
- Map of thread_id to list of ThreadUpdateInfo objects
|
||||
- Global prev_batch token if there are more results, None otherwise
|
||||
"""
|
||||
# Fetch thread root events
|
||||
thread_root_events = await self._main_store.get_events_as_list(thread_ids)
|
||||
thread_root_event_map: ThreadRootsMap = {
|
||||
e.event_id: e for e in thread_root_events
|
||||
}
|
||||
# Separate rooms based on membership to handle LEAVE/BAN rooms specially
|
||||
leave_ban_rooms: set[str] = set()
|
||||
other_rooms: set[str] = set()
|
||||
|
||||
# Fetch bundled aggregations for the thread roots
|
||||
aggregations_map: AggregationsMap = {}
|
||||
if thread_root_event_map:
|
||||
aggregations_map = await self.get_bundled_aggregations(
|
||||
thread_root_event_map.values(),
|
||||
user_id,
|
||||
for room_id in room_ids:
|
||||
membership_info = room_membership_map.get(room_id)
|
||||
if membership_info and membership_info.membership in (
|
||||
Membership.LEAVE,
|
||||
Membership.BAN,
|
||||
):
|
||||
leave_ban_rooms.add(room_id)
|
||||
else:
|
||||
other_rooms.add(room_id)
|
||||
|
||||
# Fetch thread updates from storage, handling LEAVE/BAN rooms separately
|
||||
all_thread_updates: ThreadUpdatesMap = {}
|
||||
prev_batch_token: StreamToken | None = None
|
||||
remaining_limit = limit
|
||||
|
||||
# Query LEAVE/BAN rooms with bounded to_token to prevent data leaks
|
||||
if leave_ban_rooms:
|
||||
for room_id in leave_ban_rooms:
|
||||
if remaining_limit <= 0:
|
||||
# We've hit the limit, set prev_batch to indicate more results
|
||||
prev_batch_token = to_token
|
||||
break
|
||||
|
||||
membership_info = room_membership_map[room_id]
|
||||
bounded_to_token = membership_info.event_pos.to_room_stream_token()
|
||||
|
||||
(
|
||||
room_thread_updates,
|
||||
room_prev_batch,
|
||||
) = await self._main_store.get_thread_updates_for_rooms(
|
||||
room_ids={room_id},
|
||||
from_token=from_token.room_key if from_token else None,
|
||||
to_token=bounded_to_token,
|
||||
limit=remaining_limit,
|
||||
exclude_thread_ids=exclude_thread_ids,
|
||||
)
|
||||
|
||||
# Count updates and reduce remaining limit
|
||||
num_updates = sum(
|
||||
len(updates) for updates in room_thread_updates.values()
|
||||
)
|
||||
remaining_limit -= num_updates
|
||||
|
||||
# Merge updates
|
||||
for thread_id, updates in room_thread_updates.items():
|
||||
all_thread_updates.setdefault(thread_id, []).extend(updates)
|
||||
|
||||
# Merge prev_batch tokens (take the maximum for backward pagination)
|
||||
if room_prev_batch is not None:
|
||||
if prev_batch_token is None:
|
||||
prev_batch_token = room_prev_batch
|
||||
elif (
|
||||
room_prev_batch.room_key.stream
|
||||
> prev_batch_token.room_key.stream
|
||||
):
|
||||
prev_batch_token = room_prev_batch
|
||||
|
||||
# Query other rooms (joined/invited/knocking) with normal to_token
|
||||
if other_rooms and remaining_limit > 0:
|
||||
(
|
||||
other_thread_updates,
|
||||
other_prev_batch,
|
||||
) = await self._main_store.get_thread_updates_for_rooms(
|
||||
room_ids=other_rooms,
|
||||
from_token=from_token.room_key if from_token else None,
|
||||
to_token=to_token.room_key,
|
||||
limit=remaining_limit,
|
||||
exclude_thread_ids=exclude_thread_ids,
|
||||
)
|
||||
|
||||
return thread_root_event_map, aggregations_map
|
||||
# Merge updates
|
||||
for thread_id, updates in other_thread_updates.items():
|
||||
all_thread_updates.setdefault(thread_id, []).extend(updates)
|
||||
|
||||
# Merge prev_batch tokens
|
||||
if other_prev_batch is not None:
|
||||
if prev_batch_token is None:
|
||||
prev_batch_token = other_prev_batch
|
||||
elif (
|
||||
other_prev_batch.room_key.stream > prev_batch_token.room_key.stream
|
||||
):
|
||||
prev_batch_token = other_prev_batch
|
||||
|
||||
return all_thread_updates, prev_batch_token
|
||||
|
||||
async def get_thread_updates_for_rooms(
|
||||
self,
|
||||
room_ids: frozenset[str],
|
||||
room_membership_map: Mapping[str, "RoomsForUserType"],
|
||||
user_id: str,
|
||||
from_token: StreamToken | None,
|
||||
to_token: StreamToken,
|
||||
limit: int,
|
||||
include_roots: bool = False,
|
||||
exclude_thread_ids: set[str] | None = None,
|
||||
) -> tuple[dict[str, dict[str, ThreadUpdate]], StreamToken | None]:
|
||||
"""Get thread updates across multiple rooms with full processing pipeline.
|
||||
|
||||
This is the main entry point for fetching thread updates. It handles:
|
||||
- Fetching updates with membership-based security
|
||||
- Filtering for visibility
|
||||
- Optionally fetching thread roots and aggregations
|
||||
- Building the response structure
|
||||
|
||||
Args:
|
||||
room_ids: The set of room IDs to fetch updates for
|
||||
room_membership_map: Map of room_id to RoomsForUserType for membership info
|
||||
user_id: The user requesting the updates
|
||||
from_token: Lower bound (exclusive) for the query
|
||||
to_token: Upper bound for the query
|
||||
limit: Maximum number of updates to return
|
||||
include_roots: Whether to fetch and include thread root events (default: False)
|
||||
exclude_thread_ids: Optional set of thread IDs to exclude
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- Map of room_id to thread_root_id to ThreadUpdate
|
||||
- Global prev_batch token if there are more results, None otherwise
|
||||
"""
|
||||
# Fetch thread updates with membership handling
|
||||
all_thread_updates, prev_batch_token = await self._fetch_thread_updates(
|
||||
room_ids=room_ids,
|
||||
room_membership_map=room_membership_map,
|
||||
from_token=from_token,
|
||||
to_token=to_token,
|
||||
limit=limit,
|
||||
exclude_thread_ids=exclude_thread_ids,
|
||||
)
|
||||
|
||||
if not all_thread_updates:
|
||||
return {}, prev_batch_token
|
||||
|
||||
# Filter thread updates for visibility
|
||||
filtered_updates = await self._filter_thread_updates_for_user(
|
||||
all_thread_updates, user_id
|
||||
)
|
||||
|
||||
if not filtered_updates:
|
||||
return {}, prev_batch_token
|
||||
|
||||
# Optionally fetch thread root events and their bundled aggregations
|
||||
thread_root_event_map: ThreadRootsMap = {}
|
||||
aggregations_map: AggregationsMap = {}
|
||||
if include_roots:
|
||||
# Fetch thread root events
|
||||
thread_root_events = await self._main_store.get_events_as_list(
|
||||
filtered_updates.keys()
|
||||
)
|
||||
thread_root_event_map = {e.event_id: e for e in thread_root_events}
|
||||
|
||||
# Fetch bundled aggregations for the thread roots
|
||||
if thread_root_event_map:
|
||||
aggregations_map = await self.get_bundled_aggregations(
|
||||
thread_root_event_map.values(),
|
||||
user_id,
|
||||
)
|
||||
|
||||
# Build response structure with per-thread prev_batch tokens
|
||||
thread_updates = self._build_thread_updates_response(
|
||||
filtered_updates=filtered_updates,
|
||||
thread_root_event_map=thread_root_event_map,
|
||||
aggregations_map=aggregations_map,
|
||||
global_prev_batch_token=prev_batch_token,
|
||||
)
|
||||
|
||||
return thread_updates, prev_batch_token
|
||||
|
||||
@staticmethod
|
||||
async def serialize_thread_updates(
|
||||
thread_updates: Mapping[str, Mapping[str, ThreadUpdate]],
|
||||
prev_batch_token: StreamToken | None,
|
||||
event_serializer: "EventClientSerializer",
|
||||
time_now: int,
|
||||
store: "DataStore",
|
||||
serialize_options: SerializeEventConfig,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Serialize thread updates to JSON format.
|
||||
|
||||
This helper handles serialization of ThreadUpdate objects for both the
|
||||
companion endpoint and the sliding sync extension.
|
||||
|
||||
Args:
|
||||
thread_updates: Map of room_id to thread_root_id to ThreadUpdate
|
||||
prev_batch_token: Global pagination token for fetching more updates
|
||||
event_serializer: The event serializer to use
|
||||
time_now: Current time in milliseconds for event serialization
|
||||
store: Datastore for serializing stream tokens
|
||||
serialize_options: Serialization config
|
||||
|
||||
Returns:
|
||||
JSON-serializable dict with "updates" and optionally "prev_batch"
|
||||
"""
|
||||
updates_dict: JsonDict = {}
|
||||
|
||||
for room_id, room_threads in thread_updates.items():
|
||||
room_updates: JsonDict = {}
|
||||
for thread_root_id, update in room_threads.items():
|
||||
update_dict: JsonDict = {}
|
||||
|
||||
# Serialize thread_root event if present
|
||||
if update.thread_root is not None:
|
||||
bundle_aggs_map = (
|
||||
{thread_root_id: update.bundled_aggregations}
|
||||
if update.bundled_aggregations is not None
|
||||
else None
|
||||
)
|
||||
serialized_events = await event_serializer.serialize_events(
|
||||
[update.thread_root],
|
||||
time_now,
|
||||
config=serialize_options,
|
||||
bundle_aggregations=bundle_aggs_map,
|
||||
)
|
||||
if serialized_events:
|
||||
update_dict["thread_root"] = serialized_events[0]
|
||||
|
||||
# Add per-thread prev_batch if present
|
||||
if update.prev_batch is not None:
|
||||
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
|
||||
|
||||
room_updates[thread_root_id] = update_dict
|
||||
|
||||
updates_dict[room_id] = room_updates
|
||||
|
||||
result: JsonDict = {"updates": updates_dict}
|
||||
|
||||
# Add global prev_batch token if present
|
||||
if prev_batch_token is not None:
|
||||
result["prev_batch"] = await prev_batch_token.to_string(store)
|
||||
|
||||
return result
|
||||
|
||||
async def get_threads(
|
||||
self,
|
||||
|
||||
@@ -30,28 +30,20 @@ from synapse.api.constants import (
|
||||
AccountDataTypes,
|
||||
EduTypes,
|
||||
EventContentFields,
|
||||
Membership,
|
||||
MRelatesToFields,
|
||||
RelationTypes,
|
||||
)
|
||||
from synapse.events import EventBase
|
||||
from synapse.handlers.receipts import ReceiptEventSource
|
||||
from synapse.handlers.relations import (
|
||||
AggregationsMap,
|
||||
ThreadRootsMap,
|
||||
)
|
||||
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.databases.main.receipts import ReceiptInRoom
|
||||
from synapse.storage.databases.main.relations import ThreadUpdateInfo
|
||||
from synapse.types import (
|
||||
DeviceListUpdates,
|
||||
JsonMapping,
|
||||
MultiWriterStreamToken,
|
||||
RoomStreamToken,
|
||||
SlidingSyncStreamToken,
|
||||
StrCollection,
|
||||
StreamKeyType,
|
||||
StreamToken,
|
||||
ThreadSubscriptionsToken,
|
||||
)
|
||||
@@ -74,7 +66,6 @@ _ThreadSubscription: TypeAlias = (
|
||||
_ThreadUnsubscription: TypeAlias = (
|
||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
|
||||
)
|
||||
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -1042,42 +1033,6 @@ class SlidingSyncExtensionHandler:
|
||||
threads_in_timeline.add(thread_id)
|
||||
return threads_in_timeline
|
||||
|
||||
def _merge_prev_batch_token(
|
||||
self,
|
||||
current_token: StreamToken | None,
|
||||
new_token: StreamToken | None,
|
||||
) -> StreamToken | None:
|
||||
"""Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination.
|
||||
|
||||
Args:
|
||||
current_token: The current prev_batch token (may be None)
|
||||
new_token: The new prev_batch token to merge (may be None)
|
||||
|
||||
Returns:
|
||||
The merged token (maximum of the two, or None if both are None)
|
||||
"""
|
||||
if new_token is None:
|
||||
return current_token
|
||||
if current_token is None:
|
||||
return new_token
|
||||
if new_token.room_key.stream > current_token.room_key.stream:
|
||||
return new_token
|
||||
return current_token
|
||||
|
||||
def _merge_thread_updates(
|
||||
self,
|
||||
target: dict[str, list[ThreadUpdateInfo]],
|
||||
source: dict[str, list[ThreadUpdateInfo]],
|
||||
) -> None:
|
||||
"""Merge thread updates from source into target.
|
||||
|
||||
Args:
|
||||
target: The target dict to merge into (modified in place)
|
||||
source: The source dict to merge from
|
||||
"""
|
||||
for thread_id, updates in source.items():
|
||||
target.setdefault(thread_id, []).extend(updates)
|
||||
|
||||
async def get_threads_extension_response(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
@@ -1119,141 +1074,26 @@ class SlidingSyncExtensionHandler:
|
||||
actual_room_response_map
|
||||
)
|
||||
|
||||
# Separate rooms into groups based on membership status.
|
||||
# For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events
|
||||
# that occurred after the user left/was banned.
|
||||
leave_ban_rooms: set[str] = set()
|
||||
other_rooms: set[str] = set()
|
||||
|
||||
for room_id in actual_room_ids:
|
||||
membership_info = room_membership_for_user_at_to_token_map.get(room_id)
|
||||
if membership_info and membership_info.membership in (
|
||||
Membership.LEAVE,
|
||||
Membership.BAN,
|
||||
):
|
||||
leave_ban_rooms.add(room_id)
|
||||
else:
|
||||
other_rooms.add(room_id)
|
||||
|
||||
# Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks.
|
||||
all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {}
|
||||
prev_batch_token: StreamToken | None = None
|
||||
remaining_limit = threads_request.limit
|
||||
|
||||
# Query for rooms where the user has left or been banned, using their leave/ban
|
||||
# event position as the upper bound to prevent seeing events after they left.
|
||||
if leave_ban_rooms:
|
||||
for room_id in leave_ban_rooms:
|
||||
if remaining_limit <= 0:
|
||||
# We've already fetched enough updates, but we still need to set
|
||||
# prev_batch to indicate there are more results.
|
||||
prev_batch_token = to_token
|
||||
break
|
||||
|
||||
membership_info = room_membership_for_user_at_to_token_map[room_id]
|
||||
bounded_to_token = membership_info.event_pos.to_room_stream_token()
|
||||
|
||||
(
|
||||
room_thread_updates,
|
||||
room_prev_batch,
|
||||
) = await self.store.get_thread_updates_for_rooms(
|
||||
room_ids={room_id},
|
||||
from_token=from_token.stream_token.room_key if from_token else None,
|
||||
to_token=bounded_to_token,
|
||||
limit=remaining_limit,
|
||||
exclude_thread_ids=threads_to_exclude,
|
||||
)
|
||||
|
||||
# Count how many updates we fetched and reduce the remaining limit
|
||||
num_updates = sum(
|
||||
len(updates) for updates in room_thread_updates.values()
|
||||
)
|
||||
remaining_limit -= num_updates
|
||||
|
||||
self._merge_thread_updates(all_thread_updates, room_thread_updates)
|
||||
prev_batch_token = self._merge_prev_batch_token(
|
||||
prev_batch_token, room_prev_batch
|
||||
)
|
||||
|
||||
# Query for rooms where the user is joined, invited, or knocking, using the
|
||||
# normal to_token as the upper bound.
|
||||
if other_rooms and remaining_limit > 0:
|
||||
(
|
||||
other_thread_updates,
|
||||
other_prev_batch,
|
||||
) = await self.store.get_thread_updates_for_rooms(
|
||||
room_ids=other_rooms,
|
||||
from_token=from_token.stream_token.room_key if from_token else None,
|
||||
to_token=to_token.room_key,
|
||||
limit=remaining_limit,
|
||||
exclude_thread_ids=threads_to_exclude,
|
||||
)
|
||||
|
||||
self._merge_thread_updates(all_thread_updates, other_thread_updates)
|
||||
prev_batch_token = self._merge_prev_batch_token(
|
||||
prev_batch_token, other_prev_batch
|
||||
)
|
||||
|
||||
if len(all_thread_updates) == 0:
|
||||
return None
|
||||
|
||||
# Filter thread updates for visibility
|
||||
# Get thread updates using unified helper
|
||||
user_id = sync_config.user.to_string()
|
||||
filtered_updates = (
|
||||
await self.relations_handler.process_thread_updates_for_visibility(
|
||||
all_thread_updates, user_id
|
||||
)
|
||||
(
|
||||
thread_updates_response,
|
||||
prev_batch_token,
|
||||
) = await self.relations_handler.get_thread_updates_for_rooms(
|
||||
room_ids=frozenset(actual_room_ids),
|
||||
room_membership_map=room_membership_for_user_at_to_token_map,
|
||||
user_id=user_id,
|
||||
from_token=from_token.stream_token if from_token else None,
|
||||
to_token=to_token,
|
||||
limit=threads_request.limit,
|
||||
include_roots=threads_request.include_roots,
|
||||
exclude_thread_ids=threads_to_exclude,
|
||||
)
|
||||
|
||||
if not filtered_updates:
|
||||
if not thread_updates_response:
|
||||
return None
|
||||
|
||||
# Note: Updates are already sorted by stream_ordering DESC from the database query,
|
||||
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
|
||||
# the latest event for each thread.
|
||||
|
||||
# Optionally fetch thread root events and their bundled aggregations
|
||||
thread_root_event_map: ThreadRootsMap = {}
|
||||
aggregations_map: AggregationsMap = {}
|
||||
if threads_request.include_roots:
|
||||
(
|
||||
thread_root_event_map,
|
||||
aggregations_map,
|
||||
) = await self.relations_handler.fetch_thread_roots_and_aggregations(
|
||||
filtered_updates.keys(), user_id
|
||||
)
|
||||
|
||||
thread_updates: dict[str, dict[str, _ThreadUpdate]] = {}
|
||||
for thread_root, updates in filtered_updates.items():
|
||||
# We only care about the latest update for the thread.
|
||||
# After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering).
|
||||
latest_update = updates[0]
|
||||
|
||||
# Generate per-thread prev_batch token if this thread has multiple visible updates.
|
||||
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
|
||||
# we only saw 1 update for them. This is to cover the case where we only saw
|
||||
# a single update for a given thread, but the global limit prevents us from
|
||||
# obtaining other updates which would have otherwise been included in the
|
||||
# range.
|
||||
per_thread_prev_batch = None
|
||||
if len(updates) > 1 or prev_batch_token is not None:
|
||||
# Create a token pointing to one position before the latest event's stream position.
|
||||
# This makes it exclusive - /relations with dir=b won't return the latest event again.
|
||||
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
|
||||
per_thread_prev_batch = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM,
|
||||
RoomStreamToken(stream=latest_update.stream_ordering - 1),
|
||||
)
|
||||
|
||||
thread_updates.setdefault(latest_update.room_id, {})[thread_root] = (
|
||||
_ThreadUpdate(
|
||||
thread_root=thread_root_event_map.get(thread_root),
|
||||
prev_batch=per_thread_prev_batch,
|
||||
bundled_aggregations=aggregations_map.get(thread_root),
|
||||
)
|
||||
)
|
||||
|
||||
return SlidingSyncResult.Extensions.ThreadsExtension(
|
||||
updates=thread_updates,
|
||||
updates=thread_updates_response,
|
||||
prev_batch=prev_batch_token,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.api.constants import Direction, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
from synapse.handlers.relations import ThreadsListInclude
|
||||
@@ -41,7 +41,8 @@ from synapse.streams.config import (
|
||||
PaginationConfig,
|
||||
extract_stream_token_from_pagination_token,
|
||||
)
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
|
||||
from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig
|
||||
from synapse.types.rest.client import ThreadUpdatesBody
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -166,8 +167,7 @@ class ThreadUpdatesServlet(RestServlet):
|
||||
self.relations_handler = hs.get_relations_handler()
|
||||
self.event_serializer = hs.get_event_client_serializer()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
# TODO: Get sliding sync handler for filter_rooms logic
|
||||
# self.sliding_sync_handler = hs.get_sliding_sync_handler()
|
||||
self.sliding_sync_handler = hs.get_sliding_sync_handler()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
@@ -191,135 +191,116 @@ class ThreadUpdatesServlet(RestServlet):
|
||||
to_token_str = parse_string(request, "to")
|
||||
|
||||
# Parse pagination tokens
|
||||
from_token: RoomStreamToken | None = None
|
||||
to_token: RoomStreamToken | None = None
|
||||
from_token: StreamToken | None = None
|
||||
to_token: StreamToken | None = None
|
||||
|
||||
if from_token_str:
|
||||
try:
|
||||
stream_token_str = extract_stream_token_from_pagination_token(
|
||||
from_token_str
|
||||
)
|
||||
stream_token = await StreamToken.from_string(
|
||||
self.store, stream_token_str
|
||||
)
|
||||
from_token = stream_token.room_key
|
||||
except Exception:
|
||||
raise SynapseError(400, "'from' parameter is invalid")
|
||||
from_token = await StreamToken.from_string(self.store, stream_token_str)
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing 'from' token: %s", from_token_str)
|
||||
raise SynapseError(400, "'from' parameter is invalid") from e
|
||||
|
||||
if to_token_str:
|
||||
try:
|
||||
stream_token_str = extract_stream_token_from_pagination_token(
|
||||
to_token_str
|
||||
)
|
||||
stream_token = await StreamToken.from_string(
|
||||
self.store, stream_token_str
|
||||
)
|
||||
to_token = stream_token.room_key
|
||||
to_token = await StreamToken.from_string(self.store, stream_token_str)
|
||||
except Exception:
|
||||
raise SynapseError(400, "'to' parameter is invalid")
|
||||
|
||||
# Get the list of rooms to fetch thread updates for
|
||||
# Start with all joined rooms, then apply filters if provided
|
||||
user_id = requester.user.to_string()
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
if body.filters is not None:
|
||||
# TODO: Apply filters using sliding sync room filter logic
|
||||
# For now, if filters are provided, we need to call the sliding sync
|
||||
# filter_rooms method to get the applicable room IDs
|
||||
raise SynapseError(501, "Room filters not yet implemented")
|
||||
|
||||
# Fetch thread updates from storage
|
||||
# For backward pagination:
|
||||
# - 'from' (upper bound, exclusive) maps to 'to_token' (inclusive with <=)
|
||||
# Since next_batch is (last_returned - 1), <= excludes the last returned item
|
||||
# - 'to' (lower bound, exclusive) maps to 'from_token' (exclusive with >)
|
||||
(
|
||||
all_thread_updates,
|
||||
prev_batch_token,
|
||||
) = await self.store.get_thread_updates_for_rooms(
|
||||
room_ids=room_ids,
|
||||
from_token=to_token,
|
||||
to_token=from_token,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if len(all_thread_updates) == 0:
|
||||
return 200, {"chunk": {}}
|
||||
|
||||
# Filter thread updates for visibility
|
||||
filtered_updates = (
|
||||
await self.relations_handler.process_thread_updates_for_visibility(
|
||||
all_thread_updates, user_id
|
||||
# Get the current stream token for membership lookup
|
||||
if from_token is None:
|
||||
max_stream_ordering = self.store.get_room_max_stream_ordering()
|
||||
current_token = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM, RoomStreamToken(stream=max_stream_ordering)
|
||||
)
|
||||
else:
|
||||
current_token = from_token
|
||||
|
||||
# Get room membership information to properly handle LEAVE/BAN rooms
|
||||
(
|
||||
room_membership_for_user_at_to_token_map,
|
||||
_,
|
||||
_,
|
||||
) = await self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
|
||||
user=user,
|
||||
to_token=current_token,
|
||||
from_token=None,
|
||||
)
|
||||
|
||||
if not filtered_updates:
|
||||
# Determine which rooms to fetch updates for based on lists/room_subscriptions
|
||||
if body.lists is not None or body.room_subscriptions is not None:
|
||||
# Use sliding sync room selection logic
|
||||
sync_config = SlidingSyncConfig(
|
||||
user=user,
|
||||
requester=requester,
|
||||
lists=body.lists,
|
||||
room_subscriptions=body.room_subscriptions,
|
||||
)
|
||||
|
||||
# Use the sliding sync room list handler to get the same set of rooms
|
||||
interested_rooms = (
|
||||
await self.sliding_sync_handler.room_lists.compute_interested_rooms(
|
||||
sync_config=sync_config,
|
||||
previous_connection_state=PerConnectionState(),
|
||||
to_token=current_token,
|
||||
from_token=None,
|
||||
)
|
||||
)
|
||||
|
||||
room_ids = frozenset(interested_rooms.relevant_room_map.keys())
|
||||
else:
|
||||
# No lists or room_subscriptions, use only joined rooms
|
||||
room_ids = frozenset(
|
||||
room_id
|
||||
for room_id, membership_info in room_membership_for_user_at_to_token_map.items()
|
||||
if membership_info.membership == Membership.JOIN
|
||||
)
|
||||
|
||||
# Get thread updates using unified helper
|
||||
(
|
||||
thread_updates,
|
||||
prev_batch_token,
|
||||
) = await self.relations_handler.get_thread_updates_for_rooms(
|
||||
room_ids=room_ids,
|
||||
room_membership_map=room_membership_for_user_at_to_token_map,
|
||||
user_id=user_id,
|
||||
from_token=to_token,
|
||||
to_token=from_token if from_token else current_token,
|
||||
limit=limit,
|
||||
include_roots=body.include_roots,
|
||||
)
|
||||
|
||||
if not thread_updates:
|
||||
return 200, {"chunk": {}}
|
||||
|
||||
# Fetch thread root events and their bundled aggregations
|
||||
(
|
||||
thread_root_event_map,
|
||||
aggregations_map,
|
||||
) = await self.relations_handler.fetch_thread_roots_and_aggregations(
|
||||
filtered_updates.keys(), user_id
|
||||
)
|
||||
|
||||
# Build response with per-thread data
|
||||
# Updates are already sorted by stream_ordering DESC from the database query,
|
||||
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
|
||||
# the latest event for each thread.
|
||||
# Serialize thread updates using shared helper
|
||||
time_now = self.clock.time_msec()
|
||||
serialize_options = SerializeEventConfig(requester=requester)
|
||||
chunk: dict[str, dict[str, JsonDict]] = {}
|
||||
|
||||
for thread_root_id, updates in filtered_updates.items():
|
||||
# We only care about the latest update for the thread
|
||||
latest_update = updates[0]
|
||||
room_id = latest_update.room_id
|
||||
serialized = await self.relations_handler.serialize_thread_updates(
|
||||
thread_updates=thread_updates,
|
||||
prev_batch_token=prev_batch_token,
|
||||
event_serializer=self.event_serializer,
|
||||
time_now=time_now,
|
||||
store=self.store,
|
||||
serialize_options=serialize_options,
|
||||
)
|
||||
|
||||
if room_id not in chunk:
|
||||
chunk[room_id] = {}
|
||||
|
||||
update_dict: JsonDict = {}
|
||||
|
||||
# Add thread root if present
|
||||
thread_root_event = thread_root_event_map.get(thread_root_id)
|
||||
if thread_root_event is not None:
|
||||
bundle_aggs_map = (
|
||||
{thread_root_id: aggregations_map[thread_root_id]}
|
||||
if thread_root_id in aggregations_map
|
||||
else None
|
||||
)
|
||||
serialized_events = await self.event_serializer.serialize_events(
|
||||
[thread_root_event],
|
||||
time_now,
|
||||
config=serialize_options,
|
||||
bundle_aggregations=bundle_aggs_map,
|
||||
)
|
||||
if serialized_events:
|
||||
update_dict["thread_root"] = serialized_events[0]
|
||||
|
||||
# Add per-thread prev_batch if this thread has multiple visible updates
|
||||
if len(updates) > 1:
|
||||
# Create a token pointing to one position before the latest event's stream position.
|
||||
# This makes it exclusive - /relations with dir=b won't return the latest event again.
|
||||
per_thread_prev_batch = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM,
|
||||
RoomStreamToken(stream=latest_update.stream_ordering - 1),
|
||||
)
|
||||
update_dict["prev_batch"] = await per_thread_prev_batch.to_string(
|
||||
self.store
|
||||
)
|
||||
|
||||
chunk[room_id][thread_root_id] = update_dict
|
||||
|
||||
# Build response
|
||||
response: JsonDict = {"chunk": chunk}
|
||||
|
||||
# Add next_batch token for pagination
|
||||
if prev_batch_token is not None:
|
||||
response["next_batch"] = await prev_batch_token.to_string(self.store)
|
||||
# Build response with "chunk" wrapper and "next_batch" key
|
||||
# (companion endpoint uses different key names than sliding sync)
|
||||
response: JsonDict = {"chunk": serialized["updates"]}
|
||||
if "prev_batch" in serialized:
|
||||
response["next_batch"] = serialized["prev_batch"]
|
||||
|
||||
return 200, response
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from synapse.events.utils import (
|
||||
format_event_raw,
|
||||
)
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.handlers.relations import RelationsHandler
|
||||
from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult
|
||||
from synapse.handlers.sync import (
|
||||
ArchivedSyncResult,
|
||||
@@ -1107,6 +1108,7 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
time_now,
|
||||
extensions.threads,
|
||||
self.store,
|
||||
requester,
|
||||
)
|
||||
|
||||
return serialized_extensions
|
||||
@@ -1150,6 +1152,7 @@ async def _serialise_threads(
|
||||
time_now: int,
|
||||
threads: SlidingSyncResult.Extensions.ThreadsExtension,
|
||||
store: "DataStore",
|
||||
requester: Requester,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Serialize the threads extension response for sliding sync.
|
||||
@@ -1159,6 +1162,7 @@ async def _serialise_threads(
|
||||
time_now: The current time in milliseconds, used for event serialization.
|
||||
threads: The threads extension data containing thread updates and pagination tokens.
|
||||
store: The datastore, needed for serializing stream tokens.
|
||||
requester: The user making the request, used for transaction_id inclusion.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable dict containing:
|
||||
@@ -1169,46 +1173,24 @@ async def _serialise_threads(
|
||||
- "prev_batch": A pagination token for fetching older events in the thread.
|
||||
- "prev_batch": A pagination token for fetching older thread updates (if available).
|
||||
"""
|
||||
out: JsonDict = {}
|
||||
if not threads.updates:
|
||||
out: JsonDict = {}
|
||||
if threads.prev_batch:
|
||||
out["prev_batch"] = await threads.prev_batch.to_string(store)
|
||||
return out
|
||||
|
||||
if threads.updates:
|
||||
updates_dict: JsonDict = {}
|
||||
for room_id, thread_updates in threads.updates.items():
|
||||
room_updates: JsonDict = {}
|
||||
for thread_root_id, update in thread_updates.items():
|
||||
# Serialize the update
|
||||
update_dict: JsonDict = {}
|
||||
# Create serialization config to include transaction_id for requester's events
|
||||
serialize_options = SerializeEventConfig(requester=requester)
|
||||
|
||||
# Serialize the thread_root event if present
|
||||
if update.thread_root is not None:
|
||||
# Create a mapping of event_id to bundled_aggregations
|
||||
bundle_aggs_map = (
|
||||
{thread_root_id: update.bundled_aggregations}
|
||||
if update.bundled_aggregations
|
||||
else None
|
||||
)
|
||||
serialized_events = await event_serializer.serialize_events(
|
||||
[update.thread_root],
|
||||
time_now,
|
||||
bundle_aggregations=bundle_aggs_map,
|
||||
)
|
||||
if serialized_events:
|
||||
update_dict["thread_root"] = serialized_events[0]
|
||||
|
||||
# Add prev_batch if present
|
||||
if update.prev_batch is not None:
|
||||
update_dict["prev_batch"] = await update.prev_batch.to_string(store)
|
||||
|
||||
room_updates[thread_root_id] = update_dict
|
||||
|
||||
updates_dict[room_id] = room_updates
|
||||
|
||||
out["updates"] = updates_dict
|
||||
|
||||
if threads.prev_batch:
|
||||
out["prev_batch"] = await threads.prev_batch.to_string(store)
|
||||
|
||||
return out
|
||||
# Use shared serialization helper (static method)
|
||||
return await RelationsHandler.serialize_thread_updates(
|
||||
thread_updates=threads.updates,
|
||||
prev_batch_token=threads.prev_batch,
|
||||
event_serializer=event_serializer,
|
||||
time_now=time_now,
|
||||
store=store,
|
||||
serialize_options=serialize_options,
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
|
||||
@@ -37,7 +37,8 @@ from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.handlers.relations import BundledAggregations
|
||||
from synapse.handlers.relations import BundledAggregations, ThreadUpdate
|
||||
|
||||
from synapse.types import (
|
||||
DeviceListUpdates,
|
||||
JsonDict,
|
||||
@@ -409,30 +410,7 @@ class SlidingSyncResult:
|
||||
to paginate through older thread updates.
|
||||
"""
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadUpdate:
|
||||
"""Information about a single thread that has new activity.
|
||||
|
||||
Attributes:
|
||||
thread_root: The thread root event, if requested via include_roots in the
|
||||
request. This is the event that started the thread.
|
||||
prev_batch: A pagination token (exclusive) for fetching older events in this
|
||||
specific thread. Only present if the thread has multiple updates in the
|
||||
sync window. This token can be used with the /relations endpoint with
|
||||
dir=b to paginate backwards through the thread's history.
|
||||
bundled_aggregations: Bundled aggregations for the thread root event,
|
||||
including the latest_event in the thread (found in
|
||||
unsigned.m.relations.m.thread). Only present if thread_root is included.
|
||||
"""
|
||||
|
||||
thread_root: EventBase | None
|
||||
prev_batch: StreamToken | None
|
||||
bundled_aggregations: "BundledAggregations | None" = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.thread_root) or bool(self.prev_batch)
|
||||
|
||||
updates: Mapping[str, Mapping[str, ThreadUpdate]] | None
|
||||
updates: Mapping[str, Mapping[str, "ThreadUpdate"]] | None
|
||||
prev_batch: StreamToken | None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
|
||||
@@ -433,14 +433,30 @@ class ThreadUpdatesBody(RequestBodyModel):
|
||||
"""
|
||||
Thread updates companion endpoint request body (MSC4360).
|
||||
|
||||
Allows filtering thread updates using the same filter criteria as sliding sync lists.
|
||||
This enables clients to paginate thread updates using the same room filters that
|
||||
were applied when generating the prev_batch token.
|
||||
Allows paginating thread updates using the same room selection as a sliding sync
|
||||
request. This enables clients to fetch thread updates for the same set of rooms
|
||||
that were included in their sliding sync response.
|
||||
|
||||
Attributes:
|
||||
filters: Optional room filters to apply, using the same structure as
|
||||
SlidingSyncList.Filters. If not provided, thread updates from all
|
||||
joined rooms are returned.
|
||||
lists: Sliding window API lists, using the same structure as SlidingSyncBody.lists.
|
||||
If provided along with room_subscriptions, the union of rooms from both will
|
||||
be used.
|
||||
room_subscriptions: Room subscription API rooms, using the same structure as
|
||||
SlidingSyncBody.room_subscriptions. If provided along with lists, the union
|
||||
of rooms from both will be used.
|
||||
include_roots: Whether to include the thread root events in the response.
|
||||
Defaults to False.
|
||||
|
||||
If neither lists nor room_subscriptions are provided, thread updates from all
|
||||
joined rooms are returned.
|
||||
"""
|
||||
|
||||
filters: SlidingSyncBody.SlidingSyncList.Filters | None = None
|
||||
lists: (
|
||||
dict[
|
||||
Annotated[str, StringConstraints(max_length=64, strict=True)],
|
||||
SlidingSyncBody.SlidingSyncList,
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
room_subscriptions: dict[StrictStr, SlidingSyncBody.RoomSubscription] | None = None
|
||||
include_roots: StrictBool = False
|
||||
|
||||
@@ -58,7 +58,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
access_token=user1_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
@@ -99,7 +99,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
@@ -165,7 +165,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
@@ -227,7 +227,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
@@ -277,7 +277,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
@@ -293,7 +293,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
@@ -315,7 +315,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
@@ -347,7 +347,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
@@ -363,7 +363,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
@@ -372,7 +372,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
@@ -388,7 +388,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
@@ -397,7 +397,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
@@ -454,7 +454,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertIn("next_batch", channel.json_body)
|
||||
@@ -471,7 +471,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
@@ -523,7 +523,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
@@ -602,7 +602,7 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={},
|
||||
content={"include_roots": True},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
@@ -611,3 +611,347 @@ class ThreadUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
self.assertIn(room1_id, chunk)
|
||||
self.assertNotIn(room2_id, chunk)
|
||||
self.assertIn(thread1_root_id, chunk[room1_id])
|
||||
|
||||
def test_room_filtering_with_lists(self) -> None:
|
||||
"""
|
||||
Test that room filtering works correctly using the lists parameter.
|
||||
This verifies that thread updates are only returned for rooms matching
|
||||
the provided filters.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Create an encrypted room and an unencrypted room
|
||||
encrypted_room_id = self.helper.create_room_as(
|
||||
user1_id,
|
||||
tok=user1_tok,
|
||||
extra_content={
|
||||
"initial_state": [
|
||||
{
|
||||
"type": "m.room.encryption",
|
||||
"state_key": "",
|
||||
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
unencrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in both rooms
|
||||
enc_thread_root_id = self.helper.send(
|
||||
encrypted_room_id, body="Encrypted thread", tok=user1_tok
|
||||
)["event_id"]
|
||||
unenc_thread_root_id = self.helper.send(
|
||||
unencrypted_room_id, body="Unencrypted thread", tok=user1_tok
|
||||
)["event_id"]
|
||||
|
||||
# Add replies to both threads
|
||||
self.helper.send_event(
|
||||
encrypted_room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply in encrypted room",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": enc_thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
unencrypted_room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply in unencrypted room",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": unenc_thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates with filter for encrypted rooms only
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={
|
||||
"lists": {
|
||||
"encrypted_list": {
|
||||
"ranges": [[0, 99]],
|
||||
"required_state": [["m.room.encryption", ""]],
|
||||
"timeline_limit": 10,
|
||||
"filters": {"is_encrypted": True},
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should only include the encrypted room
|
||||
self.assertIn(encrypted_room_id, chunk)
|
||||
self.assertNotIn(unencrypted_room_id, chunk)
|
||||
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
|
||||
|
||||
def test_room_filtering_with_room_subscriptions(self) -> None:
|
||||
"""
|
||||
Test that room filtering works correctly using the room_subscriptions parameter.
|
||||
This verifies that thread updates are only returned for explicitly subscribed rooms.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Create three rooms
|
||||
room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
room3_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in all three rooms
|
||||
thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
thread3_root_id = self.helper.send(room3_id, body="Thread 3", tok=user1_tok)[
|
||||
"event_id"
|
||||
]
|
||||
|
||||
# Add replies to all threads
|
||||
for room_id, thread_root_id in [
|
||||
(room1_id, thread1_root_id),
|
||||
(room2_id, thread2_root_id),
|
||||
(room3_id, thread3_root_id),
|
||||
]:
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates with subscription to only room1 and room2
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={
|
||||
"room_subscriptions": {
|
||||
room1_id: {
|
||||
"required_state": [["m.room.name", ""]],
|
||||
"timeline_limit": 10,
|
||||
},
|
||||
room2_id: {
|
||||
"required_state": [["m.room.name", ""]],
|
||||
"timeline_limit": 10,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should only include room1 and room2, not room3
|
||||
self.assertIn(room1_id, chunk)
|
||||
self.assertIn(room2_id, chunk)
|
||||
self.assertNotIn(room3_id, chunk)
|
||||
self.assertIn(thread1_root_id, chunk[room1_id])
|
||||
self.assertIn(thread2_root_id, chunk[room2_id])
|
||||
|
||||
def test_room_filtering_with_lists_and_room_subscriptions(self) -> None:
|
||||
"""
|
||||
Test that room filtering works correctly when both lists and room_subscriptions
|
||||
are provided. The union of rooms from both should be included.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
# Create an encrypted room and two unencrypted rooms
|
||||
encrypted_room_id = self.helper.create_room_as(
|
||||
user1_id,
|
||||
tok=user1_tok,
|
||||
extra_content={
|
||||
"initial_state": [
|
||||
{
|
||||
"type": "m.room.encryption",
|
||||
"state_key": "",
|
||||
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
unencrypted_room1_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
unencrypted_room2_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create threads in all three rooms
|
||||
enc_thread_root_id = self.helper.send(
|
||||
encrypted_room_id, body="Encrypted thread", tok=user1_tok
|
||||
)["event_id"]
|
||||
unenc1_thread_root_id = self.helper.send(
|
||||
unencrypted_room1_id, body="Unencrypted thread 1", tok=user1_tok
|
||||
)["event_id"]
|
||||
unenc2_thread_root_id = self.helper.send(
|
||||
unencrypted_room2_id, body="Unencrypted thread 2", tok=user1_tok
|
||||
)["event_id"]
|
||||
|
||||
# Add replies to all threads
|
||||
for room_id, thread_root_id in [
|
||||
(encrypted_room_id, enc_thread_root_id),
|
||||
(unencrypted_room1_id, unenc1_thread_root_id),
|
||||
(unencrypted_room2_id, unenc2_thread_root_id),
|
||||
]:
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# Request thread updates with:
|
||||
# - lists: filter for encrypted rooms
|
||||
# - room_subscriptions: explicitly subscribe to unencrypted_room1_id
|
||||
# Expected: should get both encrypted_room_id (from list) and unencrypted_room1_id
|
||||
# (from subscription), but NOT unencrypted_room2_id
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b",
|
||||
access_token=user1_tok,
|
||||
content={
|
||||
"lists": {
|
||||
"encrypted_list": {
|
||||
"ranges": [[0, 99]],
|
||||
"required_state": [["m.room.encryption", ""]],
|
||||
"timeline_limit": 10,
|
||||
"filters": {"is_encrypted": True},
|
||||
}
|
||||
},
|
||||
"room_subscriptions": {
|
||||
unencrypted_room1_id: {
|
||||
"required_state": [["m.room.name", ""]],
|
||||
"timeline_limit": 10,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
# Should include encrypted_room_id (from list filter) and unencrypted_room1_id
|
||||
# (from subscription), but not unencrypted_room2_id
|
||||
self.assertIn(encrypted_room_id, chunk)
|
||||
self.assertIn(unencrypted_room1_id, chunk)
|
||||
self.assertNotIn(unencrypted_room2_id, chunk)
|
||||
self.assertIn(enc_thread_root_id, chunk[encrypted_room_id])
|
||||
self.assertIn(unenc1_thread_root_id, chunk[unencrypted_room1_id])
|
||||
|
||||
def test_threads_not_returned_after_leaving_room(self) -> None:
|
||||
"""
|
||||
Test that thread updates are properly bounded when a user leaves a room.
|
||||
|
||||
Users should see thread updates that occurred up to the point they left,
|
||||
but NOT updates that occurred after they left.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
user2_id = self.register_user("user2", "pass")
|
||||
user2_tok = self.login(user2_id, "pass")
|
||||
|
||||
# Create room and both users join
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
self.helper.join(room_id, user2_id, tok=user2_tok)
|
||||
|
||||
# Create thread
|
||||
res = self.helper.send(room_id, body="Thread root", tok=user1_tok)
|
||||
thread_root = res["event_id"]
|
||||
|
||||
# Reply in thread while user2 is joined
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply 1 while user2 joined",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# User2 leaves the room
|
||||
self.helper.leave(room_id, user2_id, tok=user2_tok)
|
||||
|
||||
# Another reply after user2 left
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Reply 2 after user2 left",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root,
|
||||
},
|
||||
},
|
||||
tok=user1_tok,
|
||||
)
|
||||
|
||||
# User2 gets thread updates with an explicit room subscription
|
||||
# (We need to explicitly subscribe to the room to include it after leaving;
|
||||
# otherwise only joined rooms are returned)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=100",
|
||||
{
|
||||
"room_subscriptions": {
|
||||
room_id: {
|
||||
"required_state": [],
|
||||
"timeline_limit": 0,
|
||||
}
|
||||
}
|
||||
},
|
||||
access_token=user2_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Assert: User2 SHOULD see Reply 1 (happened while joined) but NOT Reply 2 (after leaving)
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertIn(
|
||||
room_id,
|
||||
chunk,
|
||||
"Thread updates should include the room user2 left",
|
||||
)
|
||||
self.assertIn(
|
||||
thread_root,
|
||||
chunk[room_id],
|
||||
"Thread root should be in the updates",
|
||||
)
|
||||
|
||||
# Verify that only a single update was seen (Reply 1) by checking that there's
|
||||
# no prev_batch token. If Reply 2 was also included, there would be multiple
|
||||
# updates and a prev_batch token would be present.
|
||||
thread_update = chunk[room_id][thread_root]
|
||||
self.assertNotIn(
|
||||
"prev_batch",
|
||||
thread_update,
|
||||
"No prev_batch should be present since only one update (Reply 1) is visible",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user