Implement closure of conflicted state events
This commit is contained in:
@@ -61,7 +61,11 @@ class StateResolutionStore(Protocol):
|
||||
...
|
||||
|
||||
def get_auth_chain_difference(
|
||||
self, room_id: str, state_sets: List[Set[str]]
|
||||
self,
|
||||
room_id: str,
|
||||
state_sets: List[Set[str]],
|
||||
conflicted_state_ids: Set[str],
|
||||
conflicted_boundary: Set[str],
|
||||
) -> Awaitable[Set[str]]:
|
||||
...
|
||||
|
||||
@@ -122,10 +126,12 @@ async def resolve_events_with_store(
|
||||
logger.debug("%d conflicted state entries", len(conflicted_state))
|
||||
logger.debug("Calculating auth chain difference")
|
||||
|
||||
conflicted_state_ids = set(itertools.chain.from_iterable(conflicted_state.values()))
|
||||
|
||||
# Also fetch all auth events that appear in only some of the state sets'
|
||||
# auth chains.
|
||||
auth_diff = await _get_auth_chain_difference(
|
||||
room_id, state_sets, event_map, state_res_store
|
||||
room_id, state_sets, event_map, conflicted_state_ids, state_res_store
|
||||
)
|
||||
|
||||
full_conflicted_set = set(
|
||||
@@ -272,6 +278,7 @@ async def _get_auth_chain_difference(
|
||||
room_id: str,
|
||||
state_sets: Sequence[Mapping[Any, str]],
|
||||
unpersisted_events: Dict[str, EventBase],
|
||||
conflicted_state_ids: Set[str],
|
||||
state_res_store: StateResolutionStore,
|
||||
) -> Set[str]:
|
||||
"""Compare the auth chains of each state set and return the set of events
|
||||
@@ -367,15 +374,46 @@ async def _get_auth_chain_difference(
|
||||
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
|
||||
|
||||
auth_difference_unpersisted_part: Collection[str] = union - intersection
|
||||
|
||||
persisted_conflicted_state_ids = {
|
||||
event_id for event_id in conflicted_state_ids if event_id not in union
|
||||
}
|
||||
|
||||
boundary = state_sets_ids[0].union(*state_sets_ids[1:])
|
||||
conflicted_boundary = set()
|
||||
|
||||
for event_id in persisted_conflicted_state_ids:
|
||||
auth_chain = events_to_auth_chain.get(event_id)
|
||||
if not auth_chain:
|
||||
continue
|
||||
|
||||
conflicted_boundary != auth_chain & boundary
|
||||
|
||||
else:
|
||||
auth_difference_unpersisted_part = ()
|
||||
conflicted_boundary = set()
|
||||
persisted_conflicted_state_ids = conflicted_state_ids
|
||||
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
|
||||
|
||||
difference = await state_res_store.get_auth_chain_difference(
|
||||
room_id, state_sets_ids
|
||||
difference, conflicted_boundary = await state_res_store.get_auth_chain_difference(
|
||||
room_id,
|
||||
state_sets_ids,
|
||||
persisted_conflicted_state_ids,
|
||||
conflicted_boundary,
|
||||
)
|
||||
difference.update(auth_difference_unpersisted_part)
|
||||
|
||||
unpersisted_conflicted_state_ids = (
|
||||
conflicted_state_ids - persisted_conflicted_state_ids
|
||||
)
|
||||
for boundary_event_id in conflicted_boundary:
|
||||
for conflicted_id in unpersisted_conflicted_state_ids:
|
||||
auth_chain = events_to_auth_chain[conflicted_id]
|
||||
if boundary_event_id not in auth_chain:
|
||||
continue
|
||||
|
||||
# TODO: Include all paths from conflicted_id -> boundary_id in difference.
|
||||
|
||||
return difference
|
||||
|
||||
|
||||
|
||||
@@ -377,7 +377,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
return results
|
||||
|
||||
async def get_auth_chain_difference(
|
||||
self, room_id: str, state_sets: List[Set[str]]
|
||||
self,
|
||||
room_id: str,
|
||||
state_sets: List[Set[str]],
|
||||
conflicted_state_ids: Set[str],
|
||||
conflicted_boundary: Set[str],
|
||||
) -> Set[str]:
|
||||
"""Given sets of state events figure out the auth chain difference (as
|
||||
per state res v2 algorithm).
|
||||
@@ -400,12 +404,17 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
self._get_auth_chain_difference_using_cover_index_txn,
|
||||
room_id,
|
||||
state_sets,
|
||||
conflicted_state_ids,
|
||||
conflicted_boundary,
|
||||
)
|
||||
except _NoChainCoverIndex:
|
||||
# For whatever reason we don't actually have a chain cover index
|
||||
# for the events in question, so we fall back to the old method.
|
||||
pass
|
||||
|
||||
if conflicted_boundary:
|
||||
raise NotImplementedError()
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_difference",
|
||||
self._get_auth_chain_difference_txn,
|
||||
@@ -413,8 +422,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
)
|
||||
|
||||
def _get_auth_chain_difference_using_cover_index_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
|
||||
) -> Set[str]:
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
state_sets: List[Set[str]],
|
||||
conflicted_state_ids: Set[str],
|
||||
conflicted_boundary: Set[str],
|
||||
) -> Tuple[Set[str], Set[str]]:
|
||||
"""Calculates the auth chain difference using the chain index.
|
||||
|
||||
See docs/auth_chain_difference_algorithm.md for details
|
||||
@@ -521,10 +535,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
# pulled from the database.
|
||||
chain_to_gap: Dict[int, Tuple[int, int]] = {}
|
||||
|
||||
conflicted_state_chain_ids: Dict[str, List[int]] = {}
|
||||
for event_id in conflicted_state_ids:
|
||||
chain_id, seq_no = chain_info[event_id]
|
||||
conflicted_state_chain_ids.setdefault(chain_id, []).append(seq_no)
|
||||
|
||||
# Filter down the conflicted boundary to only include events that can
|
||||
# reach conflicted state.
|
||||
conflicted_boundary_reaches_conflicted = set()
|
||||
for event_id in conflicted_boundary:
|
||||
chain_id, seq_no = chain_info[event_id]
|
||||
min_seq_nos = conflicted_state_chain_ids.get(chain_id)
|
||||
if min_seq_nos is not None and seq_no >= min(min_seq_nos):
|
||||
conflicted_boundary_reaches_conflicted.add(event_id)
|
||||
|
||||
for event_id in conflicted_boundary_reaches_conflicted:
|
||||
chain_id, seq_no = chain_info[event_id]
|
||||
conflicted_state_chain_ids[chain_id].append(seq_no)
|
||||
|
||||
for chain_id in seen_chains:
|
||||
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
|
||||
max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
|
||||
|
||||
# Now do the closure by increasing the bounds of the range to the
|
||||
# min and max of those in the conflicted state IDs
|
||||
seq_nos = conflicted_state_chain_ids.get(chain_id)
|
||||
for seq_no in seq_nos:
|
||||
min_seq_no = min(seq_no, min_seq_no)
|
||||
max_seq_no = min(seq_no, max_seq_no)
|
||||
|
||||
if min_seq_no < max_seq_no:
|
||||
# We have a non empty gap, try and fill it from the events that
|
||||
# we have, otherwise add them to the list of gaps to pull out
|
||||
@@ -539,7 +578,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
if not chain_to_gap:
|
||||
# If there are no gaps to fetch, we're done!
|
||||
return result
|
||||
return result, conflicted_boundary_reaches_conflicted
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# We can use `execute_values` to efficiently fetch the gaps when
|
||||
@@ -569,7 +608,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
txn.execute(sql, (chain_id, min_no, max_no))
|
||||
result.update(r for r, in txn)
|
||||
|
||||
return result
|
||||
return result, conflicted_boundary_reaches_conflicted
|
||||
|
||||
def _get_auth_chain_difference_txn(
|
||||
self, txn: LoggingTransaction, state_sets: List[Set[str]]
|
||||
|
||||
Reference in New Issue
Block a user