1
0

Move get_state methods into FederationHandler (#6503)

* commit 'be294d6fd':
  Move get_state methods into FederationHandler (#6503)
This commit is contained in:
Andrew Morgan
2020-03-19 18:52:20 +00:00
2 changed files with 96 additions and 171 deletions

View File

@@ -39,7 +39,7 @@ from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.utils import log_function
from synapse.util import batch_iter, unwrapFirstError
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -327,74 +327,7 @@ class FederationClient(FederationBase):
):
raise Exception("invalid response from /state_ids")
desired_events = set(state_event_ids + auth_event_ids)
event_map = yield self.get_events_from_store_or_dest(
destination, room_id, desired_events
)
failed_to_fetch = desired_events - event_map.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state/auth events for %s: %s",
room_id,
failed_to_fetch,
)
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
return pdus, auth_chain
@defer.inlineCallbacks
def get_events_from_store_or_dest(self, destination, room_id, event_ids):
"""Fetch events from a remote destination, checking if we already have them.
Args:
destination (str)
room_id (str)
event_ids (Iterable[str])
Returns:
Deferred[dict[str, EventBase]]: A deferred resolving to a map
from event_id to event
"""
fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
missing_events = set(event_ids) - fetched_events.keys()
if not missing_events:
return fetched_events
logger.debug(
"Fetching unknown state/auth events %s for room %s",
missing_events,
event_ids,
)
room_version = yield self.store.get_room_version(room_id)
# XXX 20 requests at once? really?
for batch in batch_iter(missing_events, 20):
deferreds = [
run_in_background(
self.get_pdu,
destinations=[destination],
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]
res = yield make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
if success and result:
fetched_events[result.event_id] = result
return fetched_events
return state_event_ids, auth_event_ids
@defer.inlineCallbacks
@log_function

View File

@@ -65,6 +65,7 @@ from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRes
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
from synapse.util import batch_iter, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -237,6 +238,7 @@ class FederationHandler(BaseHandler):
return None
state = None
auth_chain = []
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
@@ -338,14 +340,9 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
logger.info(
"Event %s is missing prev_events: calculating state for a "
"backwards extremity",
event_id,
)
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
auth_chains = set()
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
@@ -363,17 +360,42 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
"Requesting state at missing prev_event %s", event_id,
"[%s %s] Requesting state at missing prev_event %s",
room_id,
event_id,
p,
)
room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
(remote_state, _,) = await self._get_state_for_room(
origin, room_id, p, include_event_in_state=True
(
remote_state,
got_auth_chain,
) = await self._get_state_for_room(origin, room_id, p)
# we want the state *after* p; _get_state_for_room returns the
# state *before* p.
remote_event = await self.federation_client.get_pdu(
[origin], p, room_version, outlier=True
)
if remote_event is None:
raise Exception(
"Unable to get missing prev_event %s" % (p,)
)
if remote_event.is_state():
remote_state.append(remote_event)
# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
auth_chains.update(got_auth_chain)
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
@@ -382,7 +404,6 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
room_version = await self.store.get_room_version(room_id)
state_map = await resolve_events_with_store(
room_id,
room_version,
@@ -399,11 +420,12 @@ class FederationHandler(BaseHandler):
evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.AS_IS,
check_redacted=False,
)
event_map.update(evs)
state = [event_map[e] for e in six.itervalues(state_map)]
auth_chain = list(auth_chains)
except Exception:
logger.warning(
"[%s %s] Error attempting to resolve state at missing "
@@ -419,7 +441,9 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
await self._process_received_pdu(origin, pdu, state=state)
await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain
)
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
@@ -553,131 +577,99 @@ class FederationHandler(BaseHandler):
else:
raise
async def _get_state_for_room(
self,
destination: str,
room_id: str,
event_id: str,
include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]:
@defer.inlineCallbacks
@log_function
def _get_state_for_room(self, destination, room_id, event_id):
"""Requests all of the room state at a given event from a remote homeserver.
Args:
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
include_event_in_state: if true, the event itself will be included in the
returned state event list.
destination (str): The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in.
event_id (str): The id of the event we want the state at.
Returns:
A list of events in the state, possibly including the event itself, and
a list of events in the auth chain for the given event.
Deferred[Tuple[List[EventBase], List[EventBase]]]:
A list of events in the state, and a list of events in the auth chain
for the given event.
"""
(
state_event_ids,
auth_event_ids,
) = await self.federation_client.get_room_state_ids(
) = yield self.federation_client.get_room_state_ids(
destination, room_id, event_id=event_id
)
desired_events = set(state_event_ids + auth_event_ids)
if include_event_in_state:
desired_events.add(event_id)
event_map = await self._get_events_from_store_or_dest(
event_map = yield self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
failed_to_fetch = desired_events - event_map.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state/auth events for %s %s",
event_id,
"Failed to fetch missing state/auth events for %s: %s",
room_id,
failed_to_fetch,
)
remote_state = [
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
if include_event_in_state:
remote_event = event_map.get(event_id)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
if remote_event.is_state() and remote_event.rejected_reason is None:
remote_state.append(remote_event)
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
return remote_state, auth_chain
return pdus, auth_chain
async def _get_events_from_store_or_dest(
self, destination: str, room_id: str, event_ids: Iterable[str]
) -> Dict[str, EventBase]:
@defer.inlineCallbacks
def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
"""Fetch events from a remote destination, checking if we already have them.
Persists any events we don't already have as outliers.
If we fail to fetch any of the events, a warning will be logged, and the event
will be omitted from the result. Likewise, any events which turn out not to
be in the given room.
Args:
destination (str)
room_id (str)
event_ids (Iterable[str])
Returns:
map from event_id to event
Deferred[dict[str, EventBase]]: A deferred resolving to a map
from event_id to event
"""
fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
missing_events = set(event_ids) - fetched_events.keys()
if missing_events:
logger.debug(
"Fetching unknown state/auth events %s for room %s",
missing_events,
room_id,
)
if not missing_events:
return fetched_events
await self._get_events_and_persist(
destination=destination, room_id=room_id, events=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
(await self.store.get_events(missing_events, allow_rejected=True))
)
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
bad_events = list(
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
logger.debug(
"Fetching unknown state/auth events %s for room %s",
missing_events,
event_ids,
)
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned auth/state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
room_version = yield self.store.get_room_version(room_id)
# XXX 20 requests at once? really?
for batch in batch_iter(missing_events, 20):
deferreds = [
run_in_background(
self.federation_client.get_pdu,
destinations=[destination],
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]
res = yield make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
del fetched_events[bad_event_id]
for success, result in res:
if success and result:
fetched_events[result.event_id] = result
return fetched_events
async def _process_received_pdu(
self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
):
@defer.inlineCallbacks
def _process_received_pdu(self, origin, event, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
@@ -696,15 +688,15 @@ class FederationHandler(BaseHandler):
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
try:
context = await self._handle_new_event(origin, event, state=state)
context = yield self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
room = await self.store.get_room(room_id)
room = yield self.store.get_room(room_id)
if not room:
try:
await self.store.store_room(
yield self.store.store_room(
room_id=room_id, room_creator_user_id="", is_public=False
)
except StoreError:
@@ -717,11 +709,11 @@ class FederationHandler(BaseHandler):
# changing their profile info.
newly_joined = True
prev_state_ids = await context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
prev_state = await self.store.get_event(
prev_state = yield self.store.get_event(
prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
@@ -729,7 +721,7 @@ class FederationHandler(BaseHandler):
if newly_joined:
user = UserID.from_string(event.state_key)
await self.user_joined_room(user, room_id)
yield self.user_joined_room(user, room_id)
@log_function
async def backfill(self, dest, room_id, limit, extremities):