1
0

Fix lints

This commit is contained in:
Eric Eastwood
2022-11-21 23:22:03 -06:00
parent 0459a9c42f
commit 2939eadd00
3 changed files with 46 additions and 60 deletions

View File

@@ -184,10 +184,21 @@ class StateStorageController:
@trace
@tag_args
async def _get_state_for_client_filtering_for_events(
async def _get_state_for_events_when_filtering_for_client(
self, event_ids: Collection[str], user_id_viewing_events: str
) -> Dict[str, StateMap[EventBase]]:
"""TODO"""
"""Get the state at each event that is necessary to filter
them before being displayed to clients from the perspective of the
`user_id_viewing_events`. Will fetch `m.room.history_visibility` and
`m.room.member` event of `user_id_viewing_events`.
Args:
event_ids: List of event ID's that will be displayed to the client
user_id_viewing_events: User ID that will be viewing these events
Returns:
Dict of event_id to state map.
"""
set_tag(
SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
str(len(event_ids)),
@@ -204,10 +215,6 @@ class StateStorageController:
group_to_state = await self.stores.state._get_state_for_client_filtering(
groups, user_id_viewing_events
)
# logger.info(
# "_get_state_for_client_filtering_for_events: group_to_state=%s",
# group_to_state,
# )
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
@@ -255,7 +262,6 @@ class StateStorageController:
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
# logger.info("get_state_for_events: group_to_state=%s", group_to_state)
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],

View File

@@ -260,11 +260,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
async def _get_state_groups_from_cache(
self, state_groups: Iterable[int], state_filter: StateFilter
) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
"""TODO
"""Given a `state_filter`, pull out the relevant cached state groups that match
the filter.
Args:
state_groups: List of state group ID's to fetch from the cache
state_filter: The relevant StateFilter to pull against
Returns:
A map from each state_group to the complete/incomplete state map (filled in by cached
values) and the set of incomplete groups
A map from each state_group ID to the complete/incomplete state map (filled
in by cached values) and the set of incomplete state_groups that still need
to be filled in.
"""
member_filter, non_member_filter = state_filter.get_member_split()
@@ -284,7 +290,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
for state_group in state_groups:
state[state_group].update(member_state[state_group])
# We may have only got one of the events for the group
# We may have only got one or none of the events for the group so mark those as
# incomplete that need fetching from the database.
incomplete_groups = incomplete_groups_m | incomplete_groups_nm
return (state, incomplete_groups)
@@ -293,15 +300,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
@trace
@tag_args
async def _get_state_for_client_filtering(
self, groups: Iterable[int], user_id_viewing_events: str
) -> Dict[int, StateMap[str]]:
"""
TODO
self, state_group_ids: Iterable[int], user_id_viewing_events: str
) -> Dict[int, MutableStateMap[str]]:
"""Get a state map for each state group ID provided that is necessary to filter
the corresponding events before being displayed to clients from the perspective
of the `user_id_viewing_events`.
Args:
state_group_ids: The state groups to fetch
user_id_viewing_events: User ID that will be viewing the events that correspond
to the state groups
Returns:
Dict of state_group ID to state map.
"""
def _get_state_for_client_filtering_txn(
txn: LoggingTransaction, groups: Iterable[int]
) -> Mapping[int, StateMap[str]]:
) -> Mapping[int, MutableStateMap[str]]:
sql = """
WITH RECURSIVE sgs(state_group) AS (
VALUES(?::bigint)
@@ -343,8 +359,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
key = (intern_string(typ), intern_string(state_key))
results[group][key] = event_id
# The results should be considered immutable because we are using
# `intern_string` (TODO: Should we? copied from _get_state_groups_from_groups_txn).
return results
# Craft a StateFilter to use with the cache
@@ -358,13 +372,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
results_from_cache,
incomplete_groups,
) = await self._get_state_groups_from_cache(
groups, state_filter_for_cache_lookup
state_group_ids, state_filter_for_cache_lookup
)
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
results: Dict[int, StateMap[str]] = results_from_cache
results = results_from_cache
for batch in batch_iter(incomplete_groups, 100):
group_to_state_mapping = await self.db_pool.runInteraction(
"_get_state_for_client_filtering_txn",
@@ -408,30 +422,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Dict of state group to state map.
"""
state_filter = state_filter or StateFilter.all()
# TODO: Replace with _get_state_groups_from_cache
member_filter, non_member_filter = state_filter.get_member_split()
# Now we look them up in the member and non-member caches
(
non_member_state,
incomplete_groups_nm,
) = self._get_state_for_groups_using_cache(
groups, self._state_group_cache, state_filter=non_member_filter
)
(member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, state_filter=member_filter
)
state = dict(non_member_state)
for group in groups:
state[group].update(member_state[group])
results_from_cache,
incomplete_groups,
) = await self._get_state_groups_from_cache(groups, state_filter)
# Now fetch any missing groups from the database
incomplete_groups = incomplete_groups_m | incomplete_groups_nm
state = results_from_cache
if not incomplete_groups:
return state

View File

@@ -106,35 +106,18 @@ async def filter_events_for_client(
[event.event_id for event in events],
)
non_outlier_event_ids = event_ids = frozenset(
non_outlier_event_ids = frozenset(
e.event_id for e in events if not e.internal_metadata.outlier
)
# TODO: Remove: We do this just to remove await_full_state from the comparison
await storage.state.get_state_group_for_events(
non_outlier_event_ids, await_full_state=True
)
# Grab the history visibility and membership for each of the events. That's all we
# need to know in order to filter them.
event_id_to_state = await storage.state._get_state_for_client_filtering_for_events(
event_id_to_state = await storage.state._get_state_for_events_when_filtering_for_client(
# we exclude outliers at this point, and then handle them separately later
event_ids=non_outlier_event_ids,
user_id_viewing_events=user_id,
)
# TODO: Remove comparison
# TODO: Remove cache invalidation
storage.state.stores.state._state_group_cache.invalidate_all()
storage.state.stores.state._state_group_members_cache.invalidate_all()
# logger.info("----------------------------------------------------")
# logger.info("----------------------------------------------------")
types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id))
event_id_to_state_orig = await storage.state.get_state_for_events(
non_outlier_event_ids,
state_filter=StateFilter.from_types(types),
)
# Get the users who are ignored by the requesting user.
ignore_list = await storage.main.ignored_users(user_id)