1
0

Check for deleted state groups

This commit is contained in:
Erik Johnston
2025-01-21 10:56:12 +00:00
parent c02938c670
commit ae4a6304cd
4 changed files with 55 additions and 25 deletions

View File

@@ -151,6 +151,7 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._state_store = hs.get_datastores().state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@@ -580,7 +581,9 @@ class FederationEventHandler:
room_version.identifier,
state_maps_to_resolve,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_store
),
)
)
else:
@@ -1179,7 +1182,7 @@ class FederationEventHandler:
room_version,
state_maps,
event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(self._store, self._state_store),
)
except Exception as e:
@@ -1874,7 +1877,9 @@ class FederationEventHandler:
room_version,
[local_state_id_map, claimed_auth_events_id_map],
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_store
),
)
)
else:
@@ -2014,7 +2019,9 @@ class FederationEventHandler:
room_version,
state_sets,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_store
),
)
)
else:

View File

@@ -65,6 +65,7 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore
from synapse.storage.databases.state import StateGroupDataStore
logger = logging.getLogger(__name__)
metrics_logger = logging.getLogger("synapse.state.metrics")
@@ -481,7 +482,10 @@ class StateHandler:
@trace
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
self,
room_id: str,
event_ids: StrCollection,
await_full_state: bool = True,
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -527,7 +531,15 @@ class StateHandler:
state_group_id
)
# TODO: Check for deleted state groups
# Check if we're trying to delete the given prev group, if so we
# pretend we didn't see it.
if prev_group:
pending_deletion = (
await self._state_store.is_state_group_pending_deletion(prev_group)
)
if pending_deletion:
prev_group = None
delta_ids = None
return _StateCacheEntry(
state=None,
@@ -549,7 +561,7 @@ class StateHandler:
room_version,
state_to_resolve,
None,
state_res_store=StateResolutionStore(self.store),
state_res_store=StateResolutionStore(self.store, self._state_store),
)
return result
@@ -681,8 +693,22 @@ class StateResolutionHandler:
async with self.resolve_linearizer.queue(group_names):
cache = self._state_cache.get(group_names, None)
if cache:
# TODO: Check for deleted state groups
return cache
pending_deletion = False
if cache.state_group:
pending_deletion |= await state_res_store.state_store.is_state_group_pending_deletion(
cache.state_group
)
if cache.prev_group:
pending_deletion |= await state_res_store.state_store.is_state_group_pending_deletion(
cache.prev_group
)
if not pending_deletion:
return cache
else:
self._state_cache.pop(group_names, None)
logger.info(
"Resolving state for %s with groups %s",
@@ -903,7 +929,8 @@ class StateResolutionStore:
in well defined way.
"""
store: "DataStore"
main_store: "DataStore"
state_store: "StateGroupDataStore"
def get_events(
self, event_ids: StrCollection, allow_rejected: bool = False
@@ -918,7 +945,7 @@ class StateResolutionStore:
An awaitable which resolves to a dict from event_id to event.
"""
return self.store.get_events(
return self.main_store.get_events(
event_ids,
redact_behaviour=EventRedactBehaviour.as_is,
get_prev_content=False,
@@ -939,4 +966,4 @@ class StateResolutionStore:
An awaitable that resolves to a set of event IDs.
"""
return self.store.get_auth_chain_difference(room_id, state_sets)
return self.main_store.get_auth_chain_difference(room_id, state_sets)

View File

@@ -549,7 +549,7 @@ class EventsPersistenceStorageController:
room_version,
state_maps_by_state_group,
event_map=None,
state_res_store=StateResolutionStore(self.main_store),
state_res_store=StateResolutionStore(self.main_store, self.state_store),
)
return await res.get_state(self._state_controller, StateFilter.all())
@@ -976,7 +976,7 @@ class EventsPersistenceStorageController:
room_version,
state_groups,
events_map,
state_res_store=StateResolutionStore(self.main_store),
state_res_store=StateResolutionStore(self.main_store, self.state_store),
)
state_resolutions_during_persistence.inc()

View File

@@ -237,25 +237,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
values=[(state_group,) for state_group in state_groups],
)
@cached()
async def is_state_group_pending_deletion_before(
self, state_epoch: int, state_group: int
) -> bool:
"""Check if a state group is marked as pending deletion in a previous
epoch, but does not check the current epoch."""
async def is_state_group_pending_deletion(self, state_group: int) -> bool:
"""Check if a state group is marked as pending deletion."""
def is_state_group_pending_deletion_before_txn(txn: LoggingTransaction) -> bool:
def is_state_group_pending_deletion_txn(txn: LoggingTransaction) -> bool:
sql = """
SELECT 1 FROM state_groups_pending_deletion
WHERE state_epoch < ? AND state_group = ?
WHERE state_group = ?
"""
txn.execute(sql, (state_epoch, state_group))
txn.execute(sql, (state_group,))
return txn.fetchone() is not None
return await self.db_pool.runInteraction(
"is_state_group_pending_deletion_before",
is_state_group_pending_deletion_before_txn,
"is_state_group_pending_deletion",
is_state_group_pending_deletion_txn,
)
async def mark_state_group_as_used(self, state_group: int) -> None: