Check for deleted state groups
This commit is contained in:
@@ -151,6 +151,7 @@ class FederationEventHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._clock = hs.get_clock()
|
||||
self._store = hs.get_datastores().main
|
||||
self._state_store = hs.get_datastores().state
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
|
||||
@@ -580,7 +581,9 @@ class FederationEventHandler:
|
||||
room_version.identifier,
|
||||
state_maps_to_resolve,
|
||||
event_map=None,
|
||||
state_res_store=StateResolutionStore(self._store),
|
||||
state_res_store=StateResolutionStore(
|
||||
self._store, self._state_store
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -1179,7 +1182,7 @@ class FederationEventHandler:
|
||||
room_version,
|
||||
state_maps,
|
||||
event_map={event_id: event},
|
||||
state_res_store=StateResolutionStore(self._store),
|
||||
state_res_store=StateResolutionStore(self._store, self._state_store),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -1874,7 +1877,9 @@ class FederationEventHandler:
|
||||
room_version,
|
||||
[local_state_id_map, claimed_auth_events_id_map],
|
||||
event_map=None,
|
||||
state_res_store=StateResolutionStore(self._store),
|
||||
state_res_store=StateResolutionStore(
|
||||
self._store, self._state_store
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -2014,7 +2019,9 @@ class FederationEventHandler:
|
||||
room_version,
|
||||
state_sets,
|
||||
event_map=None,
|
||||
state_res_store=StateResolutionStore(self._store),
|
||||
state_res_store=StateResolutionStore(
|
||||
self._store, self._state_store
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -65,6 +65,7 @@ if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.controllers import StateStorageController
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.storage.databases.state import StateGroupDataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
metrics_logger = logging.getLogger("synapse.state.metrics")
|
||||
@@ -481,7 +482,10 @@ class StateHandler:
|
||||
@trace
|
||||
@measure_func()
|
||||
async def resolve_state_groups_for_events(
|
||||
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
|
||||
self,
|
||||
room_id: str,
|
||||
event_ids: StrCollection,
|
||||
await_full_state: bool = True,
|
||||
) -> _StateCacheEntry:
|
||||
"""Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
@@ -527,7 +531,15 @@ class StateHandler:
|
||||
state_group_id
|
||||
)
|
||||
|
||||
# TODO: Check for deleted state groups
|
||||
# Check if we're trying to delete the given prev group, if so we
|
||||
# pretend we didn't see it.
|
||||
if prev_group:
|
||||
pending_deletion = (
|
||||
await self._state_store.is_state_group_pending_deletion(prev_group)
|
||||
)
|
||||
if pending_deletion:
|
||||
prev_group = None
|
||||
delta_ids = None
|
||||
|
||||
return _StateCacheEntry(
|
||||
state=None,
|
||||
@@ -549,7 +561,7 @@ class StateHandler:
|
||||
room_version,
|
||||
state_to_resolve,
|
||||
None,
|
||||
state_res_store=StateResolutionStore(self.store),
|
||||
state_res_store=StateResolutionStore(self.store, self._state_store),
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -681,8 +693,22 @@ class StateResolutionHandler:
|
||||
async with self.resolve_linearizer.queue(group_names):
|
||||
cache = self._state_cache.get(group_names, None)
|
||||
if cache:
|
||||
# TODO: Check for deleted state groups
|
||||
return cache
|
||||
pending_deletion = False
|
||||
|
||||
if cache.state_group:
|
||||
pending_deletion |= await state_res_store.state_store.is_state_group_pending_deletion(
|
||||
cache.state_group
|
||||
)
|
||||
|
||||
if cache.prev_group:
|
||||
pending_deletion |= await state_res_store.state_store.is_state_group_pending_deletion(
|
||||
cache.prev_group
|
||||
)
|
||||
|
||||
if not pending_deletion:
|
||||
return cache
|
||||
else:
|
||||
self._state_cache.pop(group_names, None)
|
||||
|
||||
logger.info(
|
||||
"Resolving state for %s with groups %s",
|
||||
@@ -903,7 +929,8 @@ class StateResolutionStore:
|
||||
in well defined way.
|
||||
"""
|
||||
|
||||
store: "DataStore"
|
||||
main_store: "DataStore"
|
||||
state_store: "StateGroupDataStore"
|
||||
|
||||
def get_events(
|
||||
self, event_ids: StrCollection, allow_rejected: bool = False
|
||||
@@ -918,7 +945,7 @@ class StateResolutionStore:
|
||||
An awaitable which resolves to a dict from event_id to event.
|
||||
"""
|
||||
|
||||
return self.store.get_events(
|
||||
return self.main_store.get_events(
|
||||
event_ids,
|
||||
redact_behaviour=EventRedactBehaviour.as_is,
|
||||
get_prev_content=False,
|
||||
@@ -939,4 +966,4 @@ class StateResolutionStore:
|
||||
An awaitable that resolves to a set of event IDs.
|
||||
"""
|
||||
|
||||
return self.store.get_auth_chain_difference(room_id, state_sets)
|
||||
return self.main_store.get_auth_chain_difference(room_id, state_sets)
|
||||
|
||||
@@ -549,7 +549,7 @@ class EventsPersistenceStorageController:
|
||||
room_version,
|
||||
state_maps_by_state_group,
|
||||
event_map=None,
|
||||
state_res_store=StateResolutionStore(self.main_store),
|
||||
state_res_store=StateResolutionStore(self.main_store, self.state_store),
|
||||
)
|
||||
|
||||
return await res.get_state(self._state_controller, StateFilter.all())
|
||||
@@ -976,7 +976,7 @@ class EventsPersistenceStorageController:
|
||||
room_version,
|
||||
state_groups,
|
||||
events_map,
|
||||
state_res_store=StateResolutionStore(self.main_store),
|
||||
state_res_store=StateResolutionStore(self.main_store, self.state_store),
|
||||
)
|
||||
|
||||
state_resolutions_during_persistence.inc()
|
||||
|
||||
@@ -237,25 +237,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||
values=[(state_group,) for state_group in state_groups],
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def is_state_group_pending_deletion_before(
|
||||
self, state_epoch: int, state_group: int
|
||||
) -> bool:
|
||||
"""Check if a state group is marked as pending deletion in a previous
|
||||
epoch, but does not check the current epoch."""
|
||||
async def is_state_group_pending_deletion(self, state_group: int) -> bool:
|
||||
"""Check if a state group is marked as pending deletion."""
|
||||
|
||||
def is_state_group_pending_deletion_before_txn(txn: LoggingTransaction) -> bool:
|
||||
def is_state_group_pending_deletion_txn(txn: LoggingTransaction) -> bool:
|
||||
sql = """
|
||||
SELECT 1 FROM state_groups_pending_deletion
|
||||
WHERE state_epoch < ? AND state_group = ?
|
||||
WHERE state_group = ?
|
||||
"""
|
||||
txn.execute(sql, (state_epoch, state_group))
|
||||
txn.execute(sql, (state_group,))
|
||||
|
||||
return txn.fetchone() is not None
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"is_state_group_pending_deletion_before",
|
||||
is_state_group_pending_deletion_before_txn,
|
||||
"is_state_group_pending_deletion",
|
||||
is_state_group_pending_deletion_txn,
|
||||
)
|
||||
|
||||
async def mark_state_group_as_used(self, state_group: int) -> None:
|
||||
|
||||
Reference in New Issue
Block a user