diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index 6ca4738ae5..f18f01219f 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -1008,7 +1008,7 @@ class SlidingSyncExtensionHandler: # We set no limit here because the client can control when they get sticky events. # Furthermore, it doesn't seem possible to set a limit with the internal API shape # as given, as we cannot manipulate the to_token.sticky_events_key sent to the client... - limit=0, + limit=None, ) all_sticky_event_ids = { ev_id for evs in room_to_event_ids.values() for ev_id in evs diff --git a/synapse/storage/databases/main/sticky_events.py b/synapse/storage/databases/main/sticky_events.py index 16a1fb1bf2..7d968fd7c2 100644 --- a/synapse/storage/databases/main/sticky_events.py +++ b/synapse/storage/databases/main/sticky_events.py @@ -123,7 +123,7 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor from_id: int, to_id: int, now: int, - limit: int, + limit: int | None, ) -> tuple[int, dict[str, set[str]]]: """ Fetch all the sticky events in the given rooms, from the given sticky stream ID. @@ -133,9 +133,9 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor from_id: The sticky stream ID that sticky events should be returned from (exclusive). to_id: The sticky stream ID that sticky events should end at (inclusive). now: The current time in unix millis, used for skipping expired events. - limit: Max sticky events to return. If <= 0, no limit is applied. + limit: Max sticky events to return, or None to apply no limit. Returns: - A tuple of (to_id, map[room_id, event_ids]) + to_id, map[room_id, event_ids] """ sticky_events_rows = await self.db_pool.runInteraction( "get_sticky_events_in_rooms", @@ -146,13 +146,19 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor now, limit, ) - new_to_id = from_id + + if not sticky_events_rows: + return from_id, {} + + # Get stream_id of the last row, which is the highest + new_to_id, _, _ = sticky_events_rows[-1] + + # room ID -> event IDs room_to_events: dict[str, set[str]] = {} - for stream_id, room_id, event_id in sticky_events_rows: - new_to_id = max(new_to_id, stream_id) - events = room_to_events.get(room_id, set()) + for _, room_id, event_id in sticky_events_rows: + events = room_to_events.setdefault(room_id, set()) events.add(event_id) - room_to_events[room_id] = events + return (new_to_id, room_to_events) def _get_sticky_events_in_rooms_txn( @@ -162,23 +168,33 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor from_id: int, to_id: int, now: int, - limit: int, + limit: int | None, ) -> list[tuple[int, str, str]]: if len(room_ids) == 0: return [] - clause, room_id_values = make_in_list_sql_clause( + room_id_in_list_clause, room_id_in_list_values = make_in_list_sql_clause( txn.database_engine, "room_id", room_ids ) - query = f""" - SELECT stream_id, room_id, event_id FROM sticky_events - WHERE soft_failed != ? AND expires_at > ? AND stream_id > ? AND stream_id <= ? AND {clause} + limit_clause = "" + limit_params: tuple[int, ...] = () + if limit is not None: + limit_clause = "LIMIT ?" + limit_params = (limit,) + txn.execute( + f""" + SELECT stream_id, room_id, event_id + FROM sticky_events + WHERE + NOT soft_failed + AND expires_at > ? + AND stream_id > ? + AND stream_id <= ? + AND {room_id_in_list_clause} ORDER BY stream_id ASC - """ - params = (True, now, from_id, to_id, *room_id_values) - if limit > 0: - query += "LIMIT ?" - params += (limit,) - txn.execute(query, params) + {limit_clause} + """, + (now, from_id, to_id, *room_id_in_list_values, *limit_params), + ) return cast(list[tuple[int, str, str]], txn.fetchall()) async def get_updated_sticky_events(