diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py index 7b357c1ffe..25bf300e45 100644 --- a/synapse/storage/databases/main/sliding_sync.py +++ b/synapse/storage/databases/main/sliding_sync.py @@ -14,11 +14,12 @@ import logging -from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, cast +from typing import TYPE_CHECKING, Collection, Dict, List, Mapping, Optional, Set, cast import attr from synapse.api.errors import SlidingSyncUnknownPosition +from synapse.events import EventBase from synapse.logging.opentracing import log_kv from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction @@ -451,6 +452,38 @@ class SlidingSyncStore(SQLBaseStore): room_configs=room_configs, ) + async def get_visibility_for_events( + self, room_id: str, events: Collection[EventBase] + ) -> Mapping[str, Optional[str]]: + def get_visibility_for_events_txn( + txn: LoggingTransaction, + ) -> Mapping[str, Optional[str]]: + sql = """ + SELECT visibility FROM history_visibility_ranges + WHERE start_range <= ? AND (? < end_range OR end_range IS NULL) + AND room_id = ? + """ + + results = {} + for event in events: + txn.execute( + sql, + ( + event.internal_metadata.stream_ordering, + event.internal_metadata.stream_ordering, + room_id, + ), + ) + row = txn.fetchone() + if row is not None: + results[event.event_id] = row[0] + + return results + + return await self.db_pool.runInteraction( + "get_visibility_for_events", get_visibility_for_events_txn + ) + @attr.s(auto_attribs=True, frozen=True) class PerConnectionStateDB: diff --git a/synapse/visibility.py b/synapse/visibility.py index 3a2782bade..fdc2006956 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -105,6 +105,9 @@ async def filter_events_for_client( The filtered events. The `unsigned` data is annotated with the membership state of `user_id` at each event. """ + if not events: + return [] + # Filter out events that have been soft failed so that we don't relay them # to clients. events_before_filtering = events @@ -117,13 +120,38 @@ async def filter_events_for_client( [event.event_id for event in events], ) - types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) + types = ( + _HISTORY_VIS_KEY, + (EventTypes.Member, user_id), + ) + + room_id = events[0].room_id + assert all(event.room_id == room_id for event in events) + + visibilities: Dict[str, str] = {} + memberships: Dict[str, Optional[EventBase]] = {} + events_to_fetch = {e.event_id for e in events if not e.internal_metadata.outlier} + if not is_peeking: + fetched_visibilities = await storage.main.get_visibility_for_events( + room_id, [e for e in events if not e.internal_metadata.outlier] + ) + for event_id, visibility in fetched_visibilities.items(): + if visibility in ( + HistoryVisibility.SHARED, + HistoryVisibility.WORLD_READABLE, + ): + events_to_fetch.discard(event_id) + visibilities[event_id] = visibility # we exclude outliers at this point, and then handle them separately later - event_id_to_state = await storage.state.get_state_for_events( - frozenset(e.event_id for e in events if not e.internal_metadata.outlier), - state_filter=StateFilter.from_types(types), - ) + if events_to_fetch: + event_id_to_state = await storage.state.get_state_for_events( + events_to_fetch, + state_filter=StateFilter.from_types(types), + ) + for event_id, state in event_id_to_state.items(): + visibilities[event_id] = get_effective_room_visibility_from_state(state) + memberships[event_id] = state.get((EventTypes.Member, user_id)) # Get the users who are ignored by the requesting user. ignore_list = await storage.main.ignored_users(user_id) @@ -140,8 +168,8 @@ async def filter_events_for_client( ] = await storage.main.get_retention_policy_for_room(room_id) def allowed(event: EventBase) -> Optional[EventBase]: - state_after_event = event_id_to_state.get(event.event_id) - filtered = _check_client_allowed_to_see_event( + # state_after_event = event_id_to_state.get(event.event_id) + filtered = _check_client_allowed_to_see_event_with_state( user_id=user_id, event=event, clock=storage.main.clock, @@ -149,9 +177,10 @@ async def filter_events_for_client( sender_ignored=event.sender in ignore_list, always_include_ids=always_include_ids, retention_policy=retention_policies[event.room_id], - state=state_after_event, is_peeking=is_peeking, sender_erased=erased_senders.get(event.sender, False), + visibility=visibilities[event.event_id], + membership_event=memberships.get(event.event_id), ) if filtered is None: return None @@ -165,11 +194,9 @@ async def filter_events_for_client( user_membership_event: Optional[EventBase] if event.type == EventTypes.Member and event.state_key == user_id: user_membership_event = event - elif state_after_event is not None: - user_membership_event = state_after_event.get((EventTypes.Member, user_id)) else: - # unreachable! - raise Exception("Missing state for event that is not user's own membership") + # TODO: Actually get the proper membership + user_membership_event = memberships.get(event_id) user_membership = ( user_membership_event.membership @@ -353,6 +380,41 @@ def _check_client_allowed_to_see_event( the original event if they can see it as normal. """ + + visibility = HistoryVisibility.SHARED + + if state is not None: + visibility = get_effective_room_visibility_from_state(state) + membership_event = state.get((EventTypes.Member, user_id)) if state else None + + return _check_client_allowed_to_see_event_with_state( + user_id, + event, + clock, + filter_send_to_client, + is_peeking, + always_include_ids, + sender_ignored, + retention_policy, + sender_erased, + visibility=visibility, + membership_event=membership_event, + ) + + +def _check_client_allowed_to_see_event_with_state( + user_id: str, + event: EventBase, + clock: Clock, + filter_send_to_client: bool, + is_peeking: bool, + always_include_ids: FrozenSet[str], + sender_ignored: bool, + retention_policy: RetentionPolicy, + sender_erased: bool, + visibility: str, + membership_event: Optional[EventBase], +) -> Optional[EventBase]: # Only run some checks if these events aren't about to be sent to clients. This is # because, if this is not the case, we're probably only checking if the users can # see events in the room at that point in the DAG, and that shouldn't be decided @@ -390,12 +452,6 @@ def _check_client_allowed_to_see_event( ) return None - if state is None: - raise Exception("Missing state for non-outlier event") - - # get the room_visibility at the time of the event. - visibility = get_effective_room_visibility_from_state(state) - # Check if the room has lax history visibility, allowing us to skip # membership checks. # @@ -408,6 +464,10 @@ def _check_client_allowed_to_see_event( ): return event + if membership_event: + state = {(EventTypes.Member, user_id): membership_event} + else: + state = {} membership_result = _check_membership(user_id, event, visibility, state, is_peeking) if not membership_result.allowed: filtered_event_logger.debug(