diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index e1868552e5..6bc384c02c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -247,6 +247,46 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types + def _get_state_for_group_gather_inflight_requests( + self, group: int, state_filter: StateFilter + ) -> Tuple[Sequence[ObservableDeferred[StateMap[str]]], StateFilter]: + """ + Attempts to gather in-flight requests and re-use them to retrieve state + for the given state group, filtered with the given state filter. + + Returns: + tuple ( + sequence of ObservableDeferreds to observe, + StateFilter representing what is left over (what else needs + to be requested to fulfil the request) + ) + """ + + inflight_requests = self._state_group_inflight_requests.get(group) + if inflight_requests is None: + # no requests for this group, need to retrieve it all ourselves + return (), state_filter + + state_filter_left_over = state_filter + reusable_requests = [] + for ( + request_state_filter, + request_deferred, + ) in inflight_requests.items(): + new_state_filter_left_over = state_filter_left_over.approx_difference( + request_state_filter + ) + if new_state_filter_left_over != state_filter_left_over: + # reusing this request narrows our StateFilter down a bit. + reusable_requests.append(request_deferred) + state_filter_left_over = new_state_filter_left_over + if state_filter_left_over == StateFilter.none(): + # we have managed to collect enough of the in-flight requests + # to cover our StateFilter and give us the state we need. + break + + return reusable_requests, state_filter_left_over + async def _get_state_for_group_fire_request( self, group: int, state_filter: StateFilter ) -> StateMap[str]: