diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 63624f3e8f..6f2e03b0f9 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -313,6 +313,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "get_unread_event_push_actions_by_room_for_user", (room_id,) ) + self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,)) + # The `_get_membership_from_event_id` is immutable, except for the # case where we look up an event *before* persisting it. self._attempt_to_invalidate_cache("_get_membership_from_event_id", (event_id,)) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 4989c960a6..4d576522a9 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -50,6 +50,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Protocol, Set, @@ -80,7 +81,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter @@ -1382,7 +1383,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ min_token = end_token.stream - max_token = end_token.get_max_stream_pos() results: Dict[str, int] = {} # First, we check for the rooms in the stream change cache to see if we @@ -1395,26 +1395,76 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): else: missing_room_ids.add(room_id) + if not missing_room_ids: + return results + # Next, we query the stream position from the DB. At first we fetch all # positions less than the *max* stream pos in the token, then filter # them down. We do this as a) this is a cheaper query, and b) the vast # majority of rooms will have a latest token from before the min stream # pos. - def bulk_get_last_event_pos_txn( - txn: LoggingTransaction, batch_room_ids: StrCollection + uncapped_results = await self._bulk_get_max_event_pos(missing_room_ids) + + # Check that the stream position for the rooms are from before the + # minimum position of the token. If not then we need to fetch more + # rows. + recheck_rooms: Set[str] = set() + for room_id, stream in uncapped_results.items(): + if stream <= min_token: + results[room_id] = stream + else: + recheck_rooms.add(room_id) + + if not recheck_rooms: + return results + + for room_id in recheck_rooms: + result = await self.get_last_event_pos_in_room_before_stream_ordering( + room_id, end_token + ) + if result is not None: + results[room_id] = result[1].stream + + return results + + @cached() + async def _get_max_event_pos(self, room_id: str) -> int: + raise NotImplementedError() + + @cachedList(cached_method_name="_get_max_event_pos", list_name="room_ids") + async def _bulk_get_max_event_pos( + self, room_ids: StrCollection + ) -> Mapping[str, int]: + """Fetch the max position of a persisted event in the room.""" + + now_token = self.get_room_max_token() + max_pos = now_token.get_max_stream_pos() + + results: Dict[str, int] = {} + missing_room_ids: Set[str] = set() + for room_id in room_ids: + stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id) + if stream_pos is not None: + results[room_id] = stream_pos + else: + missing_room_ids.add(room_id) + + if not missing_room_ids: + return results + + def bulk_get_max_event_pos_txn( + txn: LoggingTransaction, batched_room_ids: StrCollection ) -> Dict[str, int]: - # This query fetches the latest stream position in the rooms before - # the given max position. clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", batch_room_ids + self.database_engine, "room_id", batched_room_ids ) sql = f""" SELECT room_id, ( SELECT stream_ordering FROM events AS e LEFT JOIN rejections USING (event_id) WHERE e.room_id = r.room_id - AND stream_ordering <= ? + AND e.stream_ordering <= ? AND NOT outlier AND rejection_reason IS NULL ORDER BY stream_ordering DESC @@ -1423,72 +1473,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): FROM rooms AS r WHERE {clause} """ - txn.execute(sql, [max_token] + args) + txn.execute(sql, [max_pos] + args) return {row[0]: row[1] for row in txn} recheck_rooms: Set[str] = set() - for batched in batch_iter(missing_room_ids, 1000): - result = await self.db_pool.runInteraction( - "bulk_get_last_event_pos_in_room_before_stream_ordering", - bulk_get_last_event_pos_txn, - batched, + for batched in batch_iter(room_ids, 1000): + batch_results = await self.db_pool.runInteraction( + "_bulk_get_max_event_pos", bulk_get_max_event_pos_txn, batched ) - - # Check that the stream position for the rooms are from before the - # minimum position of the token. If not then we need to fetch more - # rows. - for room_id, stream in result.items(): - if stream <= min_token: - results[room_id] = stream + for room_id, stream_ordering in batch_results.items(): + if stream_ordering <= now_token.stream: + results.update(batch_results) else: recheck_rooms.add(room_id) - if not recheck_rooms: - return results - - # For the remaining rooms we need to fetch all rows between the min and - # max stream positions in the end token, and filter out the rows that - # are after the end token. - # - # This query should be fast as the range between the min and max should - # be small. - - def bulk_get_last_event_pos_recheck_txn( - txn: LoggingTransaction, batch_room_ids: StrCollection - ) -> Dict[str, int]: - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", batch_room_ids + for room_id in recheck_rooms: + result = await self.get_last_event_pos_in_room_before_stream_ordering( + room_id, now_token ) - sql = f""" - SELECT room_id, instance_name, stream_ordering - FROM events - WHERE ? < stream_ordering AND stream_ordering <= ? - AND NOT outlier - AND rejection_reason IS NULL - AND {clause} - ORDER BY stream_ordering ASC - """ - txn.execute(sql, [min_token, max_token] + args) - - # We take the max stream ordering that is less than the token. Since - # we ordered by stream ordering we just need to iterate through and - # take the last matching stream ordering. - txn_results: Dict[str, int] = {} - for row in txn: - room_id = row[0] - event_pos = PersistedEventPosition(row[1], row[2]) - if not event_pos.persisted_after(end_token): - txn_results[room_id] = event_pos.stream - - return txn_results - - for batched in batch_iter(recheck_rooms, 1000): - recheck_result = await self.db_pool.runInteraction( - "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck", - bulk_get_last_event_pos_recheck_txn, - batched, - ) - results.update(recheck_result) + if result is not None: + results[room_id] = result[1].stream return results