From ce41c93878075583fac06da3b7902574df64a237 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 24 Jan 2025 15:16:48 +0000 Subject: [PATCH] Clean up --- synapse/state/__init__.py | 40 +-- synapse/storage/databases/state/epochs.py | 304 +++++++++++----------- synapse/storage/databases/state/store.py | 30 +-- 3 files changed, 185 insertions(+), 189 deletions(-) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 013d1245a7..9f0fed3ff7 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -315,6 +315,9 @@ class StateHandler: """ assert not event.internal_metadata.is_outlier() + # Record the state epoch before we start calculating state groups, to + # ensure that nothing we're relying on gets deleted. See the store class + # docstring for more information. state_epoch = await self._state_epoch_store.get_state_epoch() # @@ -532,15 +535,15 @@ class StateHandler: state_group_id ) - # Check if we're trying to delete the given prev group, if so we - # pretend we didn't see it. if prev_group: - pending_deletion = ( - await self._state_epoch_store.is_state_group_pending_deletion( - prev_group + # Ensure that we still have the prev group, and ensure we don't + # delete it while we're persisting the event. + missing_state_group = ( + await self._state_epoch_store.check_state_groups_and_bump_deletion( + {prev_group} ) ) - if pending_deletion: + if missing_state_group: prev_group = None delta_ids = None @@ -696,20 +699,24 @@ class StateResolutionHandler: async with self.resolve_linearizer.queue(group_names): cache = self._state_cache.get(group_names, None) if cache: - state_groups_to_check = [] + # Check that the returned cache entry doesn't point to deleted + # state groups. + state_groups_to_check = set() if cache.state_group is not None: - state_groups_to_check.append(cache.state_group) + state_groups_to_check.add(cache.state_group) if cache.prev_group is not None: - state_groups_to_check.append(cache.prev_group) + state_groups_to_check.add(cache.prev_group) - pending_deletion = await state_res_store.state_epoch_store.are_state_groups_pending_deletion( + missing_state_groups = await state_res_store.state_epoch_store.check_state_groups_and_bump_deletion( state_groups_to_check ) - if not pending_deletion: + if not missing_state_groups: return cache else: + # There are missing state groups, so let's remove the stale + # entry and continue as if it was a cache miss. self._state_cache.pop(group_names, None) logger.info( @@ -718,15 +725,14 @@ class StateResolutionHandler: list(group_names), ) - # We double check that none of the state groups are pending - # deletion. They shouldn't be as all these state groups should be - # referenced. - pending_deletion = await state_res_store.state_epoch_store.are_state_groups_pending_deletion( + # We double check that none of the state groups have been deleted. + # They shouldn't be as all these state groups should be referenced. + missing_state_groups = await state_res_store.state_epoch_store.check_state_groups_and_bump_deletion( group_names ) - if pending_deletion: + if missing_state_groups: raise Exception( - f"state groups are pending deletion: {shortstr(pending_deletion)}" + f"State groups have been deleted: {shortstr(missing_state_groups)}" ) state_groups_histogram.observe(len(state_groups_ids)) diff --git a/synapse/storage/databases/state/epochs.py b/synapse/storage/databases/state/epochs.py index 5df0a4cf17..eccdf248f2 100644 --- a/synapse/storage/databases/state/epochs.py +++ b/synapse/storage/databases/state/epochs.py @@ -14,7 +14,14 @@ import contextlib -from typing import TYPE_CHECKING, AsyncIterator, Collection, Dict, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + AbstractSet, + AsyncIterator, + Collection, + Set, + Tuple, +) from synapse.events import EventBase from synapse.events.snapshot import EventContext @@ -35,20 +42,37 @@ if TYPE_CHECKING: class StateEpochDataStore: """Manages state epochs and checks for state group deletion. - Deleting state groups is challenging as we need to ensure that any in-flight - events that are yet to be persisted do not refer to any state groups that we - want to delete. + Deleting state groups is challenging as before we actually delete them we + need to ensure that there are no in-flight events that refer to the state + groups that we want to delete. - To handle this, we have a concept of "state epochs", which slowly increment - over time. To delete a state group we first add it to the list of "pending - deletions" with the current epoch, and wait until a certain number of epochs - have passed before attempting to actually delete the state group. If during - this period an event that references the state group tries to be persisted, - then we check if too many state epochs have passed, if they have we reject - the attempt to persist the event, and if not we clear the state groups from - the pending deletion list (as they're now referenced). + To handle this, we take two approaches. First, before we persist any event + we ensure that the state groups still exist and mark in the + `state_groups_persisting` table that the state group is about to be used. + (Note that we have to have the extra table here as state groups and events + can be in different databases, and thus we can't check for the existence of + state groups in the persist event transaction). Once the event has been + persisted, we can remove the row from `state_groups_persisting`. So long as + we check that table before deleting state groups, we can ensure that we + never persist events that reference deleted state groups, maintaining + database integrity. + + However, we want to avoid throwing exceptions so deep in the process of + persisting events. So we use a concept of `state_epochs`, where we mark + state groups as pending/proposed for deletion and wait for a certain number + epoch increments before performing the deletion. When we come to handle new + events that reference state groups, we check if they are pending deletion + and bump the epoch when they'll be deleted in (to give a chance for the + event to be persisted, or not). """ + # How frequently, roughly, to increment epochs. + TIME_BETWEEN_EPOCH_INCREMENTS_MS = 5 * 60 * 1000 + + # The number of epoch increases that must have happened between marking a + # state group as pending and actually deleting it. + NUMBER_EPOCHS_BEFORE_DELETION = 3 + def __init__( self, database: DatabasePool, @@ -63,7 +87,11 @@ class StateEpochDataStore: # running instance. if hs.config.worker.run_background_tasks: - self._clock.looping_call_now(self._advance_state_epoch, 2 * 60 * 1000) + # Add a background loop to periodically check if we should bump + # state epoch. + self._clock.looping_call_now( + self._advance_state_epoch, self.TIME_BETWEEN_EPOCH_INCREMENTS_MS / 5 + ) @wrap_as_background_process("_advance_state_epoch") async def _advance_state_epoch(self) -> None: @@ -72,7 +100,7 @@ class StateEpochDataStore: """ now = self._clock.time_msec() - update_if_before_ts = now - 10 * 60 * 1000 + update_if_before_ts = now - self.TIME_BETWEEN_EPOCH_INCREMENTS_MS def advance_state_epoch_txn(txn: LoggingTransaction) -> None: sql = """ @@ -80,19 +108,14 @@ class StateEpochDataStore: SET state_epoch = state_epoch + 1, updated_ts = ? WHERE updated_ts <= ? """ - txn.execute( - sql, - ( - now, - update_if_before_ts, - ), - ) + txn.execute(sql, (now, update_if_before_ts)) await self.db_pool.runInteraction( "_advance_state_epoch", advance_state_epoch_txn, db_autocommit=True ) async def get_state_epoch(self) -> int: + """Get the current state epoch""" return await self.db_pool.simple_select_one_onecol( table="state_epoch", retcol="state_epoch", @@ -100,143 +123,74 @@ class StateEpochDataStore: desc="get_state_epoch", ) - def _mark_state_groups_as_used_txn( - self, txn: LoggingTransaction, state_epoch: int, state_groups: Set[int] - ) -> None: - current_state_epoch = self.db_pool.simple_select_one_onecol_txn( - txn, - table="state_epoch", - retcol="state_epoch", - keyvalues={}, - ) + async def check_state_groups_and_bump_deletion( + self, state_groups: AbstractSet[int] + ) -> Collection[int]: + """Checks to make sure that the state groups haven't been deleted, and + if they're pending deletion we delay it (allowing time for any event + that will use them to finish persisting). - # TODO: Move to constant. Is the equality correct? - if current_state_epoch - state_epoch >= 2: - raise Exception("FOO") + Returns: + The state groups that are missing, if any. + """ - clause, values = make_in_list_sql_clause( - txn.database_engine, - "id", + return await self.db_pool.runInteraction( + "check_state_groups_and_bump_deletion", + self._check_state_groups_and_bump_deletion_txn, state_groups, ) + + def _check_state_groups_and_bump_deletion_txn( + self, txn: LoggingTransaction, state_groups: AbstractSet[int] + ) -> Collection[int]: + existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups) + if state_groups - existing_state_groups: + return state_groups - existing_state_groups + + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "state_group", state_groups + ) sql = f""" - SELECT id, state_epoch - FROM state_groups - LEFT JOIN state_groups_pending_deletion ON (id = state_group) + UPDATE state_groups_pending_deletion + SET state_epoch = (SELECT state_epoch FROM state_epoch) WHERE {clause} """ + txn.execute(sql, args) + + return () + + def _get_existing_groups_with_lock( + self, txn: LoggingTransaction, state_groups: Collection[int] + ) -> AbstractSet[int]: + """Return which of the given state groups are in the database, and locks + those rows with `KEY SHARE` to ensure they don't get concurrently + deleted.""" + clause, args = make_in_list_sql_clause(self.db_pool.engine, "id", state_groups) + + sql = f""" + SELECT id FROM state_groups + WHERE {clause} + """ if isinstance(self.db_pool.engine, PostgresEngine): # On postgres we add a row level lock to the rows to ensure that we # conflict with any concurrent DELETEs. `FOR KEY SHARE` lock will - # not conflict with other reads. + # not conflict with other read sql += """ - FOR KEY SHARE OF state_groups + FOR KEY SHARE """ - txn.execute(sql, values) - - state_group_to_epoch: Dict[int, Optional[int]] = {row[0]: row[1] for row in txn} - - missing_state_groups = state_groups - state_group_to_epoch.keys() - if missing_state_groups: - raise Exception( - f"state groups have been deleted: {shortstr(missing_state_groups)}" - ) - - for state_epoch_deletion in state_group_to_epoch.values(): - if state_epoch_deletion is None: - continue - - if current_state_epoch - state_epoch_deletion >= 2: - raise Exception("FOO") # TODO - - self.db_pool.simple_delete_many_batch_txn( - txn, - table="state_groups_pending_deletion", - keys=("state_group",), - values=[(state_group,) for state_group in state_groups], - ) - - self.db_pool.simple_insert_many_txn( - txn, - table="state_groups_persisting", - keys=("state_group", "instance_name"), - values=[(state_group, self._instance_name) for state_group in state_groups], - ) - - async def is_state_group_pending_deletion(self, state_group: int) -> bool: - """Check if a state group is marked as pending deletion.""" - - def is_state_group_pending_deletion_txn(txn: LoggingTransaction) -> bool: - sql = """ - SELECT 1 FROM state_groups_pending_deletion - WHERE state_group = ? - """ - txn.execute(sql, (state_group,)) - - return txn.fetchone() is not None - - return await self.db_pool.runInteraction( - "is_state_group_pending_deletion", - is_state_group_pending_deletion_txn, - ) - - async def are_state_groups_pending_deletion( - self, state_groups: Collection[int] - ) -> Collection[int]: - rows = await self.db_pool.simple_select_many_batch( - table="state_groups_pending_deletion", - column="state_group", - iterable=state_groups, - retcols=("state_group",), - desc="are_state_groups_pending_deletion", - ) - return {row[0] for row in rows} - - async def mark_state_group_as_used(self, state_group: int) -> None: - """Mark that a given state group is used""" - - # TODO: Also assert that the state group hasn't advanced too much - - await self.db_pool.simple_delete( - table="state_groups_pending_deletion", - keyvalues={"state_group": state_group}, - desc="mark_state_group_as_used", - ) - - def check_prev_group_before_insertion_txn( - self, txn: LoggingTransaction, prev_group: int, new_groups: Collection[int] - ) -> None: - sql = """ - SELECT state_epoch, (SELECT state_epoch FROM state_epoch) - FROM state_groups_pending_deletion - WHERE state_group = ? - """ - txn.execute(sql, (prev_group,)) - row = txn.fetchone() - if row is not None: - pending_deletion_epoch, current_epoch = row - if current_epoch - pending_deletion_epoch >= 2: - raise Exception("") # TODO - - self.db_pool.simple_update_txn( - txn, - table="state_groups_pending_deletion", - keyvalues={"state_group": prev_group}, - updatevalues={"state_epoch": current_epoch}, - ) - self.db_pool.simple_insert_many_txn( - txn, - table="state_groups_pending_deletion", - keys=("state_group", "state_epoch"), - values=[(state_group, current_epoch) for state_group in new_groups], - ) + txn.execute(sql, args) + return {state_group for (state_group,) in txn} @contextlib.asynccontextmanager async def persisting_state_group_references( self, event_and_contexts: Collection[Tuple[EventBase, EventContext]] ) -> AsyncIterator[None]: + """Wraps the persistence of the given events and contexts, ensuring that + any state groups referenced still exist and that they don't get deleted + during this.""" + referenced_state_groups: Set[int] = set() state_epochs = [] for event, ctx in event_and_contexts: @@ -259,12 +213,11 @@ class StateEpochDataStore: return assert state_epochs # If we have state groups we have a state epoch - min_state_epoch = min(state_epochs) + # min_state_epoch = min(state_epochs) # TODO await self.db_pool.runInteraction( "mark_state_groups_as_used", self._mark_state_groups_as_used_txn, - min_state_epoch, referenced_state_groups, ) @@ -279,12 +232,54 @@ class StateEpochDataStore: desc="persisting_state_group_references_delete", ) - def get_state_groups_that_can_be_purged( + def _mark_state_groups_as_used_txn( + self, txn: LoggingTransaction, state_groups: Set[int] + ) -> None: + """Marks the given state groups as used. Also checks that the given + state epoch is not too old.""" + + existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups) + missing_state_groups = state_groups - existing_state_groups + if missing_state_groups: + raise Exception( + f"state groups have been deleted: {shortstr(missing_state_groups)}" + ) + + self.db_pool.simple_delete_many_batch_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group",), + values=[(state_group,) for state_group in state_groups], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_persisting", + keys=("state_group", "instance_name"), + values=[(state_group, self._instance_name) for state_group in state_groups], + ) + + def get_state_groups_that_can_be_purged_txn( self, txn: LoggingTransaction, state_groups: Collection[int] ) -> Collection[int]: + """Given a set of state groups, return which state groups can be deleted.""" + if not state_groups: return state_groups + if isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we want to lock the rows FOR UPDATE as early as + # possible to help conflicts. + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "id", state_groups + ) + sql = """ + SELECT id FROM state_groups + WHERE {clause} + FOR UPDATE + """ + txn.execute(sql, args) + current_state_epoch = self.db_pool.simple_select_one_onecol_txn( txn, table="state_epoch", @@ -292,25 +287,34 @@ class StateEpochDataStore: keyvalues={}, ) + # Check the deletion status in the DB of the given state groups clause, args = make_in_list_sql_clause( self.db_pool.engine, column="state_group", iterable=state_groups ) sql = f""" - SELECT state_group FROM ( - SELECT state_group FROM state_groups_pending_deletion - WHERE state_epoch > ? + SELECT state_group, state_epoch FROM ( + SELECT state_group, state_epoch FROM state_groups_pending_deletion UNION - SELECT state_group FROM state_groups_persisting + SELECT state_group, null FROM state_groups_persisting ) AS s WHERE {clause} """ - args.insert(0, current_state_epoch - 2) txn.execute(sql, args) - can_delete = set(state_groups) - for (state_group,) in txn: - can_delete.discard(state_group) + can_delete = set() + for state_group, state_epoch in txn: + if state_epoch is None: + # A null state epoch means that we are currently persisting + # events that reference the state group, so we don't delete + # them. + continue + + if current_state_epoch - state_epoch < self.NUMBER_EPOCHS_BEFORE_DELETION: + # Not enough state epochs have occurred to allow us to delete. + continue + + can_delete.add(state_group) return can_delete diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index f073ba5473..b2ef3703c3 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -474,14 +474,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): A list of state groups """ - is_in_db = self.db_pool.simple_select_one_onecol_txn( + # We need to check that the prev group isn't about to be deleted + is_missing = self._epoch_store._check_state_groups_and_bump_deletion_txn( txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, + {prev_group}, ) - if not is_in_db: + if is_missing: raise Exception( "Trying to persist state with unpersisted prev_group: %r" % (prev_group,) @@ -554,11 +552,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ], ) - # We need to check that the prev group isn't about to be deleted - self._epoch_store.check_prev_group_before_insertion_txn( - txn, prev_group, state_groups - ) - return events_and_context return await self.db_pool.runInteraction( @@ -615,14 +608,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): needs to be persisted as a full state. """ - is_in_db = self.db_pool.simple_select_one_onecol_txn( + # We need to check that the prev group isn't about to be deleted + is_missing = self._epoch_store._check_state_groups_and_bump_deletion_txn( txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, + {prev_group}, ) - if not is_in_db: + if is_missing: raise Exception( "Trying to persist state with unpersisted prev_group: %r" % (prev_group,) @@ -658,11 +649,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ], ) - # We need to check that the prev group isn't about to be deleted - self._epoch_store.check_prev_group_before_insertion_txn( - txn, prev_group, [state_group] - ) - return state_group def insert_full_state_txn(