diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 7a0d149635..baf89c0fc0 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -184,10 +184,21 @@ class StateStorageController: @trace @tag_args - async def _get_state_for_client_filtering_for_events( + async def _get_state_for_events_when_filtering_for_client( self, event_ids: Collection[str], user_id_viewing_events: str ) -> Dict[str, StateMap[EventBase]]: - """TODO""" + """Get the state at each event that is necessary to filter + them before being displayed to clients from the perspective of the + `user_id_viewing_events`. Will fetch `m.room.history_visibility` and + `m.room.member` event of `user_id_viewing_events`. + + Args: + event_ids: List of event ID's that will be displayed to the client + user_id_viewing_events: User ID that will be viewing these events + + Returns: + Dict of event_id to state map. + """ set_tag( SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", str(len(event_ids)), @@ -204,10 +215,6 @@ class StateStorageController: group_to_state = await self.stores.state._get_state_for_client_filtering( groups, user_id_viewing_events ) - # logger.info( - # "_get_state_for_client_filtering_for_events: group_to_state=%s", - # group_to_state, - # ) state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], @@ -255,7 +262,6 @@ class StateStorageController: group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter or StateFilter.all() ) - # logger.info("get_state_for_events: group_to_state=%s", group_to_state) state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 8209e131ac..0b71f1e5d7 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -260,11 +260,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): async def _get_state_groups_from_cache( self, state_groups: Iterable[int], state_filter: StateFilter ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]: - """TODO + """Given a `state_filter`, pull out the relevant cached state groups that match + the filter. + + Args: + state_groups: List of state group ID's to fetch from the cache + state_filter: The relevant StateFilter to pull against Returns: - A map from each state_group to the complete/incomplete state map (filled in by cached - values) and the set of incomplete groups + A map from each state_group ID to the complete/incomplete state map (filled + in by cached values) and the set of incomplete state_groups that still need + to be filled in. """ member_filter, non_member_filter = state_filter.get_member_split() @@ -284,7 +290,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): for state_group in state_groups: state[state_group].update(member_state[state_group]) - # We may have only got one of the events for the group + # We may have only got one or none of the events for the group so mark those as + # incomplete that need fetching from the database. incomplete_groups = incomplete_groups_m | incomplete_groups_nm return (state, incomplete_groups) @@ -293,15 +300,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): @trace @tag_args async def _get_state_for_client_filtering( - self, groups: Iterable[int], user_id_viewing_events: str - ) -> Dict[int, StateMap[str]]: - """ - TODO + self, state_group_ids: Iterable[int], user_id_viewing_events: str + ) -> Dict[int, MutableStateMap[str]]: + """Get a state map for each state group ID provided that is necessary to filter + the corresponding events before being displayed to clients from the perspective + of the `user_id_viewing_events`. + + Args: + state_group_ids: The state groups to fetch + user_id_viewing_events: User ID that will be viewing the events that correspond + to the state groups + + Returns: + Dict of state_group ID to state map. """ def _get_state_for_client_filtering_txn( txn: LoggingTransaction, groups: Iterable[int] - ) -> Mapping[int, StateMap[str]]: + ) -> Mapping[int, MutableStateMap[str]]: sql = """ WITH RECURSIVE sgs(state_group) AS ( VALUES(?::bigint) @@ -343,8 +359,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): key = (intern_string(typ), intern_string(state_key)) results[group][key] = event_id - # The results should be considered immutable because we are using - # `intern_string` (TODO: Should we? copied from _get_state_groups_from_groups_txn). return results # Craft a StateFilter to use with the cache @@ -358,13 +372,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): results_from_cache, incomplete_groups, ) = await self._get_state_groups_from_cache( - groups, state_filter_for_cache_lookup + state_group_ids, state_filter_for_cache_lookup ) cache_sequence_nm = self._state_group_cache.sequence cache_sequence_m = self._state_group_members_cache.sequence - results: Dict[int, StateMap[str]] = results_from_cache + results = results_from_cache for batch in batch_iter(incomplete_groups, 100): group_to_state_mapping = await self.db_pool.runInteraction( "_get_state_for_client_filtering_txn", @@ -408,30 +422,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Dict of state group to state map. """ state_filter = state_filter or StateFilter.all() - - # TODO: Replace with _get_state_groups_from_cache - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches ( - non_member_state, - incomplete_groups_nm, - ) = self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - - (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) + results_from_cache, + incomplete_groups, + ) = await self._get_state_groups_from_cache(groups, state_filter) # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - + state = results_from_cache if not incomplete_groups: return state diff --git a/synapse/visibility.py b/synapse/visibility.py index c63da1dca8..f19793faf9 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -106,35 +106,18 @@ async def filter_events_for_client( [event.event_id for event in events], ) - non_outlier_event_ids = event_ids = frozenset( + non_outlier_event_ids = frozenset( e.event_id for e in events if not e.internal_metadata.outlier ) - # TODO: Remove: We do this just to remove await_full_state from the comparison - await storage.state.get_state_group_for_events( - non_outlier_event_ids, await_full_state=True - ) - # Grab the history visibility and membership for each of the events. That's all we # need to know in order to filter them. - event_id_to_state = await storage.state._get_state_for_client_filtering_for_events( + event_id_to_state = await storage.state._get_state_for_events_when_filtering_for_client( # we exclude outliers at this point, and then handle them separately later event_ids=non_outlier_event_ids, user_id_viewing_events=user_id, ) - # TODO: Remove comparison - # TODO: Remove cache invalidation - storage.state.stores.state._state_group_cache.invalidate_all() - storage.state.stores.state._state_group_members_cache.invalidate_all() - # logger.info("----------------------------------------------------") - # logger.info("----------------------------------------------------") - types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) - event_id_to_state_orig = await storage.state.get_state_for_events( - non_outlier_event_ids, - state_filter=StateFilter.from_types(types), - ) - # Get the users who are ignored by the requesting user. ignore_list = await storage.main.ignored_users(user_id)