From 0b4eeacb585ed0eec24576e5d211382eaa362fac Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 22 Jan 2025 11:55:12 +0000 Subject: [PATCH] Move stuff --- synapse/handlers/federation_event.py | 11 +- synapse/state/__init__.py | 17 +- synapse/storage/controllers/persist_events.py | 11 +- synapse/storage/databases/__init__.py | 8 +- synapse/storage/databases/state/epochs.py | 159 ++++++++++++++++++ synapse/storage/databases/state/store.py | 122 -------------- 6 files changed, 191 insertions(+), 137 deletions(-) create mode 100644 synapse/storage/databases/state/epochs.py diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 50ab8fc54f..24684f5719 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -152,6 +152,7 @@ class FederationEventHandler: self._clock = hs.get_clock() self._store = hs.get_datastores().main self._state_store = hs.get_datastores().state + self._state_epoch_store = hs.get_datastores().state_epochs self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -582,7 +583,7 @@ class FederationEventHandler: state_maps_to_resolve, event_map=None, state_res_store=StateResolutionStore( - self._store, self._state_store + self._store, self._state_epoch_store ), ) ) @@ -1182,7 +1183,9 @@ class FederationEventHandler: room_version, state_maps, event_map={event_id: event}, - state_res_store=StateResolutionStore(self._store, self._state_store), + state_res_store=StateResolutionStore( + self._store, self._state_epoch_store + ), ) except Exception as e: @@ -1878,7 +1881,7 @@ class FederationEventHandler: [local_state_id_map, claimed_auth_events_id_map], event_map=None, state_res_store=StateResolutionStore( - self._store, self._state_store + self._store, self._state_epoch_store ), ) ) @@ -2020,7 +2023,7 @@ class FederationEventHandler: state_sets, event_map=None, state_res_store=StateResolutionStore( - self._store, self._state_store + self._store, self._state_epoch_store ), ) ) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 47e67f7154..fdf07838d9 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -65,7 +65,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer from synapse.storage.controllers import StateStorageController from synapse.storage.databases.main import DataStore - from synapse.storage.databases.state import StateGroupDataStore + from synapse.storage.databases.state.epochs import StateEpochDataStore logger = logging.getLogger(__name__) metrics_logger = logging.getLogger("synapse.state.metrics") @@ -197,6 +197,7 @@ class StateHandler: self._events_shard_config = hs.config.worker.events_shard_config self._instance_name = hs.get_instance_name() self._state_store = hs.get_datastores().state + self._state_epoch_store = hs.get_datastores().state_epochs self._update_current_state_client = ( ReplicationUpdateCurrentStateRestServlet.make_client(hs) @@ -314,7 +315,7 @@ class StateHandler: """ assert not event.internal_metadata.is_outlier() - state_epoch = await self._state_store.get_state_epoch() # TODO: Get state epoch + state_epoch = await self._state_epoch_store.get_state_epoch() # # first of all, figure out the state before the event, unless we @@ -535,7 +536,9 @@ class StateHandler: # pretend we didn't see it. if prev_group: pending_deletion = ( - await self._state_store.is_state_group_pending_deletion(prev_group) + await self._state_epoch_store.is_state_group_pending_deletion( + prev_group + ) ) if pending_deletion: prev_group = None @@ -561,7 +564,7 @@ class StateHandler: room_version, state_to_resolve, None, - state_res_store=StateResolutionStore(self.store, self._state_store), + state_res_store=StateResolutionStore(self.store, self._state_epoch_store), ) return result @@ -696,12 +699,12 @@ class StateResolutionHandler: pending_deletion = False if cache.state_group: - pending_deletion |= await state_res_store.state_store.is_state_group_pending_deletion( + pending_deletion |= await state_res_store.state_epoch_store.is_state_group_pending_deletion( cache.state_group ) if cache.prev_group: - pending_deletion |= await state_res_store.state_store.is_state_group_pending_deletion( + pending_deletion |= await state_res_store.state_epoch_store.is_state_group_pending_deletion( cache.prev_group ) @@ -930,7 +933,7 @@ class StateResolutionStore: """ main_store: "DataStore" - state_store: "StateGroupDataStore" + state_epoch_store: "StateEpochDataStore" def get_events( self, event_ids: StrCollection, allow_rejected: bool = False diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index a5d16bdde9..7d37d155a4 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -332,6 +332,7 @@ class EventsPersistenceStorageController: # store for now. self.main_store = stores.main self.state_store = stores.state + self._state_epoch_store = stores.state_epochs assert stores.persist_events self.persist_events_store = stores.persist_events @@ -549,7 +550,9 @@ class EventsPersistenceStorageController: room_version, state_maps_by_state_group, event_map=None, - state_res_store=StateResolutionStore(self.main_store, self.state_store), + state_res_store=StateResolutionStore( + self.main_store, self._state_epoch_store + ), ) return await res.get_state(self._state_controller, StateFilter.all()) @@ -644,7 +647,7 @@ class EventsPersistenceStorageController: # TODO: Add a table to track what state groups we're currently # inserting? There's a race where this transaction takes so long # that we delete the state groups we're inserting. - await self.state_store.mark_state_groups_as_used(events_and_contexts) + await self._state_epoch_store.mark_state_groups_as_used(events_and_contexts) await self.persist_events_store._persist_events_and_state_updates( room_id, @@ -976,7 +979,9 @@ class EventsPersistenceStorageController: room_version, state_groups, events_map, - state_res_store=StateResolutionStore(self.main_store, self.state_store), + state_res_store=StateResolutionStore( + self.main_store, self._state_epoch_store + ), ) state_resolutions_during_persistence.inc() diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index dd9fc01fb0..d49437834d 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -26,6 +26,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.databases.state import StateGroupDataStore +from synapse.storage.databases.state.epochs import StateEpochDataStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database @@ -49,12 +50,14 @@ class Databases(Generic[DataStoreT]): main state persist_events + state_epochs """ databases: List[DatabasePool] main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class` state: StateGroupDataStore persist_events: Optional[PersistEventsStore] + state_epochs: StateEpochDataStore def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): # Note we pass in the main store class here as workers use a different main @@ -63,6 +66,7 @@ class Databases(Generic[DataStoreT]): self.databases = [] main: Optional[DataStoreT] = None state: Optional[StateGroupDataStore] = None + state_epochs: Optional[StateEpochDataStore] = None persist_events: Optional[PersistEventsStore] = None for database_config in hs.config.database.databases: @@ -115,6 +119,7 @@ class Databases(Generic[DataStoreT]): raise Exception("'state' data store already configured") state = StateGroupDataStore(database, db_conn, hs) + state_epochs = StateEpochDataStore(database, db_conn, hs) db_conn.commit() @@ -135,7 +140,7 @@ class Databases(Generic[DataStoreT]): if not main: raise Exception("No 'main' database configured") - if not state: + if not state or not state_epochs: raise Exception("No 'state' database configured") # We use local variables here to ensure that the databases do not have @@ -143,3 +148,4 @@ class Databases(Generic[DataStoreT]): self.main = main # type: ignore[assignment] self.state = state self.persist_events = persist_events + self.state_epochs = state_epochs diff --git a/synapse/storage/databases/state/epochs.py b/synapse/storage/databases/state/epochs.py new file mode 100644 index 0000000000..3ac891732c --- /dev/null +++ b/synapse/storage/databases/state/epochs.py @@ -0,0 +1,159 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +from typing import TYPE_CHECKING, Collection, Tuple + +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class StateEpochDataStore: + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + self._clock = hs.get_clock() + self.db_pool = database + + if hs.config.worker.run_background_tasks: + self._clock.looping_call_now(self._advance_state_epoch, 2 * 60 * 1000) + + @wrap_as_background_process("_advance_state_epoch") + async def _advance_state_epoch(self) -> None: + """Advances the state epoch, checking that we haven't advanced it too + recently. + """ + + now = self._clock.time_msec() + update_if_before_ts = now - 10 * 60 * 1000 + + def advance_state_epoch_txn(txn: LoggingTransaction) -> None: + sql = """ + UPDATE state_epoch + SET state_epoch = state_epoch + 1, updated_ts = ? + WHERE updated_ts <= ? + """ + txn.execute( + sql, + ( + now, + update_if_before_ts, + ), + ) + + await self.db_pool.runInteraction( + "_advance_state_epoch", advance_state_epoch_txn, db_autocommit=True + ) + + async def get_state_epoch(self) -> int: + return await self.db_pool.simple_select_one_onecol( + table="state_epoch", + retcol="state_epoch", + keyvalues={}, + desc="get_state_epoch", + ) + + async def mark_state_groups_as_used( + self, event_and_contexts: Collection[Tuple[EventBase, EventContext]] + ) -> None: + referenced_state_groups = [] + state_epochs = [] + for event, ctx in event_and_contexts: + if ctx.rejected or event.internal_metadata.is_outlier(): + continue + + assert ctx.state_epoch is not None + assert ctx.state_group is not None + + state_epochs.append(ctx.state_epoch) + + referenced_state_groups.append(ctx.state_group) + + if ctx.state_group_before_event: + referenced_state_groups.append(ctx.state_group_before_event) + + if not referenced_state_groups: + # We don't reference any state groups, so nothing to do + return + + assert state_epochs # If we have state groups we have a state epoch + min_state_epoch = min(state_epochs) + + await self.db_pool.runInteraction( + "mark_state_groups_as_used", + self._mark_state_groups_as_used_txn, + min_state_epoch, + referenced_state_groups, + ) + + def _mark_state_groups_as_used_txn( + self, txn: LoggingTransaction, state_epoch: int, state_groups: Collection[int] + ) -> None: + current_state_epoch = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_epoch", + retcol="state_epoch", + keyvalues={}, + ) + + # TODO: Move to constant. Is the equality correct? + if current_state_epoch - state_epoch >= 2: + raise Exception("FOO") + + self.db_pool.simple_delete_many_batch_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group",), + values=[(state_group,) for state_group in state_groups], + ) + + async def is_state_group_pending_deletion(self, state_group: int) -> bool: + """Check if a state group is marked as pending deletion.""" + + def is_state_group_pending_deletion_txn(txn: LoggingTransaction) -> bool: + sql = """ + SELECT 1 FROM state_groups_pending_deletion + WHERE state_group = ? + """ + txn.execute(sql, (state_group,)) + + return txn.fetchone() is not None + + return await self.db_pool.runInteraction( + "is_state_group_pending_deletion", + is_state_group_pending_deletion_txn, + ) + + async def mark_state_group_as_used(self, state_group: int) -> None: + """Mark that a given state group is used""" + + # TODO: Also assert that the state group hasn't advanced too much + + await self.db_pool.simple_delete( + table="state_groups_pending_deletion", + keyvalues={"state_group": state_group}, + desc="mark_state_group_as_used", + ) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index a284f9ce48..fd3f951a81 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -37,12 +37,10 @@ import attr from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.events.snapshot import ( - EventContext, UnpersistedEventContext, UnpersistedEventContextBase, ) from synapse.logging.opentracing import tag_args, trace -from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -145,126 +143,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): id_column="id", ) - if hs.config.worker.run_background_tasks: - self._clock.looping_call_now(self._advance_state_epoch, 2 * 60 * 1000) - - @wrap_as_background_process("_advance_state_epoch") - async def _advance_state_epoch(self) -> None: - """Advances the state epoch, checking that we haven't advanced it too - recently. - """ - - now = self._clock.time_msec() - update_if_before_ts = now - 10 * 60 * 1000 - - def advance_state_epoch_txn(txn: LoggingTransaction) -> None: - sql = """ - UPDATE state_epoch - SET state_epoch = state_epoch + 1, updated_ts = ? - WHERE updated_ts <= ? - """ - txn.execute( - sql, - ( - now, - update_if_before_ts, - ), - ) - - await self.db_pool.runInteraction( - "_advance_state_epoch", advance_state_epoch_txn, db_autocommit=True - ) - - async def get_state_epoch(self) -> int: - return await self.db_pool.simple_select_one_onecol( - table="state_epoch", - retcol="state_epoch", - keyvalues={}, - desc="get_state_epoch", - ) - - async def mark_state_groups_as_used( - self, event_and_contexts: Collection[Tuple[EventBase, EventContext]] - ) -> None: - referenced_state_groups = [] - state_epochs = [] - for event, ctx in event_and_contexts: - if ctx.rejected or event.internal_metadata.is_outlier(): - continue - - assert ctx.state_epoch is not None - assert ctx.state_group is not None - - state_epochs.append(ctx.state_epoch) - - referenced_state_groups.append(ctx.state_group) - - if ctx.state_group_before_event: - referenced_state_groups.append(ctx.state_group_before_event) - - if not referenced_state_groups: - # We don't reference any state groups, so nothing to do - return - - assert state_epochs # If we have state groups we have a state epoch - min_state_epoch = min(state_epochs) - - await self.db_pool.runInteraction( - "mark_state_groups_as_used", - self._mark_state_groups_as_used_txn, - min_state_epoch, - referenced_state_groups, - ) - - def _mark_state_groups_as_used_txn( - self, txn: LoggingTransaction, state_epoch: int, state_groups: Collection[int] - ) -> None: - current_state_epoch = self.db_pool.simple_select_one_onecol_txn( - txn, - table="state_epoch", - retcol="state_epoch", - keyvalues={}, - ) - - # TODO: Move to constant. Is the equality correct? - if current_state_epoch - state_epoch >= 2: - raise Exception("FOO") - - self.db_pool.simple_delete_many_batch_txn( - txn, - table="state_groups_pending_deletion", - keys=("state_group",), - values=[(state_group,) for state_group in state_groups], - ) - - async def is_state_group_pending_deletion(self, state_group: int) -> bool: - """Check if a state group is marked as pending deletion.""" - - def is_state_group_pending_deletion_txn(txn: LoggingTransaction) -> bool: - sql = """ - SELECT 1 FROM state_groups_pending_deletion - WHERE state_group = ? - """ - txn.execute(sql, (state_group,)) - - return txn.fetchone() is not None - - return await self.db_pool.runInteraction( - "is_state_group_pending_deletion", - is_state_group_pending_deletion_txn, - ) - - async def mark_state_group_as_used(self, state_group: int) -> None: - """Mark that a given state group is used""" - - # TODO: Also assert that the state group hasn't advanced too much - - await self.db_pool.simple_delete( - table="state_groups_pending_deletion", - keyvalues={"state_group": state_group}, - desc="mark_state_group_as_used", - ) - @cached(max_entries=10000, iterable=True) async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta: """Given a state group try to return a previous group and a delta between