1
0
This commit is contained in:
Erik Johnston
2025-01-24 15:16:48 +00:00
parent 1e26c26ffc
commit ce41c93878
3 changed files with 185 additions and 189 deletions

View File

@@ -315,6 +315,9 @@ class StateHandler:
"""
assert not event.internal_metadata.is_outlier()
# Record the state epoch before we start calculating state groups, to
# ensure that nothing we're relying on gets deleted. See the store class
# docstring for more information.
state_epoch = await self._state_epoch_store.get_state_epoch()
#
@@ -532,15 +535,15 @@ class StateHandler:
state_group_id
)
# Check if we're trying to delete the given prev group, if so we
# pretend we didn't see it.
if prev_group:
pending_deletion = (
await self._state_epoch_store.is_state_group_pending_deletion(
prev_group
# Ensure that we still have the prev group, and ensure we don't
# delete it while we're persisting the event.
missing_state_group = (
await self._state_epoch_store.check_state_groups_and_bump_deletion(
{prev_group}
)
)
if pending_deletion:
if missing_state_group:
prev_group = None
delta_ids = None
@@ -696,20 +699,24 @@ class StateResolutionHandler:
async with self.resolve_linearizer.queue(group_names):
cache = self._state_cache.get(group_names, None)
if cache:
state_groups_to_check = []
# Check that the returned cache entry doesn't point to deleted
# state groups.
state_groups_to_check = set()
if cache.state_group is not None:
state_groups_to_check.append(cache.state_group)
state_groups_to_check.add(cache.state_group)
if cache.prev_group is not None:
state_groups_to_check.append(cache.prev_group)
state_groups_to_check.add(cache.prev_group)
pending_deletion = await state_res_store.state_epoch_store.are_state_groups_pending_deletion(
missing_state_groups = await state_res_store.state_epoch_store.check_state_groups_and_bump_deletion(
state_groups_to_check
)
if not pending_deletion:
if not missing_state_groups:
return cache
else:
# There are missing state groups, so let's remove the stale
# entry and continue as if it was a cache miss.
self._state_cache.pop(group_names, None)
logger.info(
@@ -718,15 +725,14 @@ class StateResolutionHandler:
list(group_names),
)
# We double check that none of the state groups are pending
# deletion. They shouldn't be as all these state groups should be
# referenced.
pending_deletion = await state_res_store.state_epoch_store.are_state_groups_pending_deletion(
# We double check that none of the state groups have been deleted.
# They shouldn't be as all these state groups should be referenced.
missing_state_groups = await state_res_store.state_epoch_store.check_state_groups_and_bump_deletion(
group_names
)
if pending_deletion:
if missing_state_groups:
raise Exception(
f"state groups are pending deletion: {shortstr(pending_deletion)}"
f"State groups have been deleted: {shortstr(missing_state_groups)}"
)
state_groups_histogram.observe(len(state_groups_ids))

View File

@@ -14,7 +14,14 @@
import contextlib
from typing import TYPE_CHECKING, AsyncIterator, Collection, Dict, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
AbstractSet,
AsyncIterator,
Collection,
Set,
Tuple,
)
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@@ -35,20 +42,37 @@ if TYPE_CHECKING:
class StateEpochDataStore:
"""Manages state epochs and checks for state group deletion.
Deleting state groups is challenging as we need to ensure that any in-flight
events that are yet to be persisted do not refer to any state groups that we
want to delete.
Deleting state groups is challenging as before we actually delete them we
need to ensure that there are no in-flight events that refer to the state
groups that we want to delete.
To handle this, we have a concept of "state epochs", which slowly increment
over time. To delete a state group we first add it to the list of "pending
deletions" with the current epoch, and wait until a certain number of epochs
have passed before attempting to actually delete the state group. If during
this period an event that references the state group tries to be persisted,
then we check if too many state epochs have passed, if they have we reject
the attempt to persist the event, and if not we clear the state groups from
the pending deletion list (as they're now referenced).
To handle this, we take two approaches. First, before we persist any event
we ensure that the state groups still exist and mark in the
`state_groups_persisting` table that the state group is about to be used.
(Note that we have to have the extra table here as state groups and events
can be in different databases, and thus we can't check for the existence of
state groups in the persist event transaction). Once the event has been
persisted, we can remove the row from `state_groups_persisting`. So long as
we check that table before deleting state groups, we can ensure that we
never persist events that reference deleted state groups, maintaining
database integrity.
However, we want to avoid throwing exceptions so deep in the process of
persisting events. So we use a concept of `state_epochs`, where we mark
state groups as pending/proposed for deletion and wait for a certain number
epoch increments before performing the deletion. When we come to handle new
events that reference state groups, we check if they are pending deletion
and bump the epoch when they'll be deleted in (to give a chance for the
event to be persisted, or not).
"""
# How frequently, roughly, to increment epochs.
TIME_BETWEEN_EPOCH_INCREMENTS_MS = 5 * 60 * 1000
# The number of epoch increases that must have happened between marking a
# state group as pending and actually deleting it.
NUMBER_EPOCHS_BEFORE_DELETION = 3
def __init__(
self,
database: DatabasePool,
@@ -63,7 +87,11 @@ class StateEpochDataStore:
# running instance.
if hs.config.worker.run_background_tasks:
self._clock.looping_call_now(self._advance_state_epoch, 2 * 60 * 1000)
# Add a background loop to periodically check if we should bump
# state epoch.
self._clock.looping_call_now(
self._advance_state_epoch, self.TIME_BETWEEN_EPOCH_INCREMENTS_MS / 5
)
@wrap_as_background_process("_advance_state_epoch")
async def _advance_state_epoch(self) -> None:
@@ -72,7 +100,7 @@ class StateEpochDataStore:
"""
now = self._clock.time_msec()
update_if_before_ts = now - 10 * 60 * 1000
update_if_before_ts = now - self.TIME_BETWEEN_EPOCH_INCREMENTS_MS
def advance_state_epoch_txn(txn: LoggingTransaction) -> None:
sql = """
@@ -80,19 +108,14 @@ class StateEpochDataStore:
SET state_epoch = state_epoch + 1, updated_ts = ?
WHERE updated_ts <= ?
"""
txn.execute(
sql,
(
now,
update_if_before_ts,
),
)
txn.execute(sql, (now, update_if_before_ts))
await self.db_pool.runInteraction(
"_advance_state_epoch", advance_state_epoch_txn, db_autocommit=True
)
async def get_state_epoch(self) -> int:
"""Get the current state epoch"""
return await self.db_pool.simple_select_one_onecol(
table="state_epoch",
retcol="state_epoch",
@@ -100,143 +123,74 @@ class StateEpochDataStore:
desc="get_state_epoch",
)
def _mark_state_groups_as_used_txn(
self, txn: LoggingTransaction, state_epoch: int, state_groups: Set[int]
) -> None:
current_state_epoch = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_epoch",
retcol="state_epoch",
keyvalues={},
)
async def check_state_groups_and_bump_deletion(
self, state_groups: AbstractSet[int]
) -> Collection[int]:
"""Checks to make sure that the state groups haven't been deleted, and
if they're pending deletion we delay it (allowing time for any event
that will use them to finish persisting).
# TODO: Move to constant. Is the equality correct?
if current_state_epoch - state_epoch >= 2:
raise Exception("FOO")
Returns:
The state groups that are missing, if any.
"""
clause, values = make_in_list_sql_clause(
txn.database_engine,
"id",
return await self.db_pool.runInteraction(
"check_state_groups_and_bump_deletion",
self._check_state_groups_and_bump_deletion_txn,
state_groups,
)
def _check_state_groups_and_bump_deletion_txn(
self, txn: LoggingTransaction, state_groups: AbstractSet[int]
) -> Collection[int]:
existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups)
if state_groups - existing_state_groups:
return state_groups - existing_state_groups
clause, args = make_in_list_sql_clause(
self.db_pool.engine, "state_group", state_groups
)
sql = f"""
SELECT id, state_epoch
FROM state_groups
LEFT JOIN state_groups_pending_deletion ON (id = state_group)
UPDATE state_groups_pending_deletion
SET state_epoch = (SELECT state_epoch FROM state_epoch)
WHERE {clause}
"""
txn.execute(sql, args)
return ()
def _get_existing_groups_with_lock(
self, txn: LoggingTransaction, state_groups: Collection[int]
) -> AbstractSet[int]:
"""Return which of the given state groups are in the database, and locks
those rows with `KEY SHARE` to ensure they don't get concurrently
deleted."""
clause, args = make_in_list_sql_clause(self.db_pool.engine, "id", state_groups)
sql = f"""
SELECT id FROM state_groups
WHERE {clause}
"""
if isinstance(self.db_pool.engine, PostgresEngine):
# On postgres we add a row level lock to the rows to ensure that we
# conflict with any concurrent DELETEs. `FOR KEY SHARE` lock will
# not conflict with other reads.
# not conflict with other read
sql += """
FOR KEY SHARE OF state_groups
FOR KEY SHARE
"""
txn.execute(sql, values)
state_group_to_epoch: Dict[int, Optional[int]] = {row[0]: row[1] for row in txn}
missing_state_groups = state_groups - state_group_to_epoch.keys()
if missing_state_groups:
raise Exception(
f"state groups have been deleted: {shortstr(missing_state_groups)}"
)
for state_epoch_deletion in state_group_to_epoch.values():
if state_epoch_deletion is None:
continue
if current_state_epoch - state_epoch_deletion >= 2:
raise Exception("FOO") # TODO
self.db_pool.simple_delete_many_batch_txn(
txn,
table="state_groups_pending_deletion",
keys=("state_group",),
values=[(state_group,) for state_group in state_groups],
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_persisting",
keys=("state_group", "instance_name"),
values=[(state_group, self._instance_name) for state_group in state_groups],
)
async def is_state_group_pending_deletion(self, state_group: int) -> bool:
"""Check if a state group is marked as pending deletion."""
def is_state_group_pending_deletion_txn(txn: LoggingTransaction) -> bool:
sql = """
SELECT 1 FROM state_groups_pending_deletion
WHERE state_group = ?
"""
txn.execute(sql, (state_group,))
return txn.fetchone() is not None
return await self.db_pool.runInteraction(
"is_state_group_pending_deletion",
is_state_group_pending_deletion_txn,
)
async def are_state_groups_pending_deletion(
self, state_groups: Collection[int]
) -> Collection[int]:
rows = await self.db_pool.simple_select_many_batch(
table="state_groups_pending_deletion",
column="state_group",
iterable=state_groups,
retcols=("state_group",),
desc="are_state_groups_pending_deletion",
)
return {row[0] for row in rows}
async def mark_state_group_as_used(self, state_group: int) -> None:
"""Mark that a given state group is used"""
# TODO: Also assert that the state group hasn't advanced too much
await self.db_pool.simple_delete(
table="state_groups_pending_deletion",
keyvalues={"state_group": state_group},
desc="mark_state_group_as_used",
)
def check_prev_group_before_insertion_txn(
self, txn: LoggingTransaction, prev_group: int, new_groups: Collection[int]
) -> None:
sql = """
SELECT state_epoch, (SELECT state_epoch FROM state_epoch)
FROM state_groups_pending_deletion
WHERE state_group = ?
"""
txn.execute(sql, (prev_group,))
row = txn.fetchone()
if row is not None:
pending_deletion_epoch, current_epoch = row
if current_epoch - pending_deletion_epoch >= 2:
raise Exception("") # TODO
self.db_pool.simple_update_txn(
txn,
table="state_groups_pending_deletion",
keyvalues={"state_group": prev_group},
updatevalues={"state_epoch": current_epoch},
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_pending_deletion",
keys=("state_group", "state_epoch"),
values=[(state_group, current_epoch) for state_group in new_groups],
)
txn.execute(sql, args)
return {state_group for (state_group,) in txn}
@contextlib.asynccontextmanager
async def persisting_state_group_references(
self, event_and_contexts: Collection[Tuple[EventBase, EventContext]]
) -> AsyncIterator[None]:
"""Wraps the persistence of the given events and contexts, ensuring that
any state groups referenced still exist and that they don't get deleted
during this."""
referenced_state_groups: Set[int] = set()
state_epochs = []
for event, ctx in event_and_contexts:
@@ -259,12 +213,11 @@ class StateEpochDataStore:
return
assert state_epochs # If we have state groups we have a state epoch
min_state_epoch = min(state_epochs)
# min_state_epoch = min(state_epochs) # TODO
await self.db_pool.runInteraction(
"mark_state_groups_as_used",
self._mark_state_groups_as_used_txn,
min_state_epoch,
referenced_state_groups,
)
@@ -279,12 +232,54 @@ class StateEpochDataStore:
desc="persisting_state_group_references_delete",
)
def get_state_groups_that_can_be_purged(
def _mark_state_groups_as_used_txn(
self, txn: LoggingTransaction, state_groups: Set[int]
) -> None:
"""Marks the given state groups as used. Also checks that the given
state epoch is not too old."""
existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups)
missing_state_groups = state_groups - existing_state_groups
if missing_state_groups:
raise Exception(
f"state groups have been deleted: {shortstr(missing_state_groups)}"
)
self.db_pool.simple_delete_many_batch_txn(
txn,
table="state_groups_pending_deletion",
keys=("state_group",),
values=[(state_group,) for state_group in state_groups],
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_persisting",
keys=("state_group", "instance_name"),
values=[(state_group, self._instance_name) for state_group in state_groups],
)
def get_state_groups_that_can_be_purged_txn(
self, txn: LoggingTransaction, state_groups: Collection[int]
) -> Collection[int]:
"""Given a set of state groups, return which state groups can be deleted."""
if not state_groups:
return state_groups
if isinstance(self.db_pool.engine, PostgresEngine):
# On postgres we want to lock the rows FOR UPDATE as early as
# possible to help conflicts.
clause, args = make_in_list_sql_clause(
self.db_pool.engine, "id", state_groups
)
sql = """
SELECT id FROM state_groups
WHERE {clause}
FOR UPDATE
"""
txn.execute(sql, args)
current_state_epoch = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_epoch",
@@ -292,25 +287,34 @@ class StateEpochDataStore:
keyvalues={},
)
# Check the deletion status in the DB of the given state groups
clause, args = make_in_list_sql_clause(
self.db_pool.engine, column="state_group", iterable=state_groups
)
sql = f"""
SELECT state_group FROM (
SELECT state_group FROM state_groups_pending_deletion
WHERE state_epoch > ?
SELECT state_group, state_epoch FROM (
SELECT state_group, state_epoch FROM state_groups_pending_deletion
UNION
SELECT state_group FROM state_groups_persisting
SELECT state_group, null FROM state_groups_persisting
) AS s
WHERE {clause}
"""
args.insert(0, current_state_epoch - 2)
txn.execute(sql, args)
can_delete = set(state_groups)
for (state_group,) in txn:
can_delete.discard(state_group)
can_delete = set()
for state_group, state_epoch in txn:
if state_epoch is None:
# A null state epoch means that we are currently persisting
# events that reference the state group, so we don't delete
# them.
continue
if current_state_epoch - state_epoch < self.NUMBER_EPOCHS_BEFORE_DELETION:
# Not enough state epochs have occurred to allow us to delete.
continue
can_delete.add(state_group)
return can_delete

View File

@@ -474,14 +474,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
A list of state groups
"""
is_in_db = self.db_pool.simple_select_one_onecol_txn(
# We need to check that the prev group isn't about to be deleted
is_missing = self._epoch_store._check_state_groups_and_bump_deletion_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
{prev_group},
)
if not is_in_db:
if is_missing:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
@@ -554,11 +552,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
],
)
# We need to check that the prev group isn't about to be deleted
self._epoch_store.check_prev_group_before_insertion_txn(
txn, prev_group, state_groups
)
return events_and_context
return await self.db_pool.runInteraction(
@@ -615,14 +608,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
needs to be persisted as a full state.
"""
is_in_db = self.db_pool.simple_select_one_onecol_txn(
# We need to check that the prev group isn't about to be deleted
is_missing = self._epoch_store._check_state_groups_and_bump_deletion_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
{prev_group},
)
if not is_in_db:
if is_missing:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
@@ -658,11 +649,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
],
)
# We need to check that the prev group isn't about to be deleted
self._epoch_store.check_prev_group_before_insertion_txn(
txn, prev_group, [state_group]
)
return state_group
def insert_full_state_txn(