diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 0d480f1014..cbab7a65fe 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -232,6 +232,7 @@ class StateStorageController: self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None, + await_full_state: Optional[bool] = None, ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -240,6 +241,9 @@ class StateStorageController: Args: event_ids: events whose state should be returned state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at these events. Defaults to `True` unless `state_filter` can be + completely satisfied with partial state. Returns: A dict from event_id -> (type, state_key) -> event_id @@ -248,9 +252,13 @@ class StateStorageController: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): - await_full_state = False + if await_full_state is None: + if state_filter and not state_filter.must_await_full_state( + self._is_mine_id + ): + await_full_state = False + else: + await_full_state = True event_to_groups = await self.get_state_group_for_events( event_ids, await_full_state=await_full_state @@ -292,7 +300,10 @@ class StateStorageController: @trace async def get_state_ids_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None + self, + event_id: str, + state_filter: Optional[StateFilter] = None, + await_full_state: Optional[bool] = None, ) -> StateMap[str]: """ Get the state dict corresponding to a particular event @@ -300,6 +311,9 @@ class StateStorageController: Args: event_id: event whose state should be returned state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the event. Defaults to `True` unless `state_filter` can be completely + satisfied with partial state. Returns: A dict from (type, state_key) -> state_event_id @@ -309,7 +323,9 @@ class StateStorageController: outlier or is unknown) """ state_map = await self.get_state_ids_for_events( - [event_id], state_filter or StateFilter.all() + [event_id], + state_filter or StateFilter.all(), + await_full_state=await_full_state, ) return state_map[event_id]