1
0

Implement closure of conflicted state events

This commit is contained in:
Erik Johnston
2022-11-17 16:13:23 +00:00
parent 115f0eb233
commit 0e99b0bbd0
2 changed files with 86 additions and 9 deletions

View File

@@ -61,7 +61,11 @@ class StateResolutionStore(Protocol):
...
def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
self,
room_id: str,
state_sets: List[Set[str]],
conflicted_state_ids: Set[str],
conflicted_boundary: Set[str],
) -> Awaitable[Set[str]]:
...
@@ -122,10 +126,12 @@ async def resolve_events_with_store(
logger.debug("%d conflicted state entries", len(conflicted_state))
logger.debug("Calculating auth chain difference")
conflicted_state_ids = set(itertools.chain.from_iterable(conflicted_state.values()))
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
auth_diff = await _get_auth_chain_difference(
room_id, state_sets, event_map, state_res_store
room_id, state_sets, event_map, conflicted_state_ids, state_res_store
)
full_conflicted_set = set(
@@ -272,6 +278,7 @@ async def _get_auth_chain_difference(
room_id: str,
state_sets: Sequence[Mapping[Any, str]],
unpersisted_events: Dict[str, EventBase],
conflicted_state_ids: Set[str],
state_res_store: StateResolutionStore,
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
@@ -367,15 +374,46 @@ async def _get_auth_chain_difference(
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
auth_difference_unpersisted_part: Collection[str] = union - intersection
persisted_conflicted_state_ids = {
event_id for event_id in conflicted_state_ids if event_id not in union
}
boundary = state_sets_ids[0].union(*state_sets_ids[1:])
conflicted_boundary = set()
for event_id in persisted_conflicted_state_ids:
auth_chain = events_to_auth_chain.get(event_id)
if not auth_chain:
continue
conflicted_boundary != auth_chain & boundary
else:
auth_difference_unpersisted_part = ()
conflicted_boundary = set()
persisted_conflicted_state_ids = conflicted_state_ids
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
difference = await state_res_store.get_auth_chain_difference(
room_id, state_sets_ids
difference, conflicted_boundary = await state_res_store.get_auth_chain_difference(
room_id,
state_sets_ids,
persisted_conflicted_state_ids,
conflicted_boundary,
)
difference.update(auth_difference_unpersisted_part)
unpersisted_conflicted_state_ids = (
conflicted_state_ids - persisted_conflicted_state_ids
)
for boundary_event_id in conflicted_boundary:
for conflicted_id in unpersisted_conflicted_state_ids:
auth_chain = events_to_auth_chain[conflicted_id]
if boundary_event_id not in auth_chain:
continue
# TODO: Include all paths from conflicted_id -> boundary_id in difference.
return difference

View File

@@ -377,7 +377,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return results
async def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
self,
room_id: str,
state_sets: List[Set[str]],
conflicted_state_ids: Set[str],
conflicted_boundary: Set[str],
) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -400,12 +404,17 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
self._get_auth_chain_difference_using_cover_index_txn,
room_id,
state_sets,
conflicted_state_ids,
conflicted_boundary,
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
# for the events in question, so we fall back to the old method.
pass
if conflicted_boundary:
raise NotImplementedError()
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
@@ -413,8 +422,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_difference_using_cover_index_txn(
self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
self,
txn: LoggingTransaction,
room_id: str,
state_sets: List[Set[str]],
conflicted_state_ids: Set[str],
conflicted_boundary: Set[str],
) -> Tuple[Set[str], Set[str]]:
"""Calculates the auth chain difference using the chain index.
See docs/auth_chain_difference_algorithm.md for details
@@ -521,10 +535,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# pulled from the database.
chain_to_gap: Dict[int, Tuple[int, int]] = {}
conflicted_state_chain_ids: Dict[str, List[int]] = {}
for event_id in conflicted_state_ids:
chain_id, seq_no = chain_info[event_id]
conflicted_state_chain_ids.setdefault(chain_id, []).append(seq_no)
# Filter down the conflicted boundary to only include events that can
# reach conflicted state.
conflicted_boundary_reaches_conflicted = set()
for event_id in conflicted_boundary:
chain_id, seq_no = chain_info[event_id]
min_seq_nos = conflicted_state_chain_ids.get(chain_id)
if min_seq_nos is not None and seq_no >= min(min_seq_nos):
conflicted_boundary_reaches_conflicted.add(event_id)
for event_id in conflicted_boundary_reaches_conflicted:
chain_id, seq_no = chain_info[event_id]
conflicted_state_chain_ids[chain_id].append(seq_no)
for chain_id in seen_chains:
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
# Now do the closure by increasing the bounds of the range to the
# min and max of those in the conflicted state IDs
seq_nos = conflicted_state_chain_ids.get(chain_id)
for seq_no in seq_nos:
min_seq_no = min(seq_no, min_seq_no)
max_seq_no = min(seq_no, max_seq_no)
if min_seq_no < max_seq_no:
# We have a non empty gap, try and fill it from the events that
# we have, otherwise add them to the list of gaps to pull out
@@ -539,7 +578,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
if not chain_to_gap:
# If there are no gaps to fetch, we're done!
return result
return result, conflicted_boundary_reaches_conflicted
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
@@ -569,7 +608,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (chain_id, min_no, max_no))
result.update(r for r, in txn)
return result
return result, conflicted_boundary_reaches_conflicted
def _get_auth_chain_difference_txn(
self, txn: LoggingTransaction, state_sets: List[Set[str]]