From d65c694a71985c2b312435f5d138fb73caa64d8f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 29 May 2024 15:32:06 -0500 Subject: [PATCH] Filter DM from account data --- synapse/handlers/sliding_sync.py | 73 +++++++++++++++++++++++++++-- tests/handlers/test_sliding_sync.py | 69 +++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 5 deletions(-) diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 0c72dc6c40..8774a42b19 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -11,7 +11,7 @@ if TYPE_CHECKING or HAS_PYDANTIC_V2: else: from pydantic import Extra -from synapse.api.constants import Membership +from synapse.api.constants import AccountDataTypes, Membership from synapse.events import EventBase from synapse.rest.client.models import SlidingSyncBody from synapse.types import JsonMapping, Requester, RoomStreamToken, StreamToken, UserID @@ -266,11 +266,12 @@ class SlidingSyncHandler: lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {} if sync_config.lists: for list_key, list_config in sync_config.lists.items(): - # TODO: Apply filters - # - # TODO: Exclude partially stated rooms unless the `required_state` has - # `["m.room.member", "$LAZY"]` + # Apply filters filtered_room_ids = room_id_set + if list_config.filters: + filtered_room_ids = await self.filter_rooms( + sync_config.user, room_id_set, list_config.filters + ) # TODO: Apply sorts sorted_room_ids = sorted(filtered_room_ids) @@ -488,3 +489,65 @@ class SlidingSyncHandler: sync_room_id_set.discard(room_id) return sync_room_id_set + + async def filter_rooms( + self, + user: UserID, + room_id_set: AbstractSet[str], + filters: SlidingSyncConfig.SlidingSyncList.Filters, + ) -> AbstractSet[str]: + """ + Filter rooms based on the sync request. + """ + user_id = user.to_string() + + # TODO: Apply filters + # + # TODO: Exclude partially stated rooms unless the `required_state` has + # `["m.room.member", "$LAZY"]` + + filtered_room_id_set = set(room_id_set) + + # Filter for Direct-Message (DM) rooms + if filters.is_dm: + dm_map = await self.store.get_global_account_data_by_type_for_user( + user_id, AccountDataTypes.DIRECT + ) + logger.warn("dm_map: %s", dm_map) + # Flatten out the map + dm_room_id_set = set() + if dm_map: + for room_ids in dm_map.values(): + from typing_extensions import reveal_type + + logger.warn("type(room_ids): %s", type(room_ids)) + reveal_type(room_ids) + dm_room_id_set.update(room_ids) + + filtered_room_id_set = filtered_room_id_set.intersection(dm_room_id_set) + + if filters.spaces: + raise NotImplementedError() + + if filters.is_encrypted: + raise NotImplementedError() + + if filters.is_invite: + raise NotImplementedError() + + if filters.room_types: + raise NotImplementedError() + + if filters.not_room_types: + raise NotImplementedError() + + if filters.room_name_like: + raise NotImplementedError() + + if filters.tags: + raise NotImplementedError() + + if filters.not_tags: + raise NotImplementedError() + + return filtered_room_id_set diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index aebc661023..8860bf2a0d 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -7,10 +7,16 @@ from synapse.rest.client import knock, login, room from synapse.server import HomeServer from synapse.types import JsonDict, UserID from synapse.util import Clock +from synapse.handlers.sliding_sync import SlidingSyncConfig from tests.unittest import HomeserverTestCase +import logging + +logger = logging.getLogger(__name__) + + class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): """ Tests Sliding Sync handler `get_sync_room_ids_for_user()` to make sure it returns @@ -530,3 +536,66 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): room_id3, }, ) + + +class FilterRoomsTestCase(HomeserverTestCase): + """ + Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms + correctly. + """ + + servlets = [ + admin.register_servlets, + knock.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable sliding sync + config["experimental_features"] = {"msc3575_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sliding_sync_handler = self.hs.get_sliding_sync_handler() + self.store = self.hs.get_datastores().main + + def test_TODO(self) -> None: + """ + Test TODO + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create a room and send an invite to the other user + room_id = self.helper.create_room_as( + user2_id, + is_public=False, + tok=user2_tok, + ) + self.helper.invite( + room_id, + src=user2_id, + targ=user1_id, + tok=user2_tok, + extra_data={"is_direct": True}, + ) + # Accept the invite + self.helper.join(room_id, user1_id, tok=user1_tok) + + filters = SlidingSyncConfig.SlidingSyncList.Filters( + is_dm=True, + ) + + filtered_room_ids = self.get_success( + self.sliding_sync_handler.filter_rooms( + UserID.from_string(user1_id), {room_id}, filters + ) + ) + + logger.warn("filtered_room_ids %s", filtered_room_ids) + + self.assertEqual(filtered_room_ids, {room_id})