1
0

Merge branch 'develop' of github.com:element-hq/synapse into matrix-org-hotfixes

This commit is contained in:
Andrew Morgan
2024-09-16 10:32:01 +01:00
41 changed files with 2580 additions and 2006 deletions
+1
View File
@@ -0,0 +1 @@
Add support for the `tags` and `not_tags` filters for simplified sliding sync.
+5
View File
@@ -0,0 +1,5 @@
Import pydantic objects from the `_pydantic_compat` module.
This allows `check_pydantic_models.py` to mock those pydantic objects
only in the synapse module, and not interfere with pydantic objects in
external dependencies.
-1
View File
@@ -1 +0,0 @@
Speed up sliding sync by reducing amount of data pulled out of the database for large rooms.
+1
View File
@@ -0,0 +1 @@
Make sure we get up-to-date state information when using the new Sliding Sync tables to derive room membership.
+1
View File
@@ -0,0 +1 @@
Use Sliding Sync tables as a bulk shortcut for getting the max `event_stream_ordering` of rooms.
+1
View File
@@ -0,0 +1 @@
Speed up sliding sync requests a bit where there are many room changes.
+1
View File
@@ -0,0 +1 @@
Refactor sliding sync filter unit tests so the sliding sync API has better test coverage.
+15 -35
View File
@@ -45,7 +45,6 @@ import traceback
import unittest.mock
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
@@ -57,30 +56,17 @@ from typing import (
)
from parameterized import parameterized
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import (
BaseModel as PydanticBaseModel,
conbytes,
confloat,
conint,
constr,
)
from pydantic.v1.typing import get_args
else:
from pydantic import (
BaseModel as PydanticBaseModel,
conbytes,
confloat,
conint,
constr,
)
from pydantic.typing import get_args
from typing_extensions import ParamSpec
from synapse._pydantic_compat import (
BaseModel as PydanticBaseModel,
conbytes,
confloat,
conint,
constr,
get_args,
)
logger = logging.getLogger(__name__)
CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: List[Callable] = [
@@ -183,22 +169,16 @@ def monkeypatch_pydantic() -> Generator[None, None, None]:
# Most Synapse code ought to import the patched objects directly from
# `pydantic`. But we also patch their containing modules `pydantic.main` and
# `pydantic.types` for completeness.
patch_basemodel1 = unittest.mock.patch(
"pydantic.BaseModel", new=PatchedBaseModel
patch_basemodel = unittest.mock.patch(
"synapse._pydantic_compat.BaseModel", new=PatchedBaseModel
)
patch_basemodel2 = unittest.mock.patch(
"pydantic.main.BaseModel", new=PatchedBaseModel
)
patches.enter_context(patch_basemodel1)
patches.enter_context(patch_basemodel2)
patches.enter_context(patch_basemodel)
for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG:
wrapper: Callable = make_wrapper(factory)
patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper)
patch2 = unittest.mock.patch(
f"pydantic.types.{factory.__name__}", new=wrapper
patch = unittest.mock.patch(
f"synapse._pydantic_compat.{factory.__name__}", new=wrapper
)
patches.enter_context(patch1)
patches.enter_context(patch2)
patches.enter_context(patch)
yield
+63 -1
View File
@@ -19,6 +19,8 @@
#
#
from typing import TYPE_CHECKING
from packaging.version import Version
try:
@@ -30,4 +32,64 @@ except ImportError:
HAS_PYDANTIC_V2: bool = Version(pydantic_version).major == 2
__all__ = ("HAS_PYDANTIC_V2",)
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import (
BaseModel,
Extra,
Field,
MissingError,
PydanticValueError,
StrictBool,
StrictInt,
StrictStr,
ValidationError,
conbytes,
confloat,
conint,
constr,
parse_obj_as,
validator,
)
from pydantic.v1.error_wrappers import ErrorWrapper
from pydantic.v1.typing import get_args
else:
from pydantic import (
BaseModel,
Extra,
Field,
MissingError,
PydanticValueError,
StrictBool,
StrictInt,
StrictStr,
ValidationError,
conbytes,
confloat,
conint,
constr,
parse_obj_as,
validator,
)
from pydantic.error_wrappers import ErrorWrapper
from pydantic.typing import get_args
__all__ = (
"HAS_PYDANTIC_V2",
"BaseModel",
"constr",
"conbytes",
"conint",
"confloat",
"ErrorWrapper",
"Extra",
"Field",
"get_args",
"MissingError",
"parse_obj_as",
"PydanticValueError",
"StrictBool",
"StrictInt",
"StrictStr",
"ValidationError",
"validator",
)
+2 -8
View File
@@ -18,17 +18,11 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar
from typing import Any, Dict, Type, TypeVar
import jsonschema
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import BaseModel, ValidationError, parse_obj_as
else:
from pydantic import BaseModel, ValidationError, parse_obj_as
from synapse._pydantic_compat import BaseModel, ValidationError, parse_obj_as
from synapse.config._base import ConfigError
from synapse.types import JsonDict, StrSequence
+8 -8
View File
@@ -22,17 +22,17 @@
import argparse
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import attr
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import BaseModel, Extra, StrictBool, StrictInt, StrictStr
else:
from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
from synapse._pydantic_compat import (
BaseModel,
Extra,
StrictBool,
StrictInt,
StrictStr,
)
from synapse.config._base import (
Config,
ConfigError,
+2 -8
View File
@@ -19,17 +19,11 @@
#
#
import collections.abc
from typing import TYPE_CHECKING, List, Type, Union, cast
from typing import List, Type, Union, cast
import jsonschema
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import Field, StrictBool, StrictStr
else:
from pydantic import Field, StrictBool, StrictStr
from synapse._pydantic_compat import Field, StrictBool, StrictStr
from synapse.api.constants import (
MAX_ALIAS_LENGTH,
EventContentFields,
+1 -1
View File
@@ -267,7 +267,7 @@ class SlidingSyncHandler:
if relevant_rooms_to_send_map:
with start_active_span("sliding_sync.generate_room_entries"):
await concurrently_execute(handle_room, relevant_rooms_to_send_map, 10)
await concurrently_execute(handle_room, relevant_rooms_to_send_map, 20)
extensions = await self.extensions.get_extensions_response(
sync_config=sync_config,
+9 -1
View File
@@ -38,6 +38,7 @@ from synapse.types.handlers.sliding_sync import (
SlidingSyncConfig,
SlidingSyncResult,
)
from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -534,7 +535,10 @@ class SlidingSyncExtensionHandler:
# For rooms we've previously sent down, but aren't up to date, we
# need to use the from token from the room status.
if previously_rooms:
for room_id, receipt_token in previously_rooms.items():
# Fetch any missing rooms concurrently.
async def handle_previously_room(room_id: str) -> None:
receipt_token = previously_rooms[room_id]
# TODO: Limit the number of receipts we're about to send down
# for the room, if its too many we should TODO
previously_receipts = (
@@ -546,6 +550,10 @@ class SlidingSyncExtensionHandler:
)
fetched_receipts.extend(previously_receipts)
await concurrently_execute(
handle_previously_room, previously_rooms.keys(), 20
)
if initial_rooms:
# We also always send down receipts for the current user.
user_receipts = (
+86 -63
View File
@@ -27,6 +27,7 @@ from typing import (
Set,
Tuple,
Union,
cast,
)
import attr
@@ -245,6 +246,7 @@ class SlidingSyncRoomLists:
event_pos=change.event_pos,
room_version_id=change.room_version_id,
# We keep the current state of the room though
has_known_state=existing_room.has_known_state,
room_type=existing_room.room_type,
is_encrypted=existing_room.is_encrypted,
)
@@ -269,6 +271,7 @@ class SlidingSyncRoomLists:
event_id=change.event_id,
event_pos=change.event_pos,
room_version_id=change.room_version_id,
has_known_state=True,
room_type=room_type,
is_encrypted=is_encrypted,
)
@@ -304,6 +307,7 @@ class SlidingSyncRoomLists:
event_id=None,
event_pos=newly_left_room_map[room_id],
room_version_id=await self.store.get_room_version_id(room_id),
has_known_state=True,
room_type=room_type,
is_encrypted=is_encrypted,
)
@@ -355,11 +359,18 @@ class SlidingSyncRoomLists:
if list_config.ranges:
if list_config.ranges == [(0, len(filtered_sync_room_map) - 1)]:
# If we are asking for the full range, we don't need to sort the list.
sorted_room_info = list(filtered_sync_room_map.values())
sorted_room_info: List[RoomsForUserType] = list(
filtered_sync_room_map.values()
)
else:
# Sort the list
sorted_room_info = await self.sort_rooms_using_tables(
filtered_sync_room_map, to_token
sorted_room_info = await self.sort_rooms(
# Cast is safe because RoomsForUserSlidingSync is part
# of the `RoomsForUserType` union. Why can't it detect this?
cast(
Dict[str, RoomsForUserType], filtered_sync_room_map
),
to_token,
)
for range in list_config.ranges:
@@ -1513,6 +1524,8 @@ class SlidingSyncRoomLists:
A filtered dictionary of room IDs along with membership information in the
room at the time of `to_token`.
"""
user_id = user.to_string()
room_id_to_stripped_state_map: Dict[
str, Optional[StateMap[StrippedStateEvent]]
] = {}
@@ -1622,12 +1635,14 @@ class SlidingSyncRoomLists:
and room_type not in filters.room_types
):
filtered_room_id_set.remove(room_id)
continue
if (
filters.not_room_types is not None
and room_type in filters.not_room_types
):
filtered_room_id_set.remove(room_id)
continue
if filters.room_name_like is not None:
with start_active_span("filters.room_name_like"):
@@ -1644,9 +1659,36 @@ class SlidingSyncRoomLists:
# )
raise NotImplementedError()
# Filter by room tags according to the users account data
if filters.tags is not None or filters.not_tags is not None:
with start_active_span("filters.tags"):
raise NotImplementedError()
# Fetch the user tags for their rooms
room_tags = await self.store.get_tags_for_user(user_id)
room_id_to_tag_name_set: Dict[str, Set[str]] = {
room_id: set(tags.keys()) for room_id, tags in room_tags.items()
}
if filters.tags is not None:
tags_set = set(filters.tags)
filtered_room_id_set = {
room_id
for room_id in filtered_room_id_set
# Remove rooms that don't have one of the tags in the filter
if room_id_to_tag_name_set.get(room_id, set()).intersection(
tags_set
)
}
if filters.not_tags is not None:
not_tags_set = set(filters.not_tags)
filtered_room_id_set = {
room_id
for room_id in filtered_room_id_set
# Remove rooms if they have any of the tags in the filter
if not room_id_to_tag_name_set.get(room_id, set()).intersection(
not_tags_set
)
}
# Assemble a new sync room map but only with the `filtered_room_id_set`
return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set}
@@ -1670,6 +1712,7 @@ class SlidingSyncRoomLists:
filters: Filters to apply
to_token: We filter based on the state of the room at this token
dm_room_ids: Set of room IDs which are DMs
room_tags: Mapping of room ID to tags
Returns:
A filtered dictionary of room IDs along with membership information in the
@@ -1697,7 +1740,10 @@ class SlidingSyncRoomLists:
filtered_room_id_set = {
room_id
for room_id in filtered_room_id_set
if sync_room_map[room_id].is_encrypted == filters.is_encrypted
# Remove rooms if we can't figure out what the encryption status is
if sync_room_map[room_id].has_known_state
# Or remove if it doesn't match the filter
and sync_room_map[room_id].is_encrypted == filters.is_encrypted
}
# Filter for rooms that the user has been invited to
@@ -1726,6 +1772,11 @@ class SlidingSyncRoomLists:
# Make a copy so we don't run into an error: `Set changed size during
# iteration`, when we filter out and remove items
for room_id in filtered_room_id_set.copy():
# Remove rooms if we can't figure out what room type it is
if not sync_room_map[room_id].has_known_state:
filtered_room_id_set.remove(room_id)
continue
room_type = sync_room_map[room_id].room_type
if (
@@ -1733,12 +1784,14 @@ class SlidingSyncRoomLists:
and room_type not in filters.room_types
):
filtered_room_id_set.remove(room_id)
continue
if (
filters.not_room_types is not None
and room_type in filters.not_room_types
):
filtered_room_id_set.remove(room_id)
continue
if filters.room_name_like is not None:
with start_active_span("filters.room_name_like"):
@@ -1755,70 +1808,40 @@ class SlidingSyncRoomLists:
# )
raise NotImplementedError()
# Filter by room tags according to the users account data
if filters.tags is not None or filters.not_tags is not None:
with start_active_span("filters.tags"):
raise NotImplementedError()
# Fetch the user tags for their rooms
room_tags = await self.store.get_tags_for_user(user_id)
room_id_to_tag_name_set: Dict[str, Set[str]] = {
room_id: set(tags.keys()) for room_id, tags in room_tags.items()
}
if filters.tags is not None:
tags_set = set(filters.tags)
filtered_room_id_set = {
room_id
for room_id in filtered_room_id_set
# Remove rooms that don't have one of the tags in the filter
if room_id_to_tag_name_set.get(room_id, set()).intersection(
tags_set
)
}
if filters.not_tags is not None:
not_tags_set = set(filters.not_tags)
filtered_room_id_set = {
room_id
for room_id in filtered_room_id_set
# Remove rooms if they have any of the tags in the filter
if not room_id_to_tag_name_set.get(room_id, set()).intersection(
not_tags_set
)
}
# Assemble a new sync room map but only with the `filtered_room_id_set`
return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set}
@trace
async def sort_rooms_using_tables(
self,
sync_room_map: Mapping[str, RoomsForUserSlidingSync],
to_token: StreamToken,
) -> List[RoomsForUserSlidingSync]:
"""
Sort by `stream_ordering` of the last event that the user should see in the
room. `stream_ordering` is unique so we get a stable sort.
Args:
sync_room_map: Dictionary of room IDs to sort along with membership
information in the room at the time of `to_token`.
to_token: We sort based on the events in the room at this token (<= `to_token`)
Returns:
A sorted list of room IDs by `stream_ordering` along with membership information.
"""
# Assemble a map of room ID to the `stream_ordering` of the last activity that the
# user should see in the room (<= `to_token`)
last_activity_in_room_map: Dict[str, int] = {}
for room_id, room_for_user in sync_room_map.items():
if room_for_user.membership != Membership.JOIN:
# If the user has left/been invited/knocked/been banned from a
# room, they shouldn't see anything past that point.
#
# FIXME: It's possible that people should see beyond this point
# in invited/knocked cases if for example the room has
# `invite`/`world_readable` history visibility, see
# https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
# For fully-joined rooms, we find the latest activity at/before the
# `to_token`.
joined_room_positions = (
await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering(
[
room_id
for room_id, room_for_user in sync_room_map.items()
if room_for_user.membership == Membership.JOIN
],
to_token.room_key,
)
)
last_activity_in_room_map.update(joined_room_positions)
return sorted(
sync_room_map.values(),
# Sort by the last activity (stream_ordering) in the room
key=lambda room_info: last_activity_in_room_map[room_info.room_id],
# We want descending order
reverse=True,
)
@trace
async def sort_rooms(
self,
+7 -9
View File
@@ -37,19 +37,17 @@ from typing import (
overload,
)
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import BaseModel, MissingError, PydanticValueError, ValidationError
from pydantic.v1.error_wrappers import ErrorWrapper
else:
from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError
from pydantic.error_wrappers import ErrorWrapper
from typing_extensions import Literal
from twisted.web.server import Request
from synapse._pydantic_compat import (
BaseModel,
ErrorWrapper,
MissingError,
PydanticValueError,
ValidationError,
)
from synapse.api.errors import Codes, SynapseError
from synapse.http import redact_uri
from synapse.http.server import HttpServer
+1 -6
View File
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
from synapse._pydantic_compat import HAS_PYDANTIC_V2
from synapse._pydantic_compat import StrictBool
from synapse.api.constants import Direction, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
@@ -56,11 +56,6 @@ from synapse.types.rest import RequestBodyModel
if TYPE_CHECKING:
from synapse.server import HomeServer
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import StrictBool
else:
from pydantic import StrictBool
logger = logging.getLogger(__name__)
+1 -7
View File
@@ -24,18 +24,12 @@ import random
from typing import TYPE_CHECKING, List, Optional, Tuple
from urllib.parse import urlparse
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import StrictBool, StrictStr, constr
else:
from pydantic import StrictBool, StrictStr, constr
import attr
from typing_extensions import Literal
from twisted.web.server import Request
from synapse._pydantic_compat import StrictBool, StrictStr, constr
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
+1 -7
View File
@@ -24,13 +24,7 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import Extra, StrictStr
else:
from pydantic import Extra, StrictStr
from synapse._pydantic_compat import Extra, StrictStr
from synapse.api import errors
from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError
from synapse.handlers.device import DeviceHandler
+1 -7
View File
@@ -22,17 +22,11 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import StrictStr
else:
from pydantic import StrictStr
from typing_extensions import Literal
from twisted.web.server import Request
from synapse._pydantic_compat import StrictStr
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
+1 -5
View File
@@ -23,7 +23,7 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse._pydantic_compat import HAS_PYDANTIC_V2
from synapse._pydantic_compat import StrictStr
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
@@ -40,10 +40,6 @@ from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import StrictStr
else:
from pydantic import StrictStr
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -1044,7 +1044,7 @@ class SlidingSyncRestServlet(RestServlet):
serialized_rooms[room_id]["heroes"] = serialized_heroes
# We should only include the `initial` key if it's `True` to save bandwidth.
# The absense of this flag means `False`.
# The absence of this flag means `False`.
if room_result.initial:
serialized_rooms[room_id]["initial"] = room_result.initial
+1 -7
View File
@@ -23,17 +23,11 @@ import logging
import re
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import Extra, StrictInt, StrictStr
else:
from pydantic import Extra, StrictInt, StrictStr
from signedjson.sign import sign_json
from twisted.web.server import Request
from synapse._pydantic_compat import Extra, StrictInt, StrictStr
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import HttpServer
from synapse.http.servlet import (
+1
View File
@@ -136,6 +136,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
self._attempt_to_invalidate_cache("get_room_type", (room_id,))
self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
def _invalidate_state_caches_all(self, room_id: str) -> None:
"""Invalidates caches that are based on the current state, but does
+1 -6
View File
@@ -40,7 +40,7 @@ from typing import (
import attr
from synapse._pydantic_compat import HAS_PYDANTIC_V2
from synapse._pydantic_compat import BaseModel
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection, Cursor
@@ -49,11 +49,6 @@ from synapse.util import Clock, json_encoder
from . import engines
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import BaseModel
else:
from pydantic import BaseModel
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.database import (
+14
View File
@@ -41,6 +41,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events import SLIDING_SYNC_RELEVANT_STATE_SET
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import CachedFunction
@@ -271,12 +272,20 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache(
"get_rooms_for_user", (data.state_key,)
)
self._attempt_to_invalidate_cache(
"get_sliding_sync_rooms_for_user", None
)
elif data.type == EventTypes.RoomEncryption:
self._attempt_to_invalidate_cache(
"get_room_encryption", (data.room_id,)
)
elif data.type == EventTypes.Create:
self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
if (data.type, data.state_key) in SLIDING_SYNC_RELEVANT_STATE_SET:
self._attempt_to_invalidate_cache(
"get_sliding_sync_rooms_for_user", None
)
elif row.type == EventsStreamAllStateRow.TypeId:
assert isinstance(data, EventsStreamAllStateRow)
# Similar to the above, but the entire caches are invalidated. This is
@@ -285,6 +294,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_rooms_for_user", None)
self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,))
self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
@@ -365,6 +375,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
elif etype == EventTypes.RoomEncryption:
self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
if (etype, state_key) in SLIDING_SYNC_RELEVANT_STATE_SET:
self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
if relates_to:
self._attempt_to_invalidate_cache(
"get_relations_for_event",
@@ -477,6 +490,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache(
"get_current_hosts_in_room_ordered", (room_id,)
)
self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
self._attempt_to_invalidate_cache("did_forget", None)
self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
@@ -52,6 +52,7 @@ from synapse.storage.types import Cursor
from synapse.types import JsonDict, RoomStreamToken, StateMap, StrCollection
from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES
from synapse.types.state import StateFilter
from synapse.types.storage import _BackgroundUpdates
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter
@@ -76,34 +77,6 @@ _REPLACE_STREAM_ORDERING_SQL_COMMANDS = (
)
class _BackgroundUpdates:
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2"
INDEX_STREAM_ORDERING2 = "index_stream_ordering2"
INDEX_STREAM_ORDERING2_CONTAINS_URL = "index_stream_ordering2_contains_url"
INDEX_STREAM_ORDERING2_ROOM_ORDER = "index_stream_ordering2_room_order"
INDEX_STREAM_ORDERING2_ROOM_STREAM = "index_stream_ordering2_room_stream"
INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts"
REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows"
EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index"
EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index"
SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE = (
"sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update"
)
SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE = "sliding_sync_joined_rooms_bg_update"
SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE = (
"sliding_sync_membership_snapshots_bg_update"
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn."""
@@ -83,6 +83,7 @@ from synapse.storage.util.id_generators import (
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.types.storage import _BackgroundUpdates
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
@@ -2468,3 +2469,14 @@ class EventsWorkerStore(SQLBaseStore):
)
self.invalidate_get_event_cache_after_txn(txn, event_id)
async def have_finished_sliding_sync_background_jobs(self) -> bool:
"""Return if it's safe to use the sliding sync membership tables."""
return await self.db_pool.updates.have_completed_background_updates(
(
_BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE,
_BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE,
_BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE,
)
)
+11 -14
View File
@@ -51,7 +51,6 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_bg_updates import _BackgroundUpdates
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
@@ -1365,6 +1364,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_sliding_sync_rooms_for_user, (user_id,)
)
await self.db_pool.runInteraction("forget_membership", f)
@@ -1410,10 +1412,15 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
def get_sliding_sync_rooms_for_user_txn(
txn: LoggingTransaction,
) -> Dict[str, RoomsForUserSlidingSync]:
# XXX: If you use any new columns that can change (like from
# `sliding_sync_joined_rooms` or `forgotten`), make sure to bust the
# `get_sliding_sync_rooms_for_user` cache in the appropriate places (and add
# tests).
sql = """
SELECT m.room_id, m.sender, m.membership, m.membership_event_id,
r.room_version,
m.event_instance_name, m.event_stream_ordering,
m.has_known_state,
COALESCE(j.room_type, m.room_type),
COALESCE(j.is_encrypted, m.is_encrypted)
FROM sliding_sync_membership_snapshots AS m
@@ -1431,8 +1438,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
event_id=row[3],
room_version_id=row[4],
event_pos=PersistedEventPosition(row[5], row[6]),
room_type=row[7],
is_encrypted=row[8],
has_known_state=bool(row[7]),
room_type=row[8],
is_encrypted=bool(row[9]),
)
for row in txn
}
@@ -1442,17 +1450,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
get_sliding_sync_rooms_for_user_txn,
)
async def have_finished_sliding_sync_background_jobs(self) -> bool:
"""Return if it's safe to use the sliding sync membership tables."""
return await self.db_pool.updates.have_completed_background_updates(
(
_BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE,
_BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE,
_BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE,
)
)
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(
+30 -4
View File
@@ -1524,7 +1524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# majority of rooms will have a latest token from before the min stream
# pos.
def bulk_get_max_event_pos_txn(
def bulk_get_max_event_pos_fallback_txn(
txn: LoggingTransaction, batched_room_ids: StrCollection
) -> Dict[str, int]:
clause, args = make_in_list_sql_clause(
@@ -1547,11 +1547,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, [max_pos] + args)
return {row[0]: row[1] for row in txn}
# It's easier to look at the `sliding_sync_joined_rooms` table and avoid all of
# the joins and sub-queries.
def bulk_get_max_event_pos_from_sliding_sync_tables_txn(
txn: LoggingTransaction, batched_room_ids: StrCollection
) -> Dict[str, int]:
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", batched_room_ids
)
sql = f"""
SELECT room_id, event_stream_ordering
FROM sliding_sync_joined_rooms
WHERE {clause}
ORDER BY event_stream_ordering DESC
"""
txn.execute(sql, args)
return {row[0]: row[1] for row in txn}
recheck_rooms: Set[str] = set()
for batched in batch_iter(room_ids, 1000):
batch_results = await self.db_pool.runInteraction(
"_bulk_get_max_event_pos", bulk_get_max_event_pos_txn, batched
)
if await self.have_finished_sliding_sync_background_jobs():
batch_results = await self.db_pool.runInteraction(
"bulk_get_max_event_pos_from_sliding_sync_tables_txn",
bulk_get_max_event_pos_from_sliding_sync_tables_txn,
batched,
)
else:
batch_results = await self.db_pool.runInteraction(
"bulk_get_max_event_pos_fallback_txn",
bulk_get_max_event_pos_fallback_txn,
batched,
)
for room_id, stream_ordering in batch_results.items():
if stream_ordering <= now_token.stream:
results.update(batch_results)
+1
View File
@@ -48,6 +48,7 @@ class RoomsForUserSlidingSync:
event_pos: PersistedEventPosition
room_version_id: str
has_known_state: bool
room_type: Optional[str]
is_encrypted: bool
+5 -8
View File
@@ -37,23 +37,20 @@ from typing import (
import attr
from synapse._pydantic_compat import HAS_PYDANTIC_V2
from synapse._pydantic_compat import Extra
from synapse.api.constants import EventTypes
from synapse.types import MultiWriterStreamToken, RoomStreamToken, StrCollection, UserID
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import Extra
else:
from pydantic import Extra
from synapse.events import EventBase
from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
MultiWriterStreamToken,
Requester,
RoomStreamToken,
SlidingSyncStreamToken,
StrCollection,
StreamToken,
UserID,
)
from synapse.types.rest.client import SlidingSyncBody
+1 -8
View File
@@ -18,14 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import TYPE_CHECKING
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import BaseModel, Extra
else:
from pydantic import BaseModel, Extra
from synapse._pydantic_compat import BaseModel, Extra
class RequestBodyModel(BaseModel):
+10 -24
View File
@@ -20,29 +20,15 @@
#
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from synapse._pydantic_compat import HAS_PYDANTIC_V2
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import (
Extra,
StrictBool,
StrictInt,
StrictStr,
conint,
constr,
validator,
)
else:
from pydantic import (
Extra,
StrictBool,
StrictInt,
StrictStr,
conint,
constr,
validator,
)
from synapse._pydantic_compat import (
Extra,
StrictBool,
StrictInt,
StrictStr,
conint,
constr,
validator,
)
from synapse.types.rest import RequestBodyModel
from synapse.util.threepids import validate_email
@@ -384,7 +370,7 @@ class SlidingSyncBody(RequestBodyModel):
receipts: Optional[ReceiptsExtension] = None
typing: Optional[TypingExtension] = None
conn_id: Optional[str]
conn_id: Optional[StrictStr]
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING:
+47
View File
@@ -0,0 +1,47 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
class _BackgroundUpdates:
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2"
INDEX_STREAM_ORDERING2 = "index_stream_ordering2"
INDEX_STREAM_ORDERING2_CONTAINS_URL = "index_stream_ordering2_contains_url"
INDEX_STREAM_ORDERING2_ROOM_ORDER = "index_stream_ordering2_room_order"
INDEX_STREAM_ORDERING2_ROOM_STREAM = "index_stream_ordering2_room_stream"
INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts"
REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows"
EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index"
EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index"
SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE = (
"sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update"
)
SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE = "sliding_sync_joined_rooms_bg_update"
SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE = (
"sliding_sync_membership_snapshots_bg_update"
)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1139,3 +1139,61 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
self.assertEqual(
response_body["rooms"][room_id]["bump_stamp"], invite_pos.stream
)
def test_rooms_meta_is_dm(self) -> None:
"""
Test `rooms` `is_dm` is correctly set for DM rooms.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# Create a DM room
joined_dm_room_id = self._create_dm_room(
inviter_user_id=user1_id,
inviter_tok=user1_tok,
invitee_user_id=user2_id,
invitee_tok=user2_tok,
should_join_room=True,
)
invited_dm_room_id = self._create_dm_room(
inviter_user_id=user1_id,
inviter_tok=user1_tok,
invitee_user_id=user2_id,
invitee_tok=user2_tok,
should_join_room=False,
)
# Create a normal room
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok)
# Create a room that user1 is invited to
invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
sync_body = {
"lists": {
"foo-list": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 0,
}
}
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Ensure DM's are correctly marked
self.assertDictEqual(
{
room_id: room.get("is_dm")
for room_id, room in response_body["rooms"].items()
},
{
invite_room_id: None,
room_id: None,
invited_dm_room_id: True,
joined_dm_room_id: True,
},
)
@@ -23,11 +23,12 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import (
AccountDataTypes,
EventContentFields,
EventTypes,
RoomTypes,
Membership,
)
from synapse.events import EventBase
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, StrippedStateEvent, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.rest.client import devices, login, receipts, room, sync
from synapse.server import HomeServer
from synapse.types import (
@@ -141,6 +142,167 @@ class SlidingSyncBase(unittest.HomeserverTestCase):
message=str(actual_required_state),
)
def _add_new_dm_to_global_account_data(
self, source_user_id: str, target_user_id: str, target_room_id: str
) -> None:
"""
Helper to handle inserting a new DM for the source user into global account data
(handles all of the list merging).
Args:
source_user_id: The user ID of the DM mapping we're going to update
target_user_id: User ID of the person the DM is with
target_room_id: Room ID of the DM
"""
store = self.hs.get_datastores().main
# Get the current DM map
existing_dm_map = self.get_success(
store.get_global_account_data_by_type_for_user(
source_user_id, AccountDataTypes.DIRECT
)
)
# Scrutinize the account data since it has no concrete type. We're just copying
# everything into a known type. It should be a mapping from user ID to a list of
# room IDs. Ignore anything else.
new_dm_map: Dict[str, List[str]] = {}
if isinstance(existing_dm_map, dict):
for user_id, room_ids in existing_dm_map.items():
if isinstance(user_id, str) and isinstance(room_ids, list):
for room_id in room_ids:
if isinstance(room_id, str):
new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
room_id
]
# Add the new DM to the map
new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
target_room_id
]
# Save the DM map to global account data
self.get_success(
store.add_account_data_for_user(
source_user_id,
AccountDataTypes.DIRECT,
new_dm_map,
)
)
def _create_dm_room(
self,
inviter_user_id: str,
inviter_tok: str,
invitee_user_id: str,
invitee_tok: str,
should_join_room: bool = True,
) -> str:
"""
Helper to create a DM room as the "inviter" and invite the "invitee" user to the
room. The "invitee" user also will join the room. The `m.direct` account data
will be set for both users.
"""
# Create a room and send an invite the other user
room_id = self.helper.create_room_as(
inviter_user_id,
is_public=False,
tok=inviter_tok,
)
self.helper.invite(
room_id,
src=inviter_user_id,
targ=invitee_user_id,
tok=inviter_tok,
extra_data={"is_direct": True},
)
if should_join_room:
# Person that was invited joins the room
self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
# Mimic the client setting the room as a direct message in the global account
# data for both users.
self._add_new_dm_to_global_account_data(
invitee_user_id, inviter_user_id, room_id
)
self._add_new_dm_to_global_account_data(
inviter_user_id, invitee_user_id, room_id
)
return room_id
_remote_invite_count: int = 0
def _create_remote_invite_room_for_user(
self,
invitee_user_id: str,
unsigned_invite_room_state: Optional[List[StrippedStateEvent]],
) -> str:
"""
Create a fake invite for a remote room and persist it.
We don't have any state for these kind of rooms and can only rely on the
stripped state included in the unsigned portion of the invite event to identify
the room.
Args:
invitee_user_id: The person being invited
unsigned_invite_room_state: List of stripped state events to assist the
receiver in identifying the room.
Returns:
The room ID of the remote invite room
"""
store = self.hs.get_datastores().main
invite_room_id = f"!test_room{self._remote_invite_count}:remote_server"
invite_event_dict = {
"room_id": invite_room_id,
"sender": "@inviter:remote_server",
"state_key": invitee_user_id,
"depth": 1,
"origin_server_ts": 1,
"type": EventTypes.Member,
"content": {"membership": Membership.INVITE},
"auth_events": [],
"prev_events": [],
}
if unsigned_invite_room_state is not None:
serialized_stripped_state_events = []
for stripped_event in unsigned_invite_room_state:
serialized_stripped_state_events.append(
{
"type": stripped_event.type,
"state_key": stripped_event.state_key,
"sender": stripped_event.sender,
"content": stripped_event.content,
}
)
invite_event_dict["unsigned"] = {
"invite_room_state": serialized_stripped_state_events
}
invite_event = make_event_from_dict(
invite_event_dict,
room_version=RoomVersions.V10,
)
invite_event.internal_metadata.outlier = True
invite_event.internal_metadata.out_of_band_membership = True
self.get_success(
store.maybe_store_room_on_outlier_membership(
room_id=invite_room_id, room_version=invite_event.room_version
)
)
context = EventContext.for_outlier(self.hs.get_storage_controllers())
persist_controller = self.hs.get_storage_controllers().persistence
assert persist_controller is not None
self.get_success(persist_controller.persist_event(invite_event, context))
self._remote_invite_count += 1
return invite_room_id
def _bump_notifier_wait_for_events(
self,
user_id: str,
@@ -261,93 +423,6 @@ class SlidingSyncTestCase(SlidingSyncBase):
super().prepare(reactor, clock, hs)
def _add_new_dm_to_global_account_data(
self, source_user_id: str, target_user_id: str, target_room_id: str
) -> None:
"""
Helper to handle inserting a new DM for the source user into global account data
(handles all of the list merging).
Args:
source_user_id: The user ID of the DM mapping we're going to update
target_user_id: User ID of the person the DM is with
target_room_id: Room ID of the DM
"""
# Get the current DM map
existing_dm_map = self.get_success(
self.store.get_global_account_data_by_type_for_user(
source_user_id, AccountDataTypes.DIRECT
)
)
# Scrutinize the account data since it has no concrete type. We're just copying
# everything into a known type. It should be a mapping from user ID to a list of
# room IDs. Ignore anything else.
new_dm_map: Dict[str, List[str]] = {}
if isinstance(existing_dm_map, dict):
for user_id, room_ids in existing_dm_map.items():
if isinstance(user_id, str) and isinstance(room_ids, list):
for room_id in room_ids:
if isinstance(room_id, str):
new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
room_id
]
# Add the new DM to the map
new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
target_room_id
]
# Save the DM map to global account data
self.get_success(
self.store.add_account_data_for_user(
source_user_id,
AccountDataTypes.DIRECT,
new_dm_map,
)
)
def _create_dm_room(
self,
inviter_user_id: str,
inviter_tok: str,
invitee_user_id: str,
invitee_tok: str,
should_join_room: bool = True,
) -> str:
"""
Helper to create a DM room as the "inviter" and invite the "invitee" user to the
room. The "invitee" user also will join the room. The `m.direct` account data
will be set for both users.
"""
# Create a room and send an invite the other user
room_id = self.helper.create_room_as(
inviter_user_id,
is_public=False,
tok=inviter_tok,
)
self.helper.invite(
room_id,
src=inviter_user_id,
targ=invitee_user_id,
tok=inviter_tok,
extra_data={"is_direct": True},
)
if should_join_room:
# Person that was invited joins the room
self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
# Mimic the client setting the room as a direct message in the global account
# data for both users.
self._add_new_dm_to_global_account_data(
invitee_user_id, inviter_user_id, room_id
)
self._add_new_dm_to_global_account_data(
inviter_user_id, invitee_user_id, room_id
)
return room_id
def test_sync_list(self) -> None:
"""
Test that room IDs show up in the Sliding Sync `lists`
@@ -547,288 +622,52 @@ class SlidingSyncTestCase(SlidingSyncBase):
# There should be no room sent down.
self.assertFalse(channel.json_body["rooms"])
def test_filter_list(self) -> None:
def test_forgotten_up_to_date(self) -> None:
"""
Test that filters apply to `lists`
Make sure we get up-to-date `forgotten` status for rooms
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# Create a DM room
joined_dm_room_id = self._create_dm_room(
inviter_user_id=user1_id,
inviter_tok=user1_tok,
invitee_user_id=user2_id,
invitee_tok=user2_tok,
should_join_room=True,
)
invited_dm_room_id = self._create_dm_room(
inviter_user_id=user1_id,
inviter_tok=user1_tok,
invitee_user_id=user2_id,
invitee_tok=user2_tok,
should_join_room=False,
)
# Create a normal room
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok)
# Create a room that user1 is invited to
invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
# User1 is banned from the room (was never in the room)
self.helper.ban(room_id, src=user2_id, targ=user1_id, tok=user2_tok)
# Make the Sliding Sync request
sync_body = {
"lists": {
# Absense of filters does not imply "False" values
"all": {
"foo-list": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"timeline_limit": 0,
"filters": {},
},
# Test single truthy filter
"dms": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {"is_dm": True},
},
# Test single falsy filter
"non-dms": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {"is_dm": False},
},
# Test how multiple filters should stack (AND'd together)
"room-invites": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {"is_dm": False, "is_invite": True},
},
}
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Make sure it has the foo-list we requested
self.assertListEqual(
list(response_body["lists"].keys()),
["all", "dms", "non-dms", "room-invites"],
response_body["lists"].keys(),
response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
self.assertIncludes(
set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
{room_id},
exact=True,
)
# Make sure the lists have the correct rooms
self.assertListEqual(
list(response_body["lists"]["all"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [
invite_room_id,
room_id,
invited_dm_room_id,
joined_dm_room_id,
],
}
],
list(response_body["lists"]["all"]),
)
self.assertListEqual(
list(response_body["lists"]["dms"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [invited_dm_room_id, joined_dm_room_id],
}
],
list(response_body["lists"]["dms"]),
)
self.assertListEqual(
list(response_body["lists"]["non-dms"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [invite_room_id, room_id],
}
],
list(response_body["lists"]["non-dms"]),
)
self.assertListEqual(
list(response_body["lists"]["room-invites"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [invite_room_id],
}
],
list(response_body["lists"]["room-invites"]),
)
# Ensure DM's are correctly marked
self.assertDictEqual(
{
room_id: room.get("is_dm")
for room_id, room in response_body["rooms"].items()
},
{
invite_room_id: None,
room_id: None,
invited_dm_room_id: True,
joined_dm_room_id: True,
},
)
def test_filter_regardless_of_membership_server_left_room(self) -> None:
"""
Test that filters apply to rooms regardless of membership. We're also
compounding the problem by having all of the local users leave the room causing
our server to leave the room.
We want to make sure that if someone is filtering rooms, and leaves, you still
get that final update down sync that you left.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# Create a normal room
room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok)
# Create an encrypted space room
space_room_id = self.helper.create_room_as(
user2_id,
tok=user2_tok,
extra_content={
"creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
},
)
self.helper.send_state(
space_room_id,
EventTypes.RoomEncryption,
{EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
tok=user2_tok,
)
self.helper.join(space_room_id, user1_id, tok=user1_tok)
# Make an initial Sliding Sync request
# User1 forgets the room
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {
"all-list": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 0,
"filters": {},
},
"foo-list": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {
"is_encrypted": True,
"room_types": [RoomTypes.SPACE],
},
},
}
},
f"/_matrix/client/r0/rooms/{room_id}/forget",
content={},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
from_token = channel.json_body["pos"]
self.assertEqual(channel.code, 200, channel.result)
# Make sure the response has the lists we requested
self.assertListEqual(
list(channel.json_body["lists"].keys()),
["all-list", "foo-list"],
channel.json_body["lists"].keys(),
)
# Make sure the lists have the correct rooms
self.assertListEqual(
list(channel.json_body["lists"]["all-list"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [space_room_id, room_id],
}
],
)
self.assertListEqual(
list(channel.json_body["lists"]["foo-list"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [space_room_id],
}
],
)
# Everyone leaves the encrypted space room
self.helper.leave(space_room_id, user1_id, tok=user1_tok)
self.helper.leave(space_room_id, user2_id, tok=user2_tok)
# Make an incremental Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint + f"?pos={from_token}",
{
"lists": {
"all-list": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 0,
"filters": {},
},
"foo-list": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {
"is_encrypted": True,
"room_types": [RoomTypes.SPACE],
},
},
}
},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Make sure the lists have the correct rooms even though we `newly_left`
self.assertListEqual(
list(channel.json_body["lists"]["all-list"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [space_room_id, room_id],
}
],
)
self.assertListEqual(
list(channel.json_body["lists"]["foo-list"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [space_room_id],
}
],
# We should no longer see the forgotten room
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
self.assertIncludes(
set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
set(),
exact=True,
)
def test_sort_list(self) -> None:
+1 -7
View File
@@ -19,18 +19,12 @@
#
#
import unittest as stdlib_unittest
from typing import TYPE_CHECKING
from typing_extensions import Literal
from synapse._pydantic_compat import HAS_PYDANTIC_V2
from synapse._pydantic_compat import BaseModel, ValidationError
from synapse.types.rest.client import EmailRequestTokenBody
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import BaseModel, ValidationError
else:
from pydantic import BaseModel, ValidationError
class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
class Model(BaseModel):
+1 -1
View File
@@ -34,11 +34,11 @@ from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main.events import DeltaState
from synapse.storage.databases.main.events_bg_updates import (
_BackgroundUpdates,
_resolve_stale_data_in_sliding_sync_joined_rooms_table,
_resolve_stale_data_in_sliding_sync_membership_snapshots_table,
)
from synapse.types import create_requester
from synapse.types.storage import _BackgroundUpdates
from synapse.util import Clock
from tests.test_utils.event_injection import create_event