diff --git a/changelog.d/14494.misc b/changelog.d/14494.misc new file mode 100644 index 0000000000..fbf48bf70f --- /dev/null +++ b/changelog.d/14494.misc @@ -0,0 +1 @@ +Speed-up `/messages` with `filter_events_for_client` optimizations. diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 2b31ce54bb..9285b64ed1 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -29,7 +29,7 @@ from typing import ( from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.logging.opentracing import tag_args, trace +from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace from synapse.storage.roommember import ProfileInfo from synapse.storage.state import StateFilter from synapse.storage.util.partial_state_events_tracker import ( @@ -182,6 +182,53 @@ class StateStorageController: return self.stores.state._get_state_groups_from_groups(groups, state_filter) + @trace + @tag_args + async def _get_state_for_client_filtering_for_events( + self, event_ids: Collection[str], user_id_viewing_events: str + ) -> Dict[str, StateMap[EventBase]]: + """TODO""" + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + + # Since we're making decisions based on the state, we need to wait. + await_full_state = True + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) + + groups = set(event_to_groups.values()) + logger.info( + "_get_state_for_client_filtering_for_events: groups=%s", + groups, + ) + group_to_state = await self.stores.state._get_state_for_client_filtering( + groups, user_id_viewing_events + ) + logger.info( + "_get_state_for_client_filtering_for_events: group_to_state=%s", + group_to_state, + ) + + state_event_map = await self.stores.main.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } + for event_id, group in event_to_groups.items() + } + + return {event: event_to_state[event] for event in event_ids} + @trace async def get_state_for_events( self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None @@ -209,9 +256,11 @@ class StateStorageController: ) groups = set(event_to_groups.values()) + logger.info("get_state_for_events: groups=%s", groups) group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter or StateFilter.all() ) + logger.info("get_state_for_events: group_to_state=%s", group_to_state) state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 224bebf3bf..a7fcc564a9 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -101,11 +101,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): where_clause = " AND (%s)" % (where_clause,) if isinstance(self.database_engine, PostgresEngine): - # Suspicion start # Temporarily disable sequential scans in this transaction. This is # a temporary hack until we can add the right indices in txn.execute("SET LOCAL enable_seqscan=off") - # Suspicion end # The below query walks the state_group tree so that the "state" # table includes all state_groups in the tree. It then joins diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index ead23ac9fb..0bd4bad57b 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -13,12 +13,22 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, +) import attr from synapse.api.constants import EventTypes -from synapse.logging.tracing import SynapseTags, set_tag, tag_args, trace +from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -30,9 +40,11 @@ from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.cancellation import cancellable +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -107,6 +119,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "*stateGroupMembersCache*", 500000, ) + # TODO: Remove cache invalidation + self._state_group_cache.invalidate_all() + self._state_group_members_cache.invalidate_all() def get_max_state_group_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") @@ -245,6 +260,140 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types + async def _get_state_groups_from_cache( + self, state_groups: Iterable[int], state_filter: StateFilter + ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]: + """TODO + + Returns: + A map from each state_group to the complete/incomplete state map (filled in by cached + values) and the set of incomplete groups + """ + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = self._get_state_for_groups_using_cache( + state_groups, self._state_group_cache, state_filter=non_member_filter + ) + + (member_state, incomplete_groups_m) = self._get_state_for_groups_using_cache( + state_groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for state_group in state_groups: + state[state_group].update(member_state[state_group]) + + # We may have only got one of the events for the group + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + return (state, incomplete_groups) + + @cancellable + @trace + @tag_args + async def _get_state_for_client_filtering( + self, groups: Iterable[int], user_id_viewing_events: str + ) -> Dict[int, StateMap[str]]: + """ + TODO + """ + + def _get_state_for_client_filtering_txn( + txn: LoggingTransaction, groups: Iterable[int] + ) -> Mapping[int, StateMap[str]]: + sql = """ + WITH RECURSIVE sgs(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, sgs s + WHERE s.state_group = e.state_group + ) + SELECT + type, state_key, event_id + FROM state_groups_state + WHERE + state_group IN ( + SELECT state_group FROM sgs + ) + AND (type = ? AND state_key = ?) + ORDER BY + type, + state_key, + -- Use the lastest state in the chain (highest numbered state_group in the chain) + state_group DESC + LIMIT 1 + """ + + results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} + for group in groups: + row_info_list: List[Tuple] = [] + txn.execute(sql, (group, EventTypes.RoomHistoryVisibility, "")) + history_vis_info = txn.fetchone() + if history_vis_info is not None: + row_info_list.append(history_vis_info) + + txn.execute(sql, (group, EventTypes.Member, user_id_viewing_events)) + membership_info = txn.fetchone() + if membership_info is not None: + row_info_list.append(membership_info) + + for row in row_info_list: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[group][key] = event_id + + # The results should be considered immutable because we are using + # `intern_string` (TODO: Should we? copied from _get_state_groups_from_groups_txn). + return results + + # Craft a StateFilter to use with the cache + state_filter_for_cache_lookup = StateFilter.from_types( + ( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id_viewing_events), + ) + ) + ( + results_from_cache, + incomplete_groups, + ) = await self._get_state_groups_from_cache( + groups, state_filter_for_cache_lookup + ) + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + results: Dict[int, StateMap[str]] = results_from_cache + for batch in batch_iter(incomplete_groups, 100): + group_to_state_mapping = await self.db_pool.runInteraction( + "_get_state_for_client_filtering_txn", + _get_state_for_client_filtering_txn, + batch, + ) + logger.info("group_to_state_mapping=%s", group_to_state_mapping) + + # Now lets update the caches + # Help the cache hit ratio by expanding the filter a bit + state_filter_for_cache_insertion = ( + state_filter_for_cache_lookup.return_expanded() + ) + group_to_state_dict: Dict[int, StateMap[str]] = {} + group_to_state_dict.update(group_to_state_mapping) + self._insert_into_cache( + group_to_state_dict, + state_filter_for_cache_insertion, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + results.update(group_to_state_mapping) + + return results + @cancellable @trace @tag_args @@ -264,6 +413,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """ state_filter = state_filter or StateFilter.all() + # TODO: Replace with _get_state_groups_from_cache member_filter, non_member_filter = state_filter.get_member_split() # Now we look them up in the member and non-member caches diff --git a/synapse/visibility.py b/synapse/visibility.py index ee4afd4607..37a1b1541b 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -24,9 +24,9 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event from synapse.logging.opentracing import ( - start_active_span, SynapseTags, set_tag, + start_active_span, tag_args, trace, ) @@ -108,11 +108,21 @@ async def filter_events_for_client( # Grab the history visibility and membership for each of the events. That's all we # need to know in order to filter them. - filter_types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) - event_id_to_state = await storage.state.get_state_for_events( + event_id_to_state = await storage.state._get_state_for_client_filtering_for_events( # we exclude outliers at this point, and then handle them separately later + event_ids=frozenset( + e.event_id for e in events if not e.internal_metadata.outlier + ), + user_id_viewing_events=user_id, + ) + + # TODO: Remove comparison + logger.info("----------------------------------------------------") + logger.info("----------------------------------------------------") + types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) + event_id_to_state_orig = await storage.state.get_state_for_events( frozenset(e.event_id for e in events if not e.internal_metadata.outlier), - state_filter=StateFilter.from_types(filter_types), + state_filter=StateFilter.from_types(types), ) # Get the users who are ignored by the requesting user.