diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 64d303e330..8e12645d6b 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -24,9 +24,9 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast import attr -from synapse.api.constants import EventContentFields, RelationTypes +from synapse.api.constants import EventContentFields, Membership, RelationTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -34,9 +34,18 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.storage.databases.main.events import PersistEventsStore +from synapse.storage.databases.main.events import ( + SLIDING_SYNC_RELEVANT_STATE_SET, + PersistEventsStore, + SlidingSyncMembershipInfo, + SlidingSyncMembershipSnapshotSharedInsertValues, +) +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor -from synapse.types import JsonDict, StrCollection +from synapse.types import JsonDict, StateMap, StrCollection +from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES +from synapse.types.state import StateFilter if TYPE_CHECKING: from synapse.server import HomeServer @@ -78,6 +87,11 @@ class _BackgroundUpdates: EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index" + SLIDING_SYNC_JOINED_ROOMS_BACKFILL = "sliding_sync_joined_rooms_backfill" + SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL = ( + "sliding_sync_membership_snapshots_backfill" + ) + @attr.s(slots=True, frozen=True, auto_attribs=True) class _CalculateChainCover: @@ -97,7 +111,7 @@ class _CalculateChainCover: finished_room_map: Dict[str, Tuple[int, int]] -class EventsBackgroundUpdatesStore(SQLBaseStore): +class EventsBackgroundUpdatesStore(EventsWorkerStore, SQLBaseStore): def __init__( self, database: DatabasePool, @@ -279,6 +293,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): where_clause="NOT outlier", ) + # Backfill the sliding sync tables + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL, + self._sliding_sync_joined_rooms_backfill, + ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL, + self._sliding_sync_membership_snapshots_backfill, + ) + async def _background_reindex_fields_sender( self, progress: JsonDict, batch_size: int ) -> int: @@ -1073,7 +1097,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): PersistEventsStore._add_chain_cover_index( txn, self.db_pool, - self.event_chain_id_gen, # type: ignore[attr-defined] + self.event_chain_id_gen, event_to_room_id, event_to_types, cast(Dict[str, StrCollection], event_to_auth_chain), @@ -1516,3 +1540,443 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) return batch_size + + async def _sliding_sync_joined_rooms_backfill( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Handles backfilling the `sliding_sync_joined_rooms` table. + """ + last_room_id = progress.get("last_room_id", "") + + def make_sql_clause_for_get_last_event_pos_in_room( + database_engine: BaseDatabaseEngine, + event_types: Optional[StrCollection] = None, + ) -> Tuple[str, list]: + """ + Returns the ID and event position of the last event in a room at or before a + stream ordering. + + Based on `get_last_event_pos_in_room_before_stream_ordering(...)` + + Args: + database_engine + event_types: Optional allowlist of event types to filter by + + Returns: + A tuple of SQL query and the args + """ + event_type_clause = "" + event_type_args: List[str] = [] + if event_types is not None and len(event_types) > 0: + event_type_clause, event_type_args = make_in_list_sql_clause( + database_engine, "type", event_types + ) + event_type_clause = f"AND {event_type_clause}" + + sql = f""" + SELECT stream_ordering + FROM events + LEFT JOIN rejections USING (event_id) + WHERE room_id = ? + {event_type_clause} + AND NOT outlier + AND rejections.event_id IS NULL + ORDER BY stream_ordering DESC + LIMIT 1 + """ + + return sql, event_type_args + + def _txn(txn: LoggingTransaction) -> int: + # Fetch the set of room IDs that we want to update + txn.execute( + """ + SELECT DISTINCT room_id FROM current_state_events + WHERE room_id > ? + ORDER BY room_id ASC + LIMIT ? + """, + (last_room_id, batch_size), + ) + + rooms_to_update_rows = txn.fetchall() + if not rooms_to_update_rows: + return 0 + + for (room_id,) in rooms_to_update_rows: + # TODO: Handle redactions + current_state_map = PersistEventsStore._get_relevant_sliding_sync_current_state_event_ids_txn( + txn, room_id + ) + # We're iterating over rooms pulled from the current_state_events table + # so we should have some current state for each room + assert current_state_map + + sliding_sync_joined_rooms_insert_map = PersistEventsStore._get_sliding_sync_insert_values_from_state_ids_map_txn( + txn, current_state_map + ) + # We should have some insert values for each room, even if they are `None` + assert sliding_sync_joined_rooms_insert_map + + ( + most_recent_event_stream_ordering_clause, + most_recent_event_stream_ordering_args, + ) = make_sql_clause_for_get_last_event_pos_in_room( + txn.database_engine, event_types=None + ) + bump_stamp_clause, bump_stamp_args = ( + make_sql_clause_for_get_last_event_pos_in_room( + txn.database_engine, + event_types=SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES, + ) + ) + + # Pulling keys/values separately is safe and will produce congruent + # lists + insert_keys = sliding_sync_joined_rooms_insert_map.keys() + insert_values = sliding_sync_joined_rooms_insert_map.values() + + sql = f""" + INSERT INTO sliding_sync_joined_rooms + (room_id, event_stream_ordering, bump_stamp, {", ".join(insert_keys)}) + VALUES ( + ?, + ({most_recent_event_stream_ordering_clause}), + ({bump_stamp_clause}), + {", ".join("?" for _ in insert_values)} + ) + ON CONFLICT (room_id) + DO UPDATE SET + event_stream_ordering = EXCLUDED.event_stream_ordering, + bump_stamp = EXCLUDED.bump_stamp, + {", ".join(f"{key} = EXCLUDED.{key}" for key in insert_keys)} + """ + args = ( + [room_id, room_id] + + most_recent_event_stream_ordering_args + + [room_id] + + bump_stamp_args + + list(insert_values) + ) + txn.execute(sql, args) + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL, + {"last_room_id": rooms_to_update_rows[-1][0]}, + ) + + return len(rooms_to_update_rows) + + count = await self.db_pool.runInteraction( + "sliding_sync_joined_rooms_backfill", _txn + ) + + if not count: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL + ) + + return count + + async def _sliding_sync_membership_snapshots_backfill( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Handles backfilling the `sliding_sync_membership_snapshots` table. + """ + last_event_stream_ordering = progress.get( + "last_event_stream_ordering", -(1 << 31) + ) + + def _find_memberships_to_update_txn( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, str, str, str, int, bool]]: + # Fetch the set of event IDs that we want to update + txn.execute( + """ + SELECT + c.room_id, + c.user_id, + e.sender, + c.event_id, + c.membership, + c.event_stream_ordering, + e.outlier + FROM local_current_membership as c + INNER JOIN events AS e USING (event_id) + WHERE event_stream_ordering > ? + ORDER BY event_stream_ordering ASC + LIMIT ? + """, + (last_event_stream_ordering, batch_size), + ) + + memberships_to_update_rows = cast( + List[Tuple[str, str, str, str, str, int, bool]], txn.fetchall() + ) + + return memberships_to_update_rows + + memberships_to_update_rows = await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_backfill._find_memberships_to_update_txn", + _find_memberships_to_update_txn, + ) + + if not memberships_to_update_rows: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL + ) + return 0 + + def _find_previous_membership_txn( + txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int + ) -> Tuple[str, str]: + # Find the previous invite/knock event before the leave event + txn.execute( + """ + SELECT event_id, membership + FROM room_memberships + WHERE + room_id = ? + AND user_id = ? + AND event_stream_ordering < ? + ORDER BY event_stream_ordering DESC + LIMIT 1 + """, + ( + room_id, + user_id, + stream_ordering, + ), + ) + row = txn.fetchone() + + # We should see a corresponding previous invite/knock event + assert row is not None + event_id, membership = row + + return event_id, membership + + # Map from (room_id, user_id) to ... + to_insert_membership_snapshots: Dict[ + Tuple[str, str], SlidingSyncMembershipSnapshotSharedInsertValues + ] = {} + to_insert_membership_infos: Dict[Tuple[str, str], SlidingSyncMembershipInfo] = ( + {} + ) + for ( + room_id, + user_id, + sender, + membership_event_id, + membership, + membership_event_stream_ordering, + is_outlier, + ) in memberships_to_update_rows: + # We don't know how to handle `membership` values other than these. The + # code below would need to be updated. + assert membership in ( + Membership.JOIN, + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + Membership.BAN, + ) + + # Map of values to insert/update in the `sliding_sync_membership_snapshots` table + sliding_sync_membership_snapshots_insert_map: ( + SlidingSyncMembershipSnapshotSharedInsertValues + ) = {} + if membership == Membership.JOIN: + # If we're still joined, we can pull from current state. + current_state_ids_map: StateMap[ + str + ] = await self.hs.get_storage_controllers().state.get_current_state_ids( + room_id, + state_filter=StateFilter.from_types( + SLIDING_SYNC_RELEVANT_STATE_SET + ), + # Partially-stated rooms should have all state events except for + # remote membership events so we don't need to wait at all because + # we only want some non-membership state + await_full_state=False, + ) + # We're iterating over rooms that we are joined to so they should + # have `current_state_events` and we should have some current state + # for each room + assert current_state_ids_map + + fetched_events = await self.get_events(current_state_ids_map.values()) + + current_state_map: StateMap[EventBase] = { + state_key: fetched_events[event_id] + for state_key, event_id in current_state_ids_map.items() + } + + state_insert_values = ( + PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + current_state_map + ) + ) + sliding_sync_membership_snapshots_insert_map.update(state_insert_values) + # We should have some insert values for each room, even if they are `None` + assert sliding_sync_membership_snapshots_insert_map + + # We have current state to work from + sliding_sync_membership_snapshots_insert_map["has_known_state"] = True + elif membership in (Membership.INVITE, Membership.KNOCK) or ( + membership == Membership.LEAVE and is_outlier + ): + invite_or_knock_event_id = membership_event_id + invite_or_knock_membership = membership + + # If the event is an `out_of_band_membership` (special case of + # `outlier`), we never had historical state so we have to pull from + # the stripped state on the previous invite/knock event. This gives + # us a consistent view of the room state regardless of your + # membership (i.e. the room shouldn't disappear if your using the + # `is_encrypted` filter and you leave). + if membership == Membership.LEAVE and is_outlier: + invite_or_knock_event_id, invite_or_knock_membership = ( + await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_backfill._find_previous_membership", + _find_previous_membership_txn, + room_id, + user_id, + membership_event_stream_ordering, + ) + ) + + # Pull from the stripped state on the invite/knock event + invite_or_knock_event = await self.get_event(invite_or_knock_event_id) + + raw_stripped_state_events = None + if invite_or_knock_membership == Membership.INVITE: + invite_room_state = invite_or_knock_event.unsigned.get( + "invite_room_state" + ) + raw_stripped_state_events = invite_room_state + elif invite_or_knock_membership == Membership.KNOCK: + knock_room_state = invite_or_knock_event.unsigned.get( + "knock_room_state" + ) + raw_stripped_state_events = knock_room_state + + sliding_sync_membership_snapshots_insert_map = await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_backfill._get_sliding_sync_insert_values_from_stripped_state_txn", + PersistEventsStore._get_sliding_sync_insert_values_from_stripped_state_txn, + raw_stripped_state_events, + ) + + # We should have some insert values for each room, even if no + # stripped state is on the event because we still want to record + # that we have no known state + assert sliding_sync_membership_snapshots_insert_map + elif membership in (Membership.LEAVE, Membership.BAN): + # Pull from historical state + state_ids_map = await self.hs.get_storage_controllers().state.get_state_ids_for_event( + membership_event_id, + state_filter=StateFilter.from_types( + SLIDING_SYNC_RELEVANT_STATE_SET + ), + # Partially-stated rooms should have all state events except for + # remote membership events so we don't need to wait at all because + # we only want some non-membership state + await_full_state=False, + ) + + fetched_events = await self.get_events(state_ids_map.values()) + + state_map: StateMap[EventBase] = { + state_key: fetched_events[event_id] + for state_key, event_id in state_ids_map.items() + } + + state_insert_values = ( + PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + state_map + ) + ) + sliding_sync_membership_snapshots_insert_map.update(state_insert_values) + # We should have some insert values for each room, even if they are `None` + assert sliding_sync_membership_snapshots_insert_map + + # We have historical state to work from + sliding_sync_membership_snapshots_insert_map["has_known_state"] = True + else: + # We don't know how to handle this type of membership yet + # + # FIXME: We should use `assert_never` here but for some reason + # the exhaustive matching doesn't recognize the `Never` here. + # assert_never(membership) + raise AssertionError( + f"Unexpected membership {membership} ({membership_event_id}) that we don't know how to handle yet" + ) + + to_insert_membership_snapshots[(room_id, user_id)] = ( + sliding_sync_membership_snapshots_insert_map + ) + to_insert_membership_infos[(room_id, user_id)] = SlidingSyncMembershipInfo( + user_id=user_id, + sender=sender, + membership_event_id=membership_event_id, + ) + + def _backfill_table_txn(txn: LoggingTransaction) -> None: + for key, insert_map in to_insert_membership_snapshots.items(): + room_id, user_id = key + membership_info = to_insert_membership_infos[key] + membership_event_id = membership_info.membership_event_id + + # Pulling keys/values separately is safe and will produce congruent + # lists + insert_keys = insert_map.keys() + insert_values = insert_map.values() + # We don't need to do anything `ON CONFLICT` because we never partially + # insert/update the snapshots + txn.execute( + f""" + INSERT INTO sliding_sync_membership_snapshots + (room_id, user_id, membership_event_id, membership, event_stream_ordering + {("," + ", ".join(insert_keys)) if insert_keys else ""}) + VALUES ( + ?, ?, ?, + (SELECT membership FROM room_memberships WHERE event_id = ?), + (SELECT stream_ordering FROM events WHERE event_id = ?) + {("," + ", ".join("?" for _ in insert_values)) if insert_values else ""} + ) + ON CONFLICT (room_id, user_id) + DO NOTHING + """, + [ + room_id, + user_id, + membership_event_id, + membership_event_id, + membership_event_id, + ] + + list(insert_values), + ) + + await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_backfill", _backfill_table_txn + ) + + # Update the progress + ( + _room_id, + _user_id, + _sender, + _membership_event_id, + _membership, + membership_event_stream_ordering, + _is_outlier, + ) = memberships_to_update_rows[-1] + await self.db_pool.updates._background_update_progress( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL, + {"last_event_stream_ordering": membership_event_stream_ordering}, + ) + + return len(memberships_to_update_rows) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 5898a0e6d4..ea7d8199a7 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -20,28 +20,17 @@ # import logging -from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union -from typing_extensions import assert_never - -from synapse.api.constants import Membership -from synapse.events import EventBase from synapse.logging.opentracing import tag_args, trace -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.storage.databases.main.events import ( - SLIDING_SYNC_RELEVANT_STATE_SET, - PersistEventsStore, - SlidingSyncMembershipInfo, - SlidingSyncMembershipSnapshotSharedInsertValues, -) -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.types import JsonDict, MutableStateMap, StateMap, StrCollection -from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES +from synapse.storage.engines import PostgresEngine +from synapse.types import MutableStateMap, StateMap from synapse.types.state import StateFilter from synapse.util.caches import intern_string @@ -54,13 +43,6 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class _BackgroundUpdates: - SLIDING_SYNC_JOINED_ROOMS_BACKFILL = "sliding_sync_joined_rooms_backfill" - SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL = ( - "sliding_sync_membership_snapshots_backfill" - ) - - class StateGroupBackgroundUpdateStore(SQLBaseStore): """Defines functions related to state groups needed to run the state background updates. @@ -367,16 +349,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): columns=["event_stream_ordering"], ) - # Backfill the sliding sync tables - self.db_pool.updates.register_background_update_handler( - _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL, - self._sliding_sync_joined_rooms_backfill, - ) - self.db_pool.updates.register_background_update_handler( - _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL, - self._sliding_sync_membership_snapshots_backfill, - ) - async def _background_deduplicate_state( self, progress: dict, batch_size: int ) -> int: @@ -552,439 +524,3 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): ) return 1 - - async def _sliding_sync_joined_rooms_backfill( - self, progress: JsonDict, batch_size: int - ) -> int: - """ - Handles backfilling the `sliding_sync_joined_rooms` table. - """ - last_room_id = progress.get("last_room_id", "") - - def make_sql_clause_for_get_last_event_pos_in_room( - database_engine: BaseDatabaseEngine, - event_types: Optional[StrCollection] = None, - ) -> Tuple[str, list]: - """ - Returns the ID and event position of the last event in a room at or before a - stream ordering. - - Based on `get_last_event_pos_in_room_before_stream_ordering(...)` - - Args: - database_engine - event_types: Optional allowlist of event types to filter by - - Returns: - A tuple of SQL query and the args - """ - event_type_clause = "" - event_type_args: List[str] = [] - if event_types is not None and len(event_types) > 0: - event_type_clause, event_type_args = make_in_list_sql_clause( - database_engine, "type", event_types - ) - event_type_clause = f"AND {event_type_clause}" - - sql = f""" - SELECT stream_ordering - FROM events - LEFT JOIN rejections USING (event_id) - WHERE room_id = ? - {event_type_clause} - AND NOT outlier - AND rejections.event_id IS NULL - ORDER BY stream_ordering DESC - LIMIT 1 - """ - - return sql, event_type_args - - def _txn(txn: LoggingTransaction) -> int: - # Fetch the set of room IDs that we want to update - txn.execute( - """ - SELECT DISTINCT room_id FROM current_state_events - WHERE room_id > ? - ORDER BY room_id ASC - LIMIT ? - """, - (last_room_id, batch_size), - ) - - rooms_to_update_rows = txn.fetchall() - if not rooms_to_update_rows: - return 0 - - for (room_id,) in rooms_to_update_rows: - # TODO: Handle redactions - current_state_map = PersistEventsStore._get_relevant_sliding_sync_current_state_event_ids_txn( - txn, room_id - ) - # We're iterating over rooms pulled from the current_state_events table - # so we should have some current state for each room - assert current_state_map - - sliding_sync_joined_rooms_insert_map = PersistEventsStore._get_sliding_sync_insert_values_from_state_ids_map_txn( - txn, current_state_map - ) - # We should have some insert values for each room, even if they are `None` - assert sliding_sync_joined_rooms_insert_map - - ( - most_recent_event_stream_ordering_clause, - most_recent_event_stream_ordering_args, - ) = make_sql_clause_for_get_last_event_pos_in_room( - txn.database_engine, event_types=None - ) - bump_stamp_clause, bump_stamp_args = ( - make_sql_clause_for_get_last_event_pos_in_room( - txn.database_engine, - event_types=SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES, - ) - ) - - # Pulling keys/values separately is safe and will produce congruent - # lists - insert_keys = sliding_sync_joined_rooms_insert_map.keys() - insert_values = sliding_sync_joined_rooms_insert_map.values() - - sql = f""" - INSERT INTO sliding_sync_joined_rooms - (room_id, event_stream_ordering, bump_stamp, {", ".join(insert_keys)}) - VALUES ( - ?, - ({most_recent_event_stream_ordering_clause}), - ({bump_stamp_clause}), - {", ".join("?" for _ in insert_values)} - ) - ON CONFLICT (room_id) - DO UPDATE SET - event_stream_ordering = EXCLUDED.event_stream_ordering, - bump_stamp = EXCLUDED.bump_stamp, - {", ".join(f"{key} = EXCLUDED.{key}" for key in insert_keys)} - """ - args = ( - [room_id, room_id] - + most_recent_event_stream_ordering_args - + [room_id] - + bump_stamp_args - + list(insert_values) - ) - txn.execute(sql, args) - - self.db_pool.updates._background_update_progress_txn( - txn, - _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL, - {"last_room_id": rooms_to_update_rows[-1][0]}, - ) - - return len(rooms_to_update_rows) - - count = await self.db_pool.runInteraction( - "sliding_sync_joined_rooms_backfill", _txn - ) - - if not count: - await self.db_pool.updates._end_background_update( - _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL - ) - - return count - - async def _sliding_sync_membership_snapshots_backfill( - self, progress: JsonDict, batch_size: int - ) -> int: - """ - Handles backfilling the `sliding_sync_membership_snapshots` table. - """ - last_event_stream_ordering = progress.get( - "last_event_stream_ordering", -(1 << 31) - ) - - def _find_memberships_to_update_txn( - txn: LoggingTransaction, - ) -> List[Tuple[str, str, str, str, str, int, bool]]: - # Fetch the set of event IDs that we want to update - txn.execute( - """ - SELECT - c.room_id, - c.user_id, - e.sender, - c.event_id, - c.membership, - c.event_stream_ordering, - e.outlier - FROM local_current_membership as c - INNER JOIN events AS e USING (event_id) - WHERE event_stream_ordering > ? - ORDER BY event_stream_ordering ASC - LIMIT ? - """, - (last_event_stream_ordering, batch_size), - ) - - memberships_to_update_rows = cast( - List[Tuple[str, str, str, str, str, int, bool]], txn.fetchall() - ) - - return memberships_to_update_rows - - memberships_to_update_rows = await self.db_pool.runInteraction( - "sliding_sync_membership_snapshots_backfill._find_memberships_to_update_txn", - _find_memberships_to_update_txn, - ) - - if not memberships_to_update_rows: - await self.db_pool.updates._end_background_update( - _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL - ) - return 0 - - store = self.hs.get_storage_controllers().main - - def _find_previous_membership_txn( - txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int - ) -> Tuple[str, str]: - # Find the previous invite/knock event before the leave event - txn.execute( - """ - SELECT event_id, membership - FROM room_memberships - WHERE - room_id = ? - AND user_id = ? - AND event_stream_ordering < ? - ORDER BY event_stream_ordering DESC - LIMIT 1 - """, - ( - room_id, - user_id, - stream_ordering, - ), - ) - row = txn.fetchone() - - # We should see a corresponding previous invite/knock event - assert row is not None - event_id, membership = row - - return event_id, membership - - # Map from (room_id, user_id) to ... - to_insert_membership_snapshots: Dict[ - Tuple[str, str], SlidingSyncMembershipSnapshotSharedInsertValues - ] = {} - to_insert_membership_infos: Dict[Tuple[str, str], SlidingSyncMembershipInfo] = ( - {} - ) - for ( - room_id, - user_id, - sender, - membership_event_id, - membership, - membership_event_stream_ordering, - is_outlier, - ) in memberships_to_update_rows: - # We don't know how to handle `membership` values other than these. The - # code below would need to be updated. - assert membership in ( - Membership.JOIN, - Membership.INVITE, - Membership.KNOCK, - Membership.LEAVE, - Membership.BAN, - ) - - # Map of values to insert/update in the `sliding_sync_membership_snapshots` table - sliding_sync_membership_snapshots_insert_map: ( - SlidingSyncMembershipSnapshotSharedInsertValues - ) = {} - if membership == Membership.JOIN: - # If we're still joined, we can pull from current state. - current_state_ids_map: StateMap[str] = ( - await store.get_partial_filtered_current_state_ids( - room_id, - state_filter=StateFilter.from_types( - SLIDING_SYNC_RELEVANT_STATE_SET - ), - ) - ) - # We're iterating over rooms that we are joined to so they should - # have `current_state_events` and we should have some current state - # for each room - assert current_state_ids_map - - fetched_events = await store.get_events(current_state_ids_map.values()) - - current_state_map: StateMap[EventBase] = { - state_key: fetched_events[event_id] - for state_key, event_id in current_state_ids_map.items() - } - - state_insert_values = ( - PersistEventsStore._get_sliding_sync_insert_values_from_state_map( - current_state_map - ) - ) - sliding_sync_membership_snapshots_insert_map.update(state_insert_values) - # We should have some insert values for each room, even if they are `None` - assert sliding_sync_membership_snapshots_insert_map - - # We have current state to work from - sliding_sync_membership_snapshots_insert_map["has_known_state"] = True - elif membership in (Membership.INVITE, Membership.KNOCK) or ( - membership == Membership.LEAVE and is_outlier - ): - invite_or_knock_event_id = membership_event_id - invite_or_knock_membership = membership - - # If the event is an `out_of_band_membership` (special case of - # `outlier`), we never had historical state so we have to pull from - # the stripped state on the previous invite/knock event. This gives - # us a consistent view of the room state regardless of your - # membership (i.e. the room shouldn't disappear if your using the - # `is_encrypted` filter and you leave). - if membership == Membership.LEAVE and is_outlier: - invite_or_knock_event_id, invite_or_knock_membership = ( - await self.db_pool.runInteraction( - "sliding_sync_membership_snapshots_backfill._find_previous_membership", - _find_previous_membership_txn, - room_id, - user_id, - membership_event_stream_ordering, - ) - ) - - # Pull from the stripped state on the invite/knock event - invite_or_knock_event = await store.get_event(invite_or_knock_event_id) - - raw_stripped_state_events = None - if invite_or_knock_membership == Membership.INVITE: - invite_room_state = invite_or_knock_event.unsigned.get( - "invite_room_state" - ) - raw_stripped_state_events = invite_room_state - elif invite_or_knock_membership == Membership.KNOCK: - knock_room_state = invite_or_knock_event.unsigned.get( - "knock_room_state" - ) - raw_stripped_state_events = knock_room_state - - sliding_sync_membership_snapshots_insert_map = await self.db_pool.runInteraction( - "sliding_sync_membership_snapshots_backfill._get_sliding_sync_insert_values_from_stripped_state_txn", - PersistEventsStore._get_sliding_sync_insert_values_from_stripped_state_txn, - raw_stripped_state_events, - ) - - # We should have some insert values for each room, even if no - # stripped state is on the event because we still want to record - # that we have no known state - assert sliding_sync_membership_snapshots_insert_map - elif membership in (Membership.LEAVE, Membership.BAN): - # Pull from historical state - state_group = await store._get_state_group_for_event( - membership_event_id - ) - # We should know the state for the event - assert state_group is not None - - state_by_group = await self.db_pool.runInteraction( - "sliding_sync_membership_snapshots_backfill._get_state_groups_from_groups_txn", - self._get_state_groups_from_groups_txn, - groups=[state_group], - state_filter=StateFilter.from_types( - SLIDING_SYNC_RELEVANT_STATE_SET - ), - ) - state_ids_map = state_by_group[state_group] - - fetched_events = await store.get_events(state_ids_map.values()) - - state_map: StateMap[EventBase] = { - state_key: fetched_events[event_id] - for state_key, event_id in state_ids_map.items() - } - - state_insert_values = ( - PersistEventsStore._get_sliding_sync_insert_values_from_state_map( - state_map - ) - ) - sliding_sync_membership_snapshots_insert_map.update(state_insert_values) - # We should have some insert values for each room, even if they are `None` - assert sliding_sync_membership_snapshots_insert_map - - # We have historical state to work from - sliding_sync_membership_snapshots_insert_map["has_known_state"] = True - else: - assert_never(membership) - - to_insert_membership_snapshots[(room_id, user_id)] = ( - sliding_sync_membership_snapshots_insert_map - ) - to_insert_membership_infos[(room_id, user_id)] = SlidingSyncMembershipInfo( - user_id=user_id, - sender=sender, - membership_event_id=membership_event_id, - ) - - def _backfill_table_txn(txn: LoggingTransaction) -> None: - for key, insert_map in to_insert_membership_snapshots.items(): - room_id, user_id = key - membership_info = to_insert_membership_infos[key] - membership_event_id = membership_info.membership_event_id - - # Pulling keys/values separately is safe and will produce congruent - # lists - insert_keys = insert_map.keys() - insert_values = insert_map.values() - # We don't need to do anything `ON CONFLICT` because we never partially - # insert/update the snapshots - txn.execute( - f""" - INSERT INTO sliding_sync_membership_snapshots - (room_id, user_id, membership_event_id, membership, event_stream_ordering - {("," + ", ".join(insert_keys)) if insert_keys else ""}) - VALUES ( - ?, ?, ?, - (SELECT membership FROM room_memberships WHERE event_id = ?), - (SELECT stream_ordering FROM events WHERE event_id = ?) - {("," + ", ".join("?" for _ in insert_values)) if insert_values else ""} - ) - ON CONFLICT (room_id, user_id) - DO NOTHING - """, - [ - room_id, - user_id, - membership_event_id, - membership_event_id, - membership_event_id, - ] - + list(insert_values), - ) - - await self.db_pool.runInteraction( - "sliding_sync_membership_snapshots_backfill", _backfill_table_txn - ) - - # Update the progress - ( - _room_id, - _user_id, - _sender, - _membership_event_id, - _membership, - membership_event_stream_ordering, - _is_outlier, - ) = memberships_to_update_rows[-1] - await self.db_pool.updates._background_update_progress( - _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL, - {"last_event_stream_ordering": membership_event_stream_ordering}, - ) - - return len(memberships_to_update_rows) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 3d3f95f29c..a2122de7ee 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -36,7 +36,7 @@ from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.databases.main.events import DeltaState -from synapse.storage.databases.state.bg_updates import _BackgroundUpdates +from synapse.storage.databases.main.events_bg_updates import _BackgroundUpdates from synapse.types import StateMap from synapse.util import Clock