Filter DM from account data
This commit is contained in:
@@ -11,7 +11,7 @@ if TYPE_CHECKING or HAS_PYDANTIC_V2:
|
||||
else:
|
||||
from pydantic import Extra
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import AccountDataTypes, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.rest.client.models import SlidingSyncBody
|
||||
from synapse.types import JsonMapping, Requester, RoomStreamToken, StreamToken, UserID
|
||||
@@ -266,11 +266,12 @@ class SlidingSyncHandler:
|
||||
lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {}
|
||||
if sync_config.lists:
|
||||
for list_key, list_config in sync_config.lists.items():
|
||||
# TODO: Apply filters
|
||||
#
|
||||
# TODO: Exclude partially stated rooms unless the `required_state` has
|
||||
# `["m.room.member", "$LAZY"]`
|
||||
# Apply filters
|
||||
filtered_room_ids = room_id_set
|
||||
if list_config.filters:
|
||||
filtered_room_ids = await self.filter_rooms(
|
||||
sync_config.user, room_id_set, list_config.filters
|
||||
)
|
||||
# TODO: Apply sorts
|
||||
sorted_room_ids = sorted(filtered_room_ids)
|
||||
|
||||
@@ -488,3 +489,65 @@ class SlidingSyncHandler:
|
||||
sync_room_id_set.discard(room_id)
|
||||
|
||||
return sync_room_id_set
|
||||
|
||||
async def filter_rooms(
|
||||
self,
|
||||
user: UserID,
|
||||
room_id_set: AbstractSet[str],
|
||||
filters: SlidingSyncConfig.SlidingSyncList.Filters,
|
||||
) -> AbstractSet[str]:
|
||||
"""
|
||||
Filter rooms based on the sync request.
|
||||
"""
|
||||
user_id = user.to_string()
|
||||
|
||||
# TODO: Apply filters
|
||||
#
|
||||
# TODO: Exclude partially stated rooms unless the `required_state` has
|
||||
# `["m.room.member", "$LAZY"]`
|
||||
|
||||
filtered_room_id_set = set(room_id_set)
|
||||
|
||||
# Filter for Direct-Message (DM) rooms
|
||||
if filters.is_dm:
|
||||
dm_map = await self.store.get_global_account_data_by_type_for_user(
|
||||
user_id, AccountDataTypes.DIRECT
|
||||
)
|
||||
logger.warn("dm_map: %s", dm_map)
|
||||
# Flatten out the map
|
||||
dm_room_id_set = set()
|
||||
if dm_map:
|
||||
for room_ids in dm_map.values():
|
||||
from typing_extensions import reveal_type
|
||||
|
||||
logger.warn("type(room_ids): %s", type(room_ids))
|
||||
reveal_type(room_ids)
|
||||
dm_room_id_set.update(room_ids)
|
||||
|
||||
filtered_room_id_set = filtered_room_id_set.intersection(dm_room_id_set)
|
||||
|
||||
if filters.spaces:
|
||||
raise NotImplementedError()
|
||||
|
||||
if filters.is_encrypted:
|
||||
raise NotImplementedError()
|
||||
|
||||
if filters.is_invite:
|
||||
raise NotImplementedError()
|
||||
|
||||
if filters.room_types:
|
||||
raise NotImplementedError()
|
||||
|
||||
if filters.not_room_types:
|
||||
raise NotImplementedError()
|
||||
|
||||
if filters.room_name_like:
|
||||
raise NotImplementedError()
|
||||
|
||||
if filters.tags:
|
||||
raise NotImplementedError()
|
||||
|
||||
if filters.not_tags:
|
||||
raise NotImplementedError()
|
||||
|
||||
return filtered_room_id_set
|
||||
|
||||
@@ -7,10 +7,16 @@ from synapse.rest.client import knock, login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
from synapse.handlers.sliding_sync import SlidingSyncConfig
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||
"""
|
||||
Tests Sliding Sync handler `get_sync_room_ids_for_user()` to make sure it returns
|
||||
@@ -530,3 +536,66 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||
room_id3,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class FilterRoomsTestCase(HomeserverTestCase):
|
||||
"""
|
||||
Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms
|
||||
correctly.
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
knock.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
# Enable sliding sync
|
||||
config["experimental_features"] = {"msc3575_enabled": True}
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
def test_TODO(self) -> None:
|
||||
"""
|
||||
Test TODO
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
user2_id = self.register_user("user2", "pass")
|
||||
user2_tok = self.login(user2_id, "pass")
|
||||
|
||||
# Create a room and send an invite to the other user
|
||||
room_id = self.helper.create_room_as(
|
||||
user2_id,
|
||||
is_public=False,
|
||||
tok=user2_tok,
|
||||
)
|
||||
self.helper.invite(
|
||||
room_id,
|
||||
src=user2_id,
|
||||
targ=user1_id,
|
||||
tok=user2_tok,
|
||||
extra_data={"is_direct": True},
|
||||
)
|
||||
# Accept the invite
|
||||
self.helper.join(room_id, user1_id, tok=user1_tok)
|
||||
|
||||
filters = SlidingSyncConfig.SlidingSyncList.Filters(
|
||||
is_dm=True,
|
||||
)
|
||||
|
||||
filtered_room_ids = self.get_success(
|
||||
self.sliding_sync_handler.filter_rooms(
|
||||
UserID.from_string(user1_id), {room_id}, filters
|
||||
)
|
||||
)
|
||||
|
||||
logger.warn("filtered_room_ids %s", filtered_room_ids)
|
||||
|
||||
self.assertEqual(filtered_room_ids, {room_id})
|
||||
|
||||
Reference in New Issue
Block a user