1
0

Run linter

This commit is contained in:
Devon Hudson
2025-11-09 09:30:44 -07:00
parent cb82a4a687
commit a3b34dfafd
3 changed files with 41 additions and 22 deletions

View File

@@ -12,9 +12,9 @@
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from collections import defaultdict
import itertools
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
AbstractSet,
@@ -1008,9 +1008,7 @@ class SlidingSyncExtensionHandler:
prev_batch=prev_batch,
)
def _extract_thread_id_from_event(
self, event: EventBase
) -> Optional[str]:
def _extract_thread_id_from_event(self, event: EventBase) -> Optional[str]:
"""Extract thread ID from event if it's a thread reply.
Args:
@@ -1082,7 +1080,9 @@ class SlidingSyncExtensionHandler:
# wants the thread root events.
threads_to_exclude: Optional[Set[str]] = None
if not threads_request.include_roots:
threads_to_exclude = self._find_threads_in_timeline(actual_room_response_map)
threads_to_exclude = self._find_threads_in_timeline(
actual_room_response_map
)
# Separate rooms into groups based on membership status.
# For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events
@@ -1130,7 +1130,9 @@ class SlidingSyncExtensionHandler:
)
# Count how many updates we fetched and reduce the remaining limit
num_updates = sum(len(updates) for updates in room_thread_updates.values())
num_updates = sum(
len(updates) for updates in room_thread_updates.values()
)
remaining_limit -= num_updates
# Merge results
@@ -1144,7 +1146,10 @@ class SlidingSyncExtensionHandler:
prev_batch_token = room_prev_batch
else:
# Take the maximum (latest) prev_batch token for backwards pagination
if room_prev_batch.room_key.stream > prev_batch_token.room_key.stream:
if (
room_prev_batch.room_key.stream
> prev_batch_token.room_key.stream
):
prev_batch_token = room_prev_batch
# Query for rooms where the user is joined, invited, or knocking, using the
@@ -1171,7 +1176,10 @@ class SlidingSyncExtensionHandler:
prev_batch_token = other_prev_batch
else:
# Take the maximum (latest) prev_batch token for backwards pagination
if other_prev_batch.room_key.stream > prev_batch_token.room_key.stream:
if (
other_prev_batch.room_key.stream
> prev_batch_token.room_key.stream
):
prev_batch_token = other_prev_batch
# Early return: no thread updates found
@@ -1211,13 +1219,17 @@ class SlidingSyncExtensionHandler:
thread_root_event_map = {}
aggregations_map = {}
if threads_request.include_roots:
thread_root_events = await self.store.get_events_as_list(filtered_updates.keys())
thread_root_events = await self.store.get_events_as_list(
filtered_updates.keys()
)
thread_root_event_map = {e.event_id: e for e in thread_root_events}
if thread_root_event_map:
aggregations_map = await self.relations_handler.get_bundled_aggregations(
thread_root_event_map.values(),
sync_config.user.to_string(),
aggregations_map = (
await self.relations_handler.get_bundled_aggregations(
thread_root_event_map.values(),
sync_config.user.to_string(),
)
)
thread_updates: Dict[str, Dict[str, _ThreadUpdate]] = {}
@@ -1238,13 +1250,16 @@ class SlidingSyncExtensionHandler:
# This makes it exclusive - /relations with dir=b won't return the latest event again.
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM, RoomStreamToken(stream=latest_update.stream_ordering - 1)
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)
thread_updates.setdefault(latest_update.room_id, {})[thread_root] = _ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root),
thread_updates.setdefault(latest_update.room_id, {})[thread_root] = (
_ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root),
)
)
return SlidingSyncResult.Extensions.ThreadsExtension(

View File

@@ -18,8 +18,8 @@
#
#
from collections import defaultdict
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
@@ -1190,7 +1190,9 @@ class RelationsWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_thread_updates_for_user_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[str, str, str, int]], Optional[int]]:
room_clause, room_id_values = make_in_list_sql_clause(txn.database_engine, "e.room_id", room_ids)
room_clause, room_id_values = make_in_list_sql_clause(
txn.database_engine, "e.room_id", room_ids
)
# Generate the pagination clause, if necessary.
pagination_clause = ""
@@ -1210,7 +1212,10 @@ class RelationsWorkerStore(EventsWorkerStore, SQLBaseStore):
exclusion_args: List[str] = []
if exclude_thread_ids:
exclusion_clause, exclusion_args = make_in_list_sql_clause(
txn.database_engine, "er.relates_to_id", exclude_thread_ids, negative=True,
txn.database_engine,
"er.relates_to_id",
exclude_thread_ids,
negative=True,
)
exclusion_clause = f" AND {exclusion_clause}"

View File

@@ -138,7 +138,7 @@ class SlidingSyncThreadsExtensionTestCase(SlidingSyncBase):
"foo-list": {
"ranges": [[0, 1]],
"required_state": [],
"timeline_limit": 0, # Set to 0, otherwise events will be in timeline, not extension
"timeline_limit": 0, # Set to 0, otherwise events will be in timeline, not extension
}
},
"extensions": {
@@ -157,7 +157,6 @@ class SlidingSyncThreadsExtensionTestCase(SlidingSyncBase):
{"updates": {room_id: {thread_root_id: {}}}},
)
def test_threads_incremental_sync(self) -> None:
"""
Test new thread updates appear in incremental sync response.