From ae4a6304cd5375fe77563387f2b6092e454fcb90 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 21 Jan 2025 10:56:12 +0000 Subject: [PATCH] Check for deleted state groups --- synapse/handlers/federation_event.py | 15 +++++-- synapse/state/__init__.py | 43 +++++++++++++++---- synapse/storage/controllers/persist_events.py | 4 +- synapse/storage/databases/state/store.py | 18 +++----- 4 files changed, 55 insertions(+), 25 deletions(-) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index c85deaed56..50ab8fc54f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -151,6 +151,7 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._store = hs.get_datastores().main + self._state_store = hs.get_datastores().state self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -580,7 +581,9 @@ class FederationEventHandler: room_version.identifier, state_maps_to_resolve, event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_store + ), ) ) else: @@ -1179,7 +1182,7 @@ class FederationEventHandler: room_version, state_maps, event_map={event_id: event}, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore(self._store, self._state_store), ) except Exception as e: @@ -1874,7 +1877,9 @@ class FederationEventHandler: room_version, [local_state_id_map, claimed_auth_events_id_map], event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_store + ), ) ) else: @@ -2014,7 +2019,9 @@ class FederationEventHandler: room_version, state_sets, event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_store + ), ) ) else: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 52770dfd18..47e67f7154 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -65,6 +65,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer from synapse.storage.controllers import StateStorageController from synapse.storage.databases.main import DataStore + from synapse.storage.databases.state import StateGroupDataStore logger = logging.getLogger(__name__) metrics_logger = logging.getLogger("synapse.state.metrics") @@ -481,7 +482,10 @@ class StateHandler: @trace @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: StrCollection, await_full_state: bool = True + self, + room_id: str, + event_ids: StrCollection, + await_full_state: bool = True, ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -527,7 +531,15 @@ class StateHandler: state_group_id ) - # TODO: Check for deleted state groups + # Check if we're trying to delete the given prev group, if so we + # pretend we didn't see it. + if prev_group: + pending_deletion = ( + await self._state_store.is_state_group_pending_deletion(prev_group) + ) + if pending_deletion: + prev_group = None + delta_ids = None return _StateCacheEntry( state=None, @@ -549,7 +561,7 @@ class StateHandler: room_version, state_to_resolve, None, - state_res_store=StateResolutionStore(self.store), + state_res_store=StateResolutionStore(self.store, self._state_store), ) return result @@ -681,8 +693,22 @@ class StateResolutionHandler: async with self.resolve_linearizer.queue(group_names): cache = self._state_cache.get(group_names, None) if cache: - # TODO: Check for deleted state groups - return cache + pending_deletion = False + + if cache.state_group: + pending_deletion |= await state_res_store.state_store.is_state_group_pending_deletion( + cache.state_group + ) + + if cache.prev_group: + pending_deletion |= await state_res_store.state_store.is_state_group_pending_deletion( + cache.prev_group + ) + + if not pending_deletion: + return cache + else: + self._state_cache.pop(group_names, None) logger.info( "Resolving state for %s with groups %s", @@ -903,7 +929,8 @@ class StateResolutionStore: in well defined way. """ - store: "DataStore" + main_store: "DataStore" + state_store: "StateGroupDataStore" def get_events( self, event_ids: StrCollection, allow_rejected: bool = False @@ -918,7 +945,7 @@ class StateResolutionStore: An awaitable which resolves to a dict from event_id to event. """ - return self.store.get_events( + return self.main_store.get_events( event_ids, redact_behaviour=EventRedactBehaviour.as_is, get_prev_content=False, @@ -939,4 +966,4 @@ class StateResolutionStore: An awaitable that resolves to a set of event IDs. """ - return self.store.get_auth_chain_difference(room_id, state_sets) + return self.main_store.get_auth_chain_difference(room_id, state_sets) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 8124802418..a5d16bdde9 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -549,7 +549,7 @@ class EventsPersistenceStorageController: room_version, state_maps_by_state_group, event_map=None, - state_res_store=StateResolutionStore(self.main_store), + state_res_store=StateResolutionStore(self.main_store, self.state_store), ) return await res.get_state(self._state_controller, StateFilter.all()) @@ -976,7 +976,7 @@ class EventsPersistenceStorageController: room_version, state_groups, events_map, - state_res_store=StateResolutionStore(self.main_store), + state_res_store=StateResolutionStore(self.main_store, self.state_store), ) state_resolutions_during_persistence.inc() diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 35c1f8e407..a284f9ce48 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -237,25 +237,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): values=[(state_group,) for state_group in state_groups], ) - @cached() - async def is_state_group_pending_deletion_before( - self, state_epoch: int, state_group: int - ) -> bool: - """Check if a state group is marked as pending deletion in a previous - epoch, but does not check the current epoch.""" + async def is_state_group_pending_deletion(self, state_group: int) -> bool: + """Check if a state group is marked as pending deletion.""" - def is_state_group_pending_deletion_before_txn(txn: LoggingTransaction) -> bool: + def is_state_group_pending_deletion_txn(txn: LoggingTransaction) -> bool: sql = """ SELECT 1 FROM state_groups_pending_deletion - WHERE state_epoch < ? AND state_group = ? + WHERE state_group = ? """ - txn.execute(sql, (state_epoch, state_group)) + txn.execute(sql, (state_group,)) return txn.fetchone() is not None return await self.db_pool.runInteraction( - "is_state_group_pending_deletion_before", - is_state_group_pending_deletion_before_txn, + "is_state_group_pending_deletion", + is_state_group_pending_deletion_txn, ) async def mark_state_group_as_used(self, state_group: int) -> None: