[^/]*)/?"
- PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030"
async def on_GET(
self,
@@ -235,9 +234,10 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
room_id: str,
) -> Tuple[int, JsonDict]:
timestamp = parse_integer_from_args(query, "ts", required=True)
- direction = parse_string_from_args(
- query, "dir", default="f", allowed_values=["f", "b"], required=True
+ direction_str = parse_string_from_args(
+ query, "dir", allowed_values=["f", "b"], required=True
)
+ direction = Direction(direction_str)
return await self.handler.on_timestamp_to_event_request(
origin, room_id, timestamp, direction
@@ -423,7 +423,7 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
- self._msc3706_enabled = hs.config.experimental.msc3706_enabled
+ self._read_msc3706_query_param = hs.config.experimental.msc3706_enabled
async def on_PUT(
self,
@@ -437,10 +437,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
# match those given in content
partial_state = False
- if self._msc3706_enabled:
+ # The stable query parameter wins, if it disagrees with the unstable
+ # parameter for some reason.
+ stable_param = parse_boolean_from_args(query, "omit_members", default=None)
+ if stable_param is not None:
+ partial_state = stable_param
+ elif self._read_msc3706_query_param:
partial_state = parse_boolean_from_args(
query, "org.matrix.msc3706.partial_state", default=False
)
+
result = await self.handler.on_send_join_request(
origin, content, room_id, caller_supports_partial_state=partial_state
)
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index fc21d58001..67e789eef7 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -14,16 +14,19 @@
# limitations under the License.
import logging
import random
-from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
+from synapse.api.constants import AccountDataTypes
from synapse.replication.http.account_data import (
+ ReplicationAddRoomAccountDataRestServlet,
ReplicationAddTagRestServlet,
+ ReplicationAddUserAccountDataRestServlet,
+ ReplicationRemoveRoomAccountDataRestServlet,
ReplicationRemoveTagRestServlet,
- ReplicationRoomAccountDataRestServlet,
- ReplicationUserAccountDataRestServlet,
+ ReplicationRemoveUserAccountDataRestServlet,
)
from synapse.streams import EventSource
-from synapse.types import JsonDict, StreamKeyType, UserID
+from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -41,8 +44,18 @@ class AccountDataHandler:
self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier()
- self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
- self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
+ self._add_user_data_client = (
+ ReplicationAddUserAccountDataRestServlet.make_client(hs)
+ )
+ self._remove_user_data_client = (
+ ReplicationRemoveUserAccountDataRestServlet.make_client(hs)
+ )
+ self._add_room_data_client = (
+ ReplicationAddRoomAccountDataRestServlet.make_client(hs)
+ )
+ self._remove_room_data_client = (
+ ReplicationRemoveRoomAccountDataRestServlet.make_client(hs)
+ )
self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data
@@ -112,7 +125,7 @@ class AccountDataHandler:
return max_stream_id
else:
- response = await self._room_data_client(
+ response = await self._add_room_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
@@ -121,15 +134,59 @@ class AccountDataHandler:
)
return response["max_stream_id"]
+ async def remove_account_data_for_room(
+ self, user_id: str, room_id: str, account_data_type: str
+ ) -> Optional[int]:
+ """
+ Deletes the room account data for the given user and account data type.
+
+ "Deleting" account data merely means setting the content of the account data
+ to an empty JSON object: {}.
+
+ Args:
+ user_id: The user ID to remove room account data for.
+ room_id: The room ID to target.
+ account_data_type: The account data type to remove.
+
+ Returns:
+ The maximum stream ID, or None if the room account data item did not exist.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.remove_account_data_for_room(
+ user_id, room_id, account_data_type
+ )
+ if max_stream_id is None:
+ # The referenced account data did not exist, so no delete occurred.
+ return None
+
+ self._notifier.on_new_event(
+ StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
+ )
+
+ # Notify Synapse modules that the content of the type has changed to an
+ # empty dictionary.
+ await self._notify_modules(user_id, room_id, account_data_type, {})
+
+ return max_stream_id
+ else:
+ response = await self._remove_room_data_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ account_data_type=account_data_type,
+ content={},
+ )
+ return response["max_stream_id"]
+
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some global account_data for a user.
Args:
- user_id: The user to add a tag for.
+ user_id: The user to add some account data for.
account_data_type: The type of account_data to add.
- content: A json object to associate with the tag.
+ content: The content json dictionary.
Returns:
The maximum stream ID.
@@ -148,7 +205,7 @@ class AccountDataHandler:
return max_stream_id
else:
- response = await self._user_data_client(
+ response = await self._add_user_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
account_data_type=account_data_type,
@@ -156,6 +213,45 @@ class AccountDataHandler:
)
return response["max_stream_id"]
+ async def remove_account_data_for_user(
+ self, user_id: str, account_data_type: str
+ ) -> Optional[int]:
+ """Removes a piece of global account_data for a user.
+
+ Args:
+ user_id: The user to remove account data for.
+ account_data_type: The type of account_data to remove.
+
+ Returns:
+ The maximum stream ID, or None if the room account data item did not exist.
+ """
+
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.remove_account_data_for_user(
+ user_id, account_data_type
+ )
+ if max_stream_id is None:
+ # The referenced account data did not exist, so no delete occurred.
+ return None
+
+ self._notifier.on_new_event(
+ StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
+ )
+
+ # Notify Synapse modules that the content of the type has changed to an
+ # empty dictionary.
+ await self._notify_modules(user_id, None, account_data_type, {})
+
+ return max_stream_id
+ else:
+ response = await self._remove_user_data_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ account_data_type=account_data_type,
+ content={},
+ )
+ return response["max_stream_id"]
+
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
@@ -218,7 +314,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
- def get_current_key(self, direction: str = "f") -> int:
+ def get_current_key(self) -> int:
return self.store.get_max_account_data_stream_id()
async def get_new_events(
@@ -226,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
user: UserID,
from_key: int,
limit: int,
- room_ids: Collection[str],
+ room_ids: StrCollection,
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
@@ -240,7 +336,11 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
for room_id, room_tags in tags.items():
results.append(
- {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
+ {
+ "type": AccountDataTypes.TAG,
+ "content": {"tags": room_tags},
+ "room_id": room_id,
+ }
)
(
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 5bf8e86387..b03c214b14 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -16,7 +16,7 @@ import abc
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
-from synapse.api.constants import Membership
+from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
@@ -30,6 +30,7 @@ logger = logging.getLogger(__name__)
class AdminHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
+ self._device_handler = hs.get_device_handler()
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
@@ -197,7 +198,7 @@ class AdminHandler:
# efficient method perhaps but it does guarantee we get everything.
while True:
events, _ = await self.store.paginate_room_events(
- room_id, from_key, to_key, limit=100, direction="f"
+ room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
)
if not events:
break
@@ -247,6 +248,21 @@ class AdminHandler:
)
writer.write_state(room_id, event_id, state)
+ # Get the user profile
+ profile = await self.get_user(UserID.from_string(user_id))
+ if profile is not None:
+ writer.write_profile(profile)
+
+ # Get all devices the user has
+ devices = await self._device_handler.get_devices_by_user(user_id)
+ writer.write_devices(devices)
+
+ # Get all connections the user has
+ connections = await self.get_whois(UserID.from_string(user_id))
+ writer.write_connections(
+ connections["devices"][""]["sessions"][0]["connections"]
+ )
+
return writer.finished()
@@ -297,6 +313,33 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
+ @abc.abstractmethod
+ def write_profile(self, profile: JsonDict) -> None:
+ """Write the profile of a user.
+
+ Args:
+ profile: The user profile.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def write_devices(self, devices: List[JsonDict]) -> None:
+ """Write the devices of a user.
+
+ Args:
+ devices: The list of devices.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def write_connections(self, connections: List[JsonDict]) -> None:
+ """Write the connections of a user.
+
+ Args:
+ connections: The list of connections / sessions.
+ """
+ raise NotImplementedError()
+
@abc.abstractmethod
def finished(self) -> Any:
"""Called when all data has successfully been exported and written.
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 4bb33df7ff..b4a3ad217a 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -578,9 +578,6 @@ class ApplicationServicesHandler:
device_id,
), messages in recipient_device_to_messages.items():
for message_json in messages:
- # Remove 'message_id' from the to-device message, as it's an internal ID
- message_json.pop("message_id", None)
-
message_payload.append(
{
"to_user_id": user_id,
@@ -615,8 +612,8 @@ class ApplicationServicesHandler:
)
# Fetch the users who have modified their device list since then.
- users_with_changed_device_lists = (
- await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
+ users_with_changed_device_lists = await self.store.get_all_devices_changed(
+ from_key, to_key=new_key
)
# Filter out any users the application service is not interested in
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 8b9ef25d29..30f2d46c3c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2031,7 +2031,7 @@ class PasswordAuthProvider:
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters
- self._supported_login_types: Dict[str, Iterable[str]] = {}
+ self._supported_login_types: Dict[str, Tuple[str, ...]] = {}
# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 4c3fb063b4..ba58f150d1 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Codes, Requester, UserID, create_requester
@@ -76,6 +77,9 @@ class DeactivateAccountHandler:
True if identity server supports removing threepids, otherwise False.
"""
+ # This can only be called on the main process.
+ assert isinstance(self._device_handler, DeviceHandler)
+
# Check if this user can be deactivated
if not await self._third_party_rules.check_can_deactivate_user(
user_id, by_admin
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 55ac7ef612..08afcbeefa 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,10 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
- Collection,
Dict,
Iterable,
List,
@@ -33,6 +33,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
+ InvalidAPICallError,
RequestSendFailed,
SynapseError,
)
@@ -43,8 +44,10 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.types import (
JsonDict,
+ StrCollection,
StreamKeyType,
StreamToken,
+ UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
@@ -65,6 +68,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
class DeviceWorkerHandler:
+ device_list_updater: "DeviceListWorkerUpdater"
+
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
@@ -76,6 +81,8 @@ class DeviceWorkerHandler:
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
+ self.device_list_updater = DeviceListWorkerUpdater(hs)
+
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
@@ -99,6 +106,19 @@ class DeviceWorkerHandler:
log_kv(device_map)
return devices
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
+
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ return await self.store.get_dehydrated_device(user_id)
+
@trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
"""Retrieve the given device
@@ -126,8 +146,8 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
- self, user_id: str, room_ids: Collection[str], from_token: StreamToken
- ) -> Collection[str]:
+ self, user_id: str, room_ids: StrCollection, from_token: StreamToken
+ ) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
@@ -320,10 +340,13 @@ class DeviceWorkerHandler:
class DeviceHandler(DeviceWorkerHandler):
+ device_list_updater: "DeviceListUpdater"
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.federation_sender = hs.get_federation_sender()
+ self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
self.device_list_updater = DeviceListUpdater(hs, self)
@@ -480,7 +503,7 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
- # Delete access tokens and e2e keys for each device. Not optimised as it is not
+ # Delete data specific to each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
await self._auth_handler.delete_access_tokens_for_user(
@@ -490,6 +513,14 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id, device_id=device_id
)
+ if self.hs.config.experimental.msc3890_enabled:
+ # Remove any local notification settings for this device in accordance
+ # with MSC3890.
+ await self._account_data_handler.remove_account_data_for_user(
+ user_id,
+ f"org.matrix.msc3890.local_notification_settings.{device_id}",
+ )
+
await self.notify_device_update(user_id, device_ids)
async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
@@ -520,7 +551,7 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
@measure_func("notify_device_update")
async def notify_device_update(
- self, user_id: str, device_ids: Collection[str]
+ self, user_id: str, device_ids: StrCollection
) -> None:
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
@@ -606,19 +637,6 @@ class DeviceHandler(DeviceWorkerHandler):
await self.delete_devices(user_id, [old_device_id])
return device_id
- async def get_dehydrated_device(
- self, user_id: str
- ) -> Optional[Tuple[str, JsonDict]]:
- """Retrieve the information for a dehydrated device.
-
- Args:
- user_id: the user whose dehydrated device we are looking for
- Returns:
- a tuple whose first item is the device ID, and the second item is
- the dehydrated device information
- """
- return await self.store.get_dehydrated_device(user_id)
-
async def rehydrate_device(
self, user_id: str, access_token: str, device_id: str
) -> dict:
@@ -682,13 +700,33 @@ class DeviceHandler(DeviceWorkerHandler):
hosts_already_sent_to: Set[str] = set()
try:
+ stream_id, room_id = await self.store.get_device_change_last_converted_pos()
+
while True:
self._handle_new_device_update_new_data = False
- rows = await self.store.get_uncoverted_outbound_room_pokes()
+ max_stream_id = self.store.get_device_stream_token()
+ rows = await self.store.get_uncoverted_outbound_room_pokes(
+ stream_id, room_id
+ )
if not rows:
# If the DB returned nothing then there is nothing left to
# do, *unless* a new device list update happened during the
# DB query.
+
+ # Advance `(stream_id, room_id)`.
+ # `max_stream_id` comes from *before* the query for unconverted
+ # rows, which means that any unconverted rows must have a larger
+ # stream ID.
+ if max_stream_id > stream_id:
+ stream_id, room_id = max_stream_id, ""
+ await self.store.set_device_change_last_converted_pos(
+ stream_id, room_id
+ )
+ else:
+ assert max_stream_id == stream_id
+ # Avoid moving `room_id` backwards.
+ pass
+
if self._handle_new_device_update_new_data:
continue
else:
@@ -718,7 +756,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
room_id=room_id,
- stream_id=stream_id,
hosts=hosts,
context=opentracing_context,
)
@@ -752,6 +789,12 @@ class DeviceHandler(DeviceWorkerHandler):
hosts_already_sent_to.update(hosts)
current_stream_id = stream_id
+ # Advance `(stream_id, room_id)`.
+ _, _, room_id, stream_id, _ = rows[-1]
+ await self.store.set_device_change_last_converted_pos(
+ stream_id, room_id
+ )
+
finally:
self._handle_new_device_update_is_processing = False
@@ -816,6 +859,7 @@ class DeviceHandler(DeviceWorkerHandler):
known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
room_id
)
+ assert known_hosts_at_join is not None
potentially_changed_hosts.difference_update(known_hosts_at_join)
potentially_changed_hosts.discard(self.server_name)
@@ -834,7 +878,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
room_id=room_id,
- stream_id=None,
hosts=potentially_changed_hosts,
context=None,
)
@@ -858,7 +901,73 @@ def _update_device_from_client_ips(
)
-class DeviceListUpdater:
+class DeviceListWorkerUpdater:
+ "Handles incoming device list updates from federation and contacts the main process over replication"
+
+ def __init__(self, hs: "HomeServer"):
+ from synapse.replication.http.devices import (
+ ReplicationMultiUserDevicesResyncRestServlet,
+ ReplicationUserDevicesResyncRestServlet,
+ )
+
+ self._user_device_resync_client = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
+ )
+ self._multi_user_device_resync_client = (
+ ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
+ )
+
+ async def multi_user_device_resync(
+ self, user_ids: List[str], mark_failed_as_stale: bool = True
+ ) -> Dict[str, Optional[JsonDict]]:
+ """
+ Like `user_device_resync` but operates on multiple users **from the same origin**
+ at once.
+
+ Returns:
+ Dict from User ID to the same Dict as `user_device_resync`.
+ """
+ # mark_failed_as_stale is not sent. Ensure this doesn't break expectations.
+ assert mark_failed_as_stale
+
+ if not user_ids:
+ # Shortcut empty requests
+ return {}
+
+ try:
+ return await self._multi_user_device_resync_client(user_ids=user_ids)
+ except SynapseError as err:
+ if not (
+ err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED
+ ):
+ raise
+
+ # Fall back to single requests
+ result: Dict[str, Optional[JsonDict]] = {}
+ for user_id in user_ids:
+ result[user_id] = await self._user_device_resync_client(user_id=user_id)
+ return result
+
+ async def user_device_resync(
+ self, user_id: str, mark_failed_as_stale: bool = True
+ ) -> Optional[JsonDict]:
+ """Fetches all devices for a user and updates the device cache with them.
+
+ Args:
+ user_id: The user's id whose device_list will be updated.
+ mark_failed_as_stale: Whether to mark the user's device list as stale
+ if the attempt to resync failed.
+ Returns:
+ A dict with device info as under the "devices" in the result of this
+ request:
+ https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ None when we weren't able to fetch the device info for some reason,
+ e.g. due to a connection problem.
+ """
+ return (await self.multi_user_device_resync([user_id]))[user_id]
+
+
+class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
@@ -866,6 +975,7 @@ class DeviceListUpdater:
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.device_handler = device_handler
+ self._notifier = hs.get_notifier()
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
@@ -937,7 +1047,7 @@ class DeviceListUpdater:
# Check if we are partially joining any rooms. If so we need to store
# all device list updates so that we can handle them correctly once we
# know who is in the room.
- # TODO(faster joins): this fetches and processes a bunch of data that we don't
+ # TODO(faster_joins): this fetches and processes a bunch of data that we don't
# use. Could be replaced by a tighter query e.g.
# SELECT EXISTS(SELECT 1 FROM partial_state_rooms)
partial_rooms = await self.store.get_partial_state_room_resync_info()
@@ -946,6 +1056,7 @@ class DeviceListUpdater:
user_id,
device_id,
)
+ self._notifier.notify_replication()
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
@@ -1101,19 +1212,66 @@ class DeviceListUpdater:
# Allow future calls to retry resyncinc out of sync device lists.
self._resync_retry_in_progress = False
+ async def multi_user_device_resync(
+ self, user_ids: List[str], mark_failed_as_stale: bool = True
+ ) -> Dict[str, Optional[JsonDict]]:
+ """
+ Like `user_device_resync` but operates on multiple users **from the same origin**
+ at once.
+
+ Returns:
+ Dict from User ID to the same Dict as `user_device_resync`.
+ """
+ if not user_ids:
+ return {}
+
+ origins = {UserID.from_string(user_id).domain for user_id in user_ids}
+
+ if len(origins) != 1:
+ raise InvalidAPICallError(f"Only one origin permitted, got {origins!r}")
+
+ result = {}
+ failed = set()
+ # TODO(Perf): Actually batch these up
+ for user_id in user_ids:
+ user_result, user_failed = await self._user_device_resync_returning_failed(
+ user_id
+ )
+ result[user_id] = user_result
+ if user_failed:
+ failed.add(user_id)
+
+ if mark_failed_as_stale:
+ await self.store.mark_remote_users_device_caches_as_stale(failed)
+
+ return result
+
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[JsonDict]:
+ result, failed = await self._user_device_resync_returning_failed(user_id)
+
+ if failed and mark_failed_as_stale:
+ # Mark the remote user's device list as stale so we know we need to retry
+ # it later.
+ await self.store.mark_remote_users_device_caches_as_stale((user_id,))
+
+ return result
+
+ async def _user_device_resync_returning_failed(
+ self, user_id: str
+ ) -> Tuple[Optional[JsonDict], bool]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id: The user's id whose device_list will be updated.
- mark_failed_as_stale: Whether to mark the user's device list as stale
- if the attempt to resync failed.
Returns:
- A dict with device info as under the "devices" in the result of this
- request:
- https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ - A dict with device info as under the "devices" in the result of this
+ request:
+ https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ None when we weren't able to fetch the device info for some reason,
+ e.g. due to a connection problem.
+ - True iff the resync failed and the device list should be marked as stale.
"""
logger.debug("Attempting to resync the device list for %s", user_id)
log_kv({"message": "Doing resync to update device list."})
@@ -1122,12 +1280,7 @@ class DeviceListUpdater:
try:
result = await self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination:
- if mark_failed_as_stale:
- # Mark the remote user's device list as stale so we know we need to retry
- # it later.
- await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
- return None
+ return None, True
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
"Failed to handle device list update for %s: %s",
@@ -1135,23 +1288,18 @@ class DeviceListUpdater:
e,
)
- if mark_failed_as_stale:
- # Mark the remote user's device list as stale so we know we need to retry
- # it later.
- await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
- return None
+ return None, True
except FederationDeniedError as e:
set_tag("error", True)
log_kv({"reason": "FederationDeniedError"})
logger.info(e)
- return None
+ return None, False
except Exception as e:
set_tag("error", True)
log_kv(
@@ -1159,12 +1307,7 @@ class DeviceListUpdater:
)
logger.exception("Failed to handle device list update for %s", user_id)
- if mark_failed_as_stale:
- # Mark the remote user's device list as stale so we know we need to retry
- # it later.
- await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
- return None
+ return None, True
log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
@@ -1246,7 +1389,7 @@ class DeviceListUpdater:
# point.
self._seen_updates[user_id] = {stream_id}
- return result
+ return result, False
async def process_cross_signing_key_update(
self,
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 1d5de0f29a..22f298d445 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict
-from synapse.api.constants import EduTypes, ToDeviceEventTypes
+from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
@@ -195,7 +195,7 @@ class DeviceMessageHandler:
sender_user_id,
unknown_devices,
)
- await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+ await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
# Immediately attempt a resync in the background
run_in_background(self._user_device_resync, user_id=sender_user_id) # type: ignore[unused-awaitable]
@@ -216,14 +216,24 @@ class DeviceMessageHandler:
"""
sender_user_id = requester.user.to_string()
- message_id = random_string(16)
- set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
-
- log_kv({"number_of_to_device_messages": len(messages)})
- set_tag("sender", sender_user_id)
+ set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
+ set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
local_messages = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
+ # add an opentracing log entry for each message
+ for device_id, message_content in by_device.items():
+ log_kv(
+ {
+ "event": "send_to_device_message",
+ "user_id": user_id,
+ "device_id": device_id,
+ EventContentFields.TO_DEVICE_MSGID: message_content.get(
+ EventContentFields.TO_DEVICE_MSGID
+ ),
+ }
+ )
+
# Ratelimit local cross-user key requests by the sending device.
if (
message_type == ToDeviceEventTypes.RoomKeyRequest
@@ -233,6 +243,7 @@ class DeviceMessageHandler:
requester, (sender_user_id, requester.device_id)
)
if not allowed:
+ log_kv({"message": f"dropping key requests to {user_id}"})
logger.info(
"Dropping room_key_request from %s to %s due to rate limit",
sender_user_id,
@@ -247,18 +258,11 @@ class DeviceMessageHandler:
"content": message_content,
"type": message_type,
"sender": sender_user_id,
- "message_id": message_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
- log_kv(
- {
- "user_id": user_id,
- "device_id": list(messages_by_device),
- }
- )
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
@@ -267,7 +271,11 @@ class DeviceMessageHandler:
remote_edu_contents = {}
for destination, messages in remote_messages.items():
- log_kv({"destination": destination})
+ # The EDU contains a "message_id" property which is used for
+ # idempotence. Make up a random one.
+ message_id = random_string(16)
+ log_kv({"destination": destination, "message_id": message_id})
+
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index bf1221f523..d2188ca08f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -27,17 +27,17 @@ from twisted.internet import defer
from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
-from synapse.util import json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer, delay_cancellation
+from synapse.util import json_decoder
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination
@@ -56,27 +56,23 @@ class E2eKeysHandler:
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
- self._edu_updater = SigningKeyEduUpdater(hs, self)
-
federation_registry = hs.get_federation_registry()
- self._is_master = hs.config.worker.worker_app is None
- if not self._is_master:
- self._user_device_resync_client = (
- ReplicationUserDevicesResyncRestServlet.make_client(hs)
- )
- else:
+ is_master = hs.config.worker.worker_app is None
+ if is_master:
+ edu_updater = SigningKeyEduUpdater(hs)
+
# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
EduTypes.SIGNING_KEY_UPDATE,
- self._edu_updater.incoming_signing_key_update,
+ edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
- self._edu_updater.incoming_signing_key_update,
+ edu_updater.incoming_signing_key_update,
)
# doesn't really work as part of the generic query API, because the
@@ -242,24 +238,28 @@ class E2eKeysHandler:
# Now fetch any devices that we don't have in our cache
# TODO It might make sense to propagate cancellations into the
# deferreds which are querying remote homeservers.
- await make_deferred_yieldable(
- delay_cancellation(
- defer.gatherResults(
- [
- run_in_background(
- self._query_devices_for_destination,
- results,
- cross_signing_keys,
- failures,
- destination,
- queries,
- timeout,
- )
- for destination, queries in remote_queries_not_in_cache.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ logger.debug(
+ "%d destinations to query devices for", len(remote_queries_not_in_cache)
+ )
+
+ async def _query(
+ destination_queries: Tuple[str, Dict[str, Iterable[str]]]
+ ) -> None:
+ destination, queries = destination_queries
+ return await self._query_devices_for_destination(
+ results,
+ cross_signing_keys,
+ failures,
+ destination,
+ queries,
+ timeout,
)
+
+ await concurrently_execute(
+ _query,
+ remote_queries_not_in_cache.items(),
+ 10,
+ delay_cancellation=True,
)
ret = {"device_keys": results, "failures": failures}
@@ -304,29 +304,41 @@ class E2eKeysHandler:
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
- for (user_id, device_list) in destination_query.items():
- if user_id in user_ids_updated:
- continue
- if device_list:
- continue
+ # Perform a user device resync for each user only once and only as long as:
+ # - they have an empty device_list
+ # - they are in some rooms that this server can see
+ users_to_resync_devices = {
+ user_id
+ for (user_id, device_list) in destination_query.items()
+ if (not device_list) and (await self.store.get_rooms_for_user(user_id))
+ }
- room_ids = await self.store.get_rooms_for_user(user_id)
- if not room_ids:
- continue
+ logger.debug(
+ "%d users to resync devices for from destination %s",
+ len(users_to_resync_devices),
+ destination,
+ )
- # We've decided we're sharing a room with this user and should
- # probably be tracking their device lists. However, we haven't
- # done an initial sync on the device list so we do it now.
- try:
- if self._is_master:
- resync_results = await self.device_handler.device_list_updater.user_device_resync(
- user_id
- )
- else:
- resync_results = await self._user_device_resync_client(
- user_id=user_id
+ try:
+ user_resync_results = (
+ await self.device_handler.device_list_updater.multi_user_device_resync(
+ list(users_to_resync_devices)
+ )
+ )
+ for user_id in users_to_resync_devices:
+ resync_results = user_resync_results[user_id]
+
+ if resync_results is None:
+ # TODO: It's weird that we'll store a failure against a
+ # destination, yet continue processing users from that
+ # destination.
+ # We might want to consider changing this, but for now
+ # I'm leaving it as I found it.
+ failures[destination] = _exception_to_failure(
+ ValueError(f"Device resync failed for {user_id!r}")
)
+ continue
# Add the device keys to the results.
user_devices = resync_results["devices"]
@@ -344,8 +356,8 @@ class E2eKeysHandler:
if self_signing_key:
cross_signing_keys["self_signing_keys"][user_id] = self_signing_key
- except Exception as e:
- failures[destination] = _exception_to_failure(e)
+ except Exception as e:
+ failures[destination] = _exception_to_failure(e)
if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
@@ -605,6 +617,8 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec()
@@ -732,6 +746,8 @@ class E2eKeysHandler:
user_id: the user uploading the keys
keys: the signing keys
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
# if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys
@@ -823,6 +839,9 @@ class E2eKeysHandler:
Raises:
SynapseError: if the signatures dict is not valid.
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
failures = {}
# signatures to be stored. Each item will be a SignatureListItem
@@ -1200,6 +1219,9 @@ class E2eKeysHandler:
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
try:
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
@@ -1396,11 +1418,14 @@ class SignatureListItem:
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
- def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
- self.e2e_keys_handler = e2e_keys_handler
+
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self._device_handler = device_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
@@ -1445,9 +1470,6 @@ class SigningKeyEduUpdater:
user_id: the user whose updates we are processing
"""
- device_handler = self.e2e_keys_handler.device_handler
- device_list_updater = device_handler.device_list_updater
-
async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
@@ -1459,13 +1481,11 @@ class SigningKeyEduUpdater:
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- new_device_ids = (
- await device_list_updater.process_cross_signing_key_update(
- user_id,
- master_key,
- self_signing_key,
- )
+ new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
+ user_id,
+ master_key,
+ self_signing_key,
)
device_ids = device_ids + new_device_ids
- await device_handler.notify_device_update(user_id, device_ids)
+ await self._device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 3bbad0271b..a23a8ce2a1 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union
+from typing import TYPE_CHECKING, List, Mapping, Optional, Union
from synapse import event_auth
from synapse.api.constants import (
@@ -29,7 +29,7 @@ from synapse.event_auth import (
)
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import StateMap, StrCollection, get_domain_from_id
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -45,6 +45,7 @@ class EventAuthHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
+ self._state_storage_controller = hs.get_storage_controllers().state
self._server_name = hs.hostname
async def check_auth_rules_from_context(
@@ -179,17 +180,22 @@ class EventAuthHandler:
this function may return an incorrect result as we are not able to fully
track server membership in a room without full state.
"""
- if not allow_partial_state_rooms and await self._store.is_partial_state_room(
- room_id
- ):
- raise AuthError(
- 403,
- "Unable to authorise you right now; room is partial-stated here.",
- errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE,
- )
-
- if not await self.is_host_in_room(room_id, host):
- raise AuthError(403, "Host not in room.")
+ if await self._store.is_partial_state_room(room_id):
+ if allow_partial_state_rooms:
+ current_hosts = await self._state_storage_controller.get_current_hosts_in_room_or_partial_state_approximation(
+ room_id
+ )
+ if host not in current_hosts:
+ raise AuthError(403, "Host not in room (partial-state approx).")
+ else:
+ raise AuthError(
+ 403,
+ "Unable to authorise you right now; room is partial-stated here.",
+ errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE,
+ )
+ else:
+ if not await self.is_host_in_room(room_id, host):
+ raise AuthError(403, "Host not in room.")
async def check_restricted_join_rules(
self,
@@ -284,7 +290,7 @@ class EventAuthHandler:
async def get_rooms_that_allow_join(
self, state_ids: StateMap[str]
- ) -> Collection[str]:
+ ) -> StrCollection:
"""
Generate a list of rooms in which membership allows access to a room.
@@ -325,7 +331,7 @@ class EventAuthHandler:
return result
- async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool:
+ async def is_user_in_rooms(self, room_ids: StrCollection, user_id: str) -> bool:
"""
Check whether a user is a member of any of the provided rooms.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ffc45473c2..16057c030c 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -22,11 +22,12 @@ from enum import Enum
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
- Collection,
+ AbstractSet,
Dict,
Iterable,
List,
Optional,
+ Set,
Tuple,
Union,
)
@@ -47,7 +48,6 @@ from synapse.api.errors import (
FederationError,
FederationPullAttemptBackoffError,
HttpResponseException,
- LimitExceededError,
NotFoundError,
RequestSendFailed,
SynapseError,
@@ -70,8 +70,8 @@ from synapse.replication.http.federation import (
)
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import JsonDict, StrCollection, get_domain_from_id
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
@@ -152,6 +152,7 @@ class FederationHandler:
self._federation_event_handler = hs.get_federation_event_handler()
self._device_handler = hs.get_device_handler()
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
+ self._notifier = hs.get_notifier()
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
@@ -170,12 +171,29 @@ class FederationHandler:
self.third_party_event_rules = hs.get_third_party_event_rules()
+ # Tracks running partial state syncs by room ID.
+ # Partial state syncs currently only run on the main process, so it's okay to
+ # track them in-memory for now.
+ self._active_partial_state_syncs: Set[str] = set()
+ # Tracks partial state syncs we may want to restart.
+ # A dictionary mapping room IDs to (initial destination, other destinations)
+ # tuples.
+ self._partial_state_syncs_maybe_needing_restart: Dict[
+ str, Tuple[Optional[str], AbstractSet[str]]
+ ] = {}
+ # A lock guarding the partial state flag for rooms.
+ # When the lock is held for a given room, no other concurrent code may
+ # partial state or un-partial state the room.
+ self._is_partial_state_room_linearizer = Linearizer(
+ name="_is_partial_state_room_linearizer"
+ )
+
# if this is the main process, fire off a background process to resume
# any partial-state-resync operations which were in flight when we
# were shut down.
if not hs.config.worker.worker_app:
- run_as_background_process( # type: ignore[unused-awaitable]
- "resume_sync_partial_state_room", self._resume_sync_partial_state_room
+ run_as_background_process(
+ "resume_sync_partial_state_room", self._resume_partial_state_room_sync
)
@trace
@@ -419,7 +437,7 @@ class FederationHandler:
)
)
- async def try_backfill(domains: Collection[str]) -> bool:
+ async def try_backfill(domains: StrCollection) -> bool:
# TODO: Should we try multiple of these at a time?
# Number of contacted remote homeservers that have denied our backfill
@@ -586,7 +604,23 @@ class FederationHandler:
self._federation_event_handler.room_queues[room_id] = []
- await self._clean_room_for_join(room_id)
+ is_host_joined = await self.store.is_host_joined(room_id, self.server_name)
+
+ if not is_host_joined:
+ # We may have old forward extremities lying around if the homeserver left
+ # the room completely in the past. Clear them out.
+ #
+ # Note that this check-then-clear is subject to races where
+ # * the homeserver is in the room and stops being in the room just after
+ # the check. We won't reset the forward extremities, but that's okay,
+ # since they will be almost up to date.
+ # * the homeserver is not in the room and starts being in the room just
+ # after the check. This can't happen, since `RoomMemberHandler` has a
+ # linearizer lock which prevents concurrent remote joins into the same
+ # room.
+ # In short, the races either have an acceptable outcome or should be
+ # impossible.
+ await self._clean_room_for_join(room_id)
try:
# Try the host we successfully got a response to /make_join/
@@ -598,94 +632,116 @@ class FederationHandler:
except ValueError:
pass
- ret = await self.federation_client.send_join(
- host_list, event, room_version_obj
- )
+ async with self._is_partial_state_room_linearizer.queue(room_id):
+ already_partial_state_room = await self.store.is_partial_state_room(
+ room_id
+ )
- event = ret.event
- origin = ret.origin
- state = ret.state
- auth_chain = ret.auth_chain
- auth_chain.sort(key=lambda e: e.depth)
+ ret = await self.federation_client.send_join(
+ host_list,
+ event,
+ room_version_obj,
+ # Perform a full join when we are already in the room and it is a
+ # full state room, since we are not allowed to persist a partial
+ # state join event in a full state room. In the future, we could
+ # optimize this by always performing a partial state join and
+ # computing the state ourselves or retrieving it from the remote
+ # homeserver if necessary.
+ #
+ # There's a race where we leave the room, then perform a full join
+ # anyway. This should end up being fast anyway, since we would
+ # already have the full room state and auth chain persisted.
+ partial_state=not is_host_joined or already_partial_state_room,
+ )
- logger.debug("do_invite_join auth_chain: %s", auth_chain)
- logger.debug("do_invite_join state: %s", state)
+ event = ret.event
+ origin = ret.origin
+ state = ret.state
+ auth_chain = ret.auth_chain
+ auth_chain.sort(key=lambda e: e.depth)
- logger.debug("do_invite_join event: %s", event)
+ logger.debug("do_invite_join auth_chain: %s", auth_chain)
+ logger.debug("do_invite_join state: %s", state)
- # if this is the first time we've joined this room, it's time to add
- # a row to `rooms` with the correct room version. If there's already a
- # row there, we should override it, since it may have been populated
- # based on an invite request which lied about the room version.
- #
- # federation_client.send_join has already checked that the room
- # version in the received create event is the same as room_version_obj,
- # so we can rely on it now.
- #
- await self.store.upsert_room_on_join(
- room_id=room_id,
- room_version=room_version_obj,
- state_events=state,
- )
+ logger.debug("do_invite_join event: %s", event)
- if ret.partial_state:
- # Mark the room as having partial state.
- # The background process is responsible for unmarking this flag,
- # even if the join fails.
- await self.store.store_partial_state_room(
+ # if this is the first time we've joined this room, it's time to add
+ # a row to `rooms` with the correct room version. If there's already a
+ # row there, we should override it, since it may have been populated
+ # based on an invite request which lied about the room version.
+ #
+ # federation_client.send_join has already checked that the room
+ # version in the received create event is the same as room_version_obj,
+ # so we can rely on it now.
+ #
+ await self.store.upsert_room_on_join(
room_id=room_id,
- servers=ret.servers_in_room,
- device_lists_stream_id=self.store.get_device_stream_token(),
- joined_via=origin,
+ room_version=room_version_obj,
+ state_events=state,
)
- try:
- max_stream_id = (
- await self._federation_event_handler.process_remote_join(
- origin,
- room_id,
- auth_chain,
- state,
- event,
- room_version_obj,
- partial_state=ret.partial_state,
- )
- )
- except PartialStateConflictError as e:
- # The homeserver was already in the room and it is no longer partial
- # stated. We ought to be doing a local join instead. Turn the error into
- # a 429, as a hint to the client to try again.
- # TODO(faster_joins): `_should_perform_remote_join` suggests that we may
- # do a remote join for restricted rooms even if we have full state.
- logger.error(
- "Room %s was un-partial stated while processing remote join.",
- room_id,
- )
- raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0)
- else:
- # Record the join event id for future use (when we finish the full
- # join). We have to do this after persisting the event to keep foreign
- # key constraints intact.
- if ret.partial_state:
- await self.store.write_partial_state_rooms_join_event_id(
- room_id, event.event_id
- )
- finally:
- # Always kick off the background process that asynchronously fetches
- # state for the room.
- # If the join failed, the background process is responsible for
- # cleaning up — including unmarking the room as a partial state room.
- if ret.partial_state:
- # Kick off the process of asynchronously fetching the state for this
- # room.
- run_as_background_process( # type: ignore[unused-awaitable]
- desc="sync_partial_state_room",
- func=self._sync_partial_state_room,
- initial_destination=origin,
- other_destinations=ret.servers_in_room,
+ if ret.partial_state and not already_partial_state_room:
+ # Mark the room as having partial state.
+ # The background process is responsible for unmarking this flag,
+ # even if the join fails.
+ # TODO(faster_joins):
+ # We may want to reset the partial state info if it's from an
+ # old, failed partial state join.
+ # https://github.com/matrix-org/synapse/issues/13000
+ await self.store.store_partial_state_room(
room_id=room_id,
+ servers=ret.servers_in_room,
+ device_lists_stream_id=self.store.get_device_stream_token(),
+ joined_via=origin,
)
+ try:
+ max_stream_id = (
+ await self._federation_event_handler.process_remote_join(
+ origin,
+ room_id,
+ auth_chain,
+ state,
+ event,
+ room_version_obj,
+ partial_state=ret.partial_state,
+ )
+ )
+ except PartialStateConflictError:
+ # This should be impossible, since we hold the lock on the room's
+ # partial statedness.
+ logger.error(
+ "Room %s was un-partial stated while processing remote join.",
+ room_id,
+ )
+ raise
+ else:
+ # Record the join event id for future use (when we finish the full
+ # join). We have to do this after persisting the event to keep
+ # foreign key constraints intact.
+ if ret.partial_state and not already_partial_state_room:
+ # TODO(faster_joins):
+ # We may want to reset the partial state info if it's from
+ # an old, failed partial state join.
+ # https://github.com/matrix-org/synapse/issues/13000
+ await self.store.write_partial_state_rooms_join_event_id(
+ room_id, event.event_id
+ )
+ finally:
+ # Always kick off the background process that asynchronously fetches
+ # state for the room.
+ # If the join failed, the background process is responsible for
+ # cleaning up — including unmarking the room as a partial state
+ # room.
+ if ret.partial_state:
+ # Kick off the process of asynchronously fetching the state for
+ # this room.
+ self._start_partial_state_room_sync(
+ initial_destination=origin,
+ other_destinations=ret.servers_in_room,
+ room_id=room_id,
+ )
+
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
await self._replication.wait_for_stream_position(
@@ -1342,32 +1398,53 @@ class FederationHandler:
)
EventValidator().validate_builder(builder)
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
- )
- event, context = await self.add_display_name_to_third_party_invite(
- room_version_obj, event_dict, event, context
- )
+ # Try several times, it could fail with PartialStateConflictError
+ # in send_membership_event, cf comment in except block.
+ max_retries = 5
+ for i in range(max_retries):
+ try:
+ (
+ event,
+ context,
+ ) = await self.event_creation_handler.create_new_client_event(
+ builder=builder
+ )
- EventValidator().validate_new(event, self.config)
+ event, context = await self.add_display_name_to_third_party_invite(
+ room_version_obj, event_dict, event, context
+ )
- # We need to tell the transaction queue to send this out, even
- # though the sender isn't a local user.
- event.internal_metadata.send_on_behalf_of = self.hs.hostname
+ EventValidator().validate_new(event, self.config)
- try:
- validate_event_for_room_version(event)
- await self._event_auth_handler.check_auth_rules_from_context(event)
- except AuthError as e:
- logger.warning("Denying new third party invite %r because %s", event, e)
- raise e
+ # We need to tell the transaction queue to send this out, even
+ # though the sender isn't a local user.
+ event.internal_metadata.send_on_behalf_of = self.hs.hostname
- await self._check_signature(event, context)
+ try:
+ validate_event_for_room_version(event)
+ await self._event_auth_handler.check_auth_rules_from_context(
+ event
+ )
+ except AuthError as e:
+ logger.warning(
+ "Denying new third party invite %r because %s", event, e
+ )
+ raise e
- # We retrieve the room member handler here as to not cause a cyclic dependency
- member_handler = self.hs.get_room_member_handler()
- await member_handler.send_membership_event(None, event, context)
+ await self._check_signature(event, context)
+
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ await member_handler.send_membership_event(None, event, context)
+
+ break
+ except PartialStateConflictError as e:
+ # Persisting couldn't happen because the room got un-partial stated
+ # in the meantime and context needs to be recomputed, so let's do so.
+ if i == max_retries - 1:
+ raise e
+ pass
else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
@@ -1399,28 +1476,46 @@ class FederationHandler:
room_version_obj, event_dict
)
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
- )
- event, context = await self.add_display_name_to_third_party_invite(
- room_version_obj, event_dict, event, context
- )
+ # Try several times, it could fail with PartialStateConflictError
+ # in send_membership_event, cf comment in except block.
+ max_retries = 5
+ for i in range(max_retries):
+ try:
+ (
+ event,
+ context,
+ ) = await self.event_creation_handler.create_new_client_event(
+ builder=builder
+ )
+ event, context = await self.add_display_name_to_third_party_invite(
+ room_version_obj, event_dict, event, context
+ )
- try:
- validate_event_for_room_version(event)
- await self._event_auth_handler.check_auth_rules_from_context(event)
- except AuthError as e:
- logger.warning("Denying third party invite %r because %s", event, e)
- raise e
- await self._check_signature(event, context)
+ try:
+ validate_event_for_room_version(event)
+ await self._event_auth_handler.check_auth_rules_from_context(event)
+ except AuthError as e:
+ logger.warning("Denying third party invite %r because %s", event, e)
+ raise e
+ await self._check_signature(event, context)
- # We need to tell the transaction queue to send this out, even
- # though the sender isn't a local user.
- event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
+ # We need to tell the transaction queue to send this out, even
+ # though the sender isn't a local user.
+ event.internal_metadata.send_on_behalf_of = get_domain_from_id(
+ event.sender
+ )
- # We retrieve the room member handler here as to not cause a cyclic dependency
- member_handler = self.hs.get_room_member_handler()
- await member_handler.send_membership_event(None, event, context)
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ await member_handler.send_membership_event(None, event, context)
+
+ break
+ except PartialStateConflictError as e:
+ # Persisting couldn't happen because the room got un-partial stated
+ # in the meantime and context needs to be recomputed, so let's do so.
+ if i == max_retries - 1:
+ raise e
+ pass
async def add_display_name_to_third_party_invite(
self,
@@ -1620,24 +1715,104 @@ class FederationHandler:
# well.
return None
- async def _resume_sync_partial_state_room(self) -> None:
+ async def _resume_partial_state_room_sync(self) -> None:
"""Resumes resyncing of all partial-state rooms after a restart."""
assert not self.config.worker.worker_app
partial_state_rooms = await self.store.get_partial_state_room_resync_info()
for room_id, resync_info in partial_state_rooms.items():
- run_as_background_process( # type: ignore[unused-awaitable]
- desc="sync_partial_state_room",
- func=self._sync_partial_state_room,
+ self._start_partial_state_room_sync(
initial_destination=resync_info.joined_via,
other_destinations=resync_info.servers_in_room,
room_id=room_id,
)
+ def _start_partial_state_room_sync(
+ self,
+ initial_destination: Optional[str],
+ other_destinations: AbstractSet[str],
+ room_id: str,
+ ) -> None:
+ """Starts the background process to resync the state of a partial state room,
+ if it is not already running.
+
+ Args:
+ initial_destination: the initial homeserver to pull the state from
+ other_destinations: other homeservers to try to pull the state from, if
+ `initial_destination` is unavailable
+ room_id: room to be resynced
+ """
+
+ async def _sync_partial_state_room_wrapper() -> None:
+ if room_id in self._active_partial_state_syncs:
+ # Another local user has joined the room while there is already a
+ # partial state sync running. This implies that there is a new join
+ # event to un-partial state. We might find ourselves in one of a few
+ # scenarios:
+ # 1. There is an existing partial state sync. The partial state sync
+ # un-partial states the new join event before completing and all is
+ # well.
+ # 2. Before the latest join, the homeserver was no longer in the room
+ # and there is an existing partial state sync from our previous
+ # membership of the room. The partial state sync may have:
+ # a) succeeded, but not yet terminated. The room will not be
+ # un-partial stated again unless we restart the partial state
+ # sync.
+ # b) failed, because we were no longer in the room and remote
+ # homeservers were refusing our requests, but not yet
+ # terminated. After the latest join, remote homeservers may
+ # start answering our requests again, so we should restart the
+ # partial state sync.
+ # In the cases where we would want to restart the partial state sync,
+ # the room would have the partial state flag when the partial state sync
+ # terminates.
+ self._partial_state_syncs_maybe_needing_restart[room_id] = (
+ initial_destination,
+ other_destinations,
+ )
+ return
+
+ self._active_partial_state_syncs.add(room_id)
+
+ try:
+ await self._sync_partial_state_room(
+ initial_destination=initial_destination,
+ other_destinations=other_destinations,
+ room_id=room_id,
+ )
+ finally:
+ # Read the room's partial state flag while we still hold the claim to
+ # being the active partial state sync (so that another partial state
+ # sync can't come along and mess with it under us).
+ # Normally, the partial state flag will be gone. If it isn't, then we
+ # may find ourselves in scenario 2a or 2b as described in the comment
+ # above, where we want to restart the partial state sync.
+ is_still_partial_state_room = await self.store.is_partial_state_room(
+ room_id
+ )
+ self._active_partial_state_syncs.remove(room_id)
+
+ if room_id in self._partial_state_syncs_maybe_needing_restart:
+ (
+ restart_initial_destination,
+ restart_other_destinations,
+ ) = self._partial_state_syncs_maybe_needing_restart.pop(room_id)
+
+ if is_still_partial_state_room:
+ self._start_partial_state_room_sync(
+ initial_destination=restart_initial_destination,
+ other_destinations=restart_other_destinations,
+ room_id=room_id,
+ )
+
+ run_as_background_process(
+ desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper
+ )
+
async def _sync_partial_state_room(
self,
initial_destination: Optional[str],
- other_destinations: Collection[str],
+ other_destinations: AbstractSet[str],
room_id: str,
) -> None:
"""Background process to resync the state of a partial-state room
@@ -1648,6 +1823,12 @@ class FederationHandler:
`initial_destination` is unavailable
room_id: room to be resynced
"""
+ # Assume that we run on the main process for now.
+ # TODO(faster_joins,multiple workers)
+ # When moving the sync to workers, we need to ensure that
+ # * `_start_partial_state_room_sync` still prevents duplicate resyncs
+ # * `_is_partial_state_room_linearizer` correctly guards partial state flags
+ # for rooms between the workers doing remote joins and resync.
assert not self.config.worker.worker_app
# TODO(faster_joins): do we need to lock to avoid races? What happens if other
@@ -1685,17 +1866,19 @@ class FederationHandler:
logger.info("Handling any pending device list updates")
await self._device_handler.handle_room_un_partial_stated(room_id)
- logger.info("Clearing partial-state flag for %s", room_id)
- success = await self.store.clear_partial_state_room(room_id)
- if success:
+ async with self._is_partial_state_room_linearizer.queue(room_id):
+ logger.info("Clearing partial-state flag for %s", room_id)
+ new_stream_id = await self.store.clear_partial_state_room(room_id)
+
+ if new_stream_id is not None:
logger.info("State resync complete for %s", room_id)
self._storage_controllers.state.notify_room_un_partial_stated(
room_id
)
- # TODO(faster_joins) update room stats and user directory?
- # https://github.com/matrix-org/synapse/issues/12814
- # https://github.com/matrix-org/synapse/issues/12815
+ await self._notifier.on_un_partial_stated_room(
+ room_id, new_stream_id
+ )
return
# we raced against more events arriving with partial state. Go round
@@ -1766,9 +1949,9 @@ class FederationHandler:
def _prioritise_destinations_for_partial_state_resync(
initial_destination: Optional[str],
- other_destinations: Collection[str],
+ other_destinations: AbstractSet[str],
room_id: str,
-) -> Collection[str]:
+) -> StrCollection:
"""Work out the order in which we should ask servers to resync events.
If an `initial_destination` is given, it takes top priority. Otherwise
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index c8175d66b2..2e19df0976 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -43,6 +43,7 @@ from synapse.api.constants import (
from synapse.api.errors import (
AuthError,
Codes,
+ EventSizeError,
FederationError,
FederationPullAttemptBackoffError,
HttpResponseException,
@@ -75,14 +76,15 @@ from synapse.replication.http.federation import (
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
StateMap,
+ StrCollection,
UserID,
get_domain_from_id,
)
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination
@@ -609,10 +611,12 @@ class FederationEventHandler:
self._state_storage_controller.notify_event_un_partial_stated(
event.event_id
)
+ # Notify that there's a new row in the un_partial_stated_events stream.
+ self._notifier.notify_replication()
@trace
async def backfill(
- self, dest: str, room_id: str, limit: int, extremities: Collection[str]
+ self, dest: str, room_id: str, limit: int, extremities: StrCollection
) -> None:
"""Trigger a backfill request to `dest` for the given `room_id`
@@ -1420,7 +1424,7 @@ class FederationEventHandler:
"""
try:
- await self._store.mark_remote_user_device_cache_as_stale(sender)
+ await self._store.mark_remote_users_device_caches_as_stale((sender,))
# Immediately attempt a resync in the background
if self._config.worker.worker_app:
@@ -1562,7 +1566,7 @@ class FederationEventHandler:
@trace
@tag_args
async def _get_events_and_persist(
- self, destination: str, room_id: str, event_ids: Collection[str]
+ self, destination: str, room_id: str, event_ids: StrCollection
) -> None:
"""Fetch the given events from a server, and persist them as outliers.
@@ -1736,6 +1740,15 @@ class FederationEventHandler:
except AuthError as e:
logger.warning("Rejecting %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR
+ except EventSizeError as e:
+ if e.unpersistable:
+ # This event is completely unpersistable.
+ raise e
+ # Otherwise, we are somewhat lenient and just persist the event
+ # as rejected, for moderate compatibility with older Synapse
+ # versions.
+ logger.warning("While validating received event %r: %s", event, e)
+ context.rejected = RejectedReason.OVERSIZED_EVENT
events_and_contexts_to_persist.append((event, context))
@@ -1781,6 +1794,16 @@ class FederationEventHandler:
# TODO: use a different rejected reason here?
context.rejected = RejectedReason.AUTH_ERROR
return
+ except EventSizeError as e:
+ if e.unpersistable:
+ # This event is completely unpersistable.
+ raise e
+ # Otherwise, we are somewhat lenient and just persist the event
+ # as rejected, for moderate compatibility with older Synapse
+ # versions.
+ logger.warning("While validating received event %r: %s", event, e)
+ context.rejected = RejectedReason.OVERSIZED_EVENT
+ return
# next, check that we have all of the event's auth events.
#
@@ -2237,6 +2260,10 @@ class FederationEventHandler:
event_and_contexts, backfilled=backfilled
)
+ # After persistence we always need to notify replication there may
+ # be new data.
+ self._notifier.notify_replication()
+
if self._ephemeral_messages_enabled:
for event in events:
# If there's an expiry timestamp on the event, schedule its expiry.
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 9c335e6863..191529bd8e 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -15,7 +15,13 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
-from synapse.api.constants import EduTypes, EventTypes, Membership
+from synapse.api.constants import (
+ AccountDataTypes,
+ Direction,
+ EduTypes,
+ EventTypes,
+ Membership,
+)
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
@@ -57,7 +63,13 @@ class InitialSyncHandler:
self.validator = EventValidator()
self.snapshot_cache: ResponseCache[
Tuple[
- str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
+ str,
+ Optional[StreamToken],
+ Optional[StreamToken],
+ Direction,
+ int,
+ bool,
+ bool,
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
@@ -239,7 +251,7 @@ class InitialSyncHandler:
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append(
- {"type": "m.tag", "content": {"tags": tags}}
+ {"type": AccountDataTypes.TAG, "content": {"tags": tags}}
)
account_data = account_data_by_room.get(event.room_id, {})
@@ -326,7 +338,9 @@ class InitialSyncHandler:
account_data_events = []
tags = await self.store.get_tags_for_room(user_id, room_id)
if tags:
- account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
+ account_data_events.append(
+ {"type": AccountDataTypes.TAG, "content": {"tags": tags}}
+ )
account_data = await self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index f8599bab2f..7dfebdc4aa 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -37,7 +37,6 @@ from synapse.api.errors import (
AuthError,
Codes,
ConsentNotGivenError,
- LimitExceededError,
NotFoundError,
ShadowBanError,
SynapseError,
@@ -50,6 +49,7 @@ from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
+from synapse.events.utils import maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging import opentracing
@@ -59,7 +59,6 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
from synapse.types import (
MutableStateMap,
PersistedEventPosition,
@@ -70,6 +69,7 @@ from synapse.types import (
UserID,
create_requester,
)
+from synapse.types.state import StateFilter
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
@@ -377,7 +377,7 @@ class MessageHandler:
"""
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
- if not isinstance(expiry_ts, int) or event.is_state():
+ if type(expiry_ts) is not int or event.is_state():
return
# _schedule_expiry_for_event won't actually schedule anything if there's already
@@ -998,60 +998,73 @@ class EventCreationHandler:
event.internal_metadata.stream_ordering,
)
- event, context = await self.create_event(
- requester,
- event_dict,
- txn_id=txn_id,
- allow_no_prev_events=allow_no_prev_events,
- prev_event_ids=prev_event_ids,
- state_event_ids=state_event_ids,
- outlier=outlier,
- historical=historical,
- depth=depth,
- )
-
- assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
- event.sender,
- )
-
- spam_check_result = await self.spam_checker.check_event_for_spam(event)
- if spam_check_result != self.spam_checker.NOT_SPAM:
- if isinstance(spam_check_result, tuple):
- try:
- [code, dict] = spam_check_result
- raise SynapseError(
- 403,
- "This message had been rejected as probable spam",
- code,
- dict,
- )
- except ValueError:
- logger.error(
- "Spam-check module returned invalid error value. Expecting [code, dict], got %s",
- spam_check_result,
- )
-
- raise SynapseError(
- 403,
- "This message has been rejected as probable spam",
- Codes.FORBIDDEN,
- )
-
- # Backwards compatibility: if the return value is not an error code, it
- # means the module returned an error message to be included in the
- # SynapseError (which is now deprecated).
- raise SynapseError(
- 403,
- spam_check_result,
- Codes.FORBIDDEN,
+ # Try several times, it could fail with PartialStateConflictError
+ # in handle_new_client_event, cf comment in except block.
+ max_retries = 5
+ for i in range(max_retries):
+ try:
+ event, context = await self.create_event(
+ requester,
+ event_dict,
+ txn_id=txn_id,
+ allow_no_prev_events=allow_no_prev_events,
+ prev_event_ids=prev_event_ids,
+ state_event_ids=state_event_ids,
+ outlier=outlier,
+ historical=historical,
+ depth=depth,
)
- ev = await self.handle_new_client_event(
- requester=requester,
- events_and_context=[(event, context)],
- ratelimit=ratelimit,
- ignore_shadow_ban=ignore_shadow_ban,
- )
+ assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
+ event.sender,
+ )
+
+ spam_check_result = await self.spam_checker.check_event_for_spam(event)
+ if spam_check_result != self.spam_checker.NOT_SPAM:
+ if isinstance(spam_check_result, tuple):
+ try:
+ [code, dict] = spam_check_result
+ raise SynapseError(
+ 403,
+ "This message had been rejected as probable spam",
+ code,
+ dict,
+ )
+ except ValueError:
+ logger.error(
+ "Spam-check module returned invalid error value. Expecting [code, dict], got %s",
+ spam_check_result,
+ )
+
+ raise SynapseError(
+ 403,
+ "This message has been rejected as probable spam",
+ Codes.FORBIDDEN,
+ )
+
+ # Backwards compatibility: if the return value is not an error code, it
+ # means the module returned an error message to be included in the
+ # SynapseError (which is now deprecated).
+ raise SynapseError(
+ 403,
+ spam_check_result,
+ Codes.FORBIDDEN,
+ )
+
+ ev = await self.handle_new_client_event(
+ requester=requester,
+ events_and_context=[(event, context)],
+ ratelimit=ratelimit,
+ ignore_shadow_ban=ignore_shadow_ban,
+ )
+
+ break
+ except PartialStateConflictError as e:
+ # Persisting couldn't happen because the room got un-partial stated
+ # in the meantime and context needs to be recomputed, so let's do so.
+ if i == max_retries - 1:
+ raise e
+ pass
# we know it was persisted, so must have a stream ordering
assert ev.internal_metadata.stream_ordering
@@ -1135,11 +1148,13 @@ class EventCreationHandler:
)
state_events = await self.store.get_events_as_list(state_event_ids)
# Create a StateMap[str]
- state_map = {(e.type, e.state_key): e.event_id for e in state_events}
+ current_state_ids = {
+ (e.type, e.state_key): e.event_id for e in state_events
+ }
# Actually strip down and only use the necessary auth events
auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event,
- current_state_ids=state_map,
+ current_state_ids=current_state_ids,
for_verification=False,
)
@@ -1353,7 +1368,7 @@ class EventCreationHandler:
Raises:
ShadowBanError if the requester has been shadow-banned.
- SynapseError(503) if attempting to persist a partial state event in
+ PartialStateConflictError if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
extra_users = extra_users or []
@@ -1415,34 +1430,23 @@ class EventCreationHandler:
# We now persist the event (and update the cache in parallel, since we
# don't want to block on it).
event, context = events_and_context[0]
- try:
- result, _ = await make_deferred_yieldable(
- gather_results(
- (
- run_in_background(
- self._persist_events,
- requester=requester,
- events_and_context=events_and_context,
- ratelimit=ratelimit,
- extra_users=extra_users,
- ),
- run_in_background(
- self.cache_joined_hosts_for_events, events_and_context
- ).addErrback(
- log_failure, "cache_joined_hosts_for_event failed"
- ),
+ result, _ = await make_deferred_yieldable(
+ gather_results(
+ (
+ run_in_background(
+ self._persist_events,
+ requester=requester,
+ events_and_context=events_and_context,
+ ratelimit=ratelimit,
+ extra_users=extra_users,
),
- consumeErrors=True,
- )
- ).addErrback(unwrapFirstError)
- except PartialStateConflictError as e:
- # The event context needs to be recomputed.
- # Turn the error into a 429, as a hint to the client to try again.
- logger.info(
- "Room %s was un-partial stated while persisting client event.",
- event.room_id,
+ run_in_background(
+ self.cache_joined_hosts_for_events, events_and_context
+ ).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
+ ),
+ consumeErrors=True,
)
- raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0)
+ ).addErrback(unwrapFirstError)
return result
@@ -1527,12 +1531,23 @@ class EventCreationHandler:
external federation senders don't have to recalculate it themselves.
"""
- for event, _ in events_and_context:
- if not self._external_cache.is_enabled():
- return
+ if not self._external_cache.is_enabled():
+ return
- # If external cache is enabled we should always have this.
- assert self._external_cache_joined_hosts_updates is not None
+ # If external cache is enabled we should always have this.
+ assert self._external_cache_joined_hosts_updates is not None
+
+ for event, event_context in events_and_context:
+ if event_context.partial_state:
+ # To populate the cache for a partial-state event, we either have to
+ # block until full state, which the code below does, or change the
+ # meaning of cache values to be the list of hosts to which we plan to
+ # send events and calculate that instead.
+ #
+ # The federation senders don't use the external cache when sending
+ # events in partial-state rooms anyway, so let's not bother populating
+ # the cache.
+ continue
# We actually store two mappings, event ID -> prev state group,
# state group -> joined hosts, which is much more space efficient
@@ -1737,12 +1752,15 @@ class EventCreationHandler:
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
- event.unsigned[
- "invite_room_state"
- ] = await self.store.get_stripped_room_state_from_event_context(
- context,
- self.room_prejoin_state_types,
- membership_user_id=event.sender,
+ maybe_upsert_event_field(
+ event,
+ event.unsigned,
+ "invite_room_state",
+ await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_prejoin_state_types,
+ membership_user_id=event.sender,
+ ),
)
invitee = UserID.from_string(event.state_key)
@@ -1760,11 +1778,14 @@ class EventCreationHandler:
event.signatures.update(returned_invite.signatures)
if event.content["membership"] == Membership.KNOCK:
- event.unsigned[
- "knock_room_state"
- ] = await self.store.get_stripped_room_state_from_event_context(
- context,
- self.room_prejoin_state_types,
+ maybe_upsert_event_field(
+ event,
+ event.unsigned,
+ "knock_room_state",
+ await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_prejoin_state_types,
+ ),
)
if event.type == EventTypes.Redaction:
@@ -1918,7 +1939,9 @@ class EventCreationHandler:
if event.type == EventTypes.Message:
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
- run_in_background(self._bump_active_time, requester.user) # type: ignore[unused-awaitable]
+ run_as_background_process(
+ "bump_presence_active_time", self._bump_active_time, requester.user
+ )
async def _notify() -> None:
try:
@@ -2003,26 +2026,39 @@ class EventCreationHandler:
for user_id in members:
requester = create_requester(user_id, authenticated_entity=self.server_name)
try:
- event, context = await self.create_event(
- requester,
- {
- "type": EventTypes.Dummy,
- "content": {},
- "room_id": room_id,
- "sender": user_id,
- },
- )
+ # Try several times, it could fail with PartialStateConflictError
+ # in handle_new_client_event, cf comment in except block.
+ max_retries = 5
+ for i in range(max_retries):
+ try:
+ event, context = await self.create_event(
+ requester,
+ {
+ "type": EventTypes.Dummy,
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ )
- event.internal_metadata.proactively_send = False
+ event.internal_metadata.proactively_send = False
- # Since this is a dummy-event it is OK if it is sent by a
- # shadow-banned user.
- await self.handle_new_client_event(
- requester,
- events_and_context=[(event, context)],
- ratelimit=False,
- ignore_shadow_ban=True,
- )
+ # Since this is a dummy-event it is OK if it is sent by a
+ # shadow-banned user.
+ await self.handle_new_client_event(
+ requester,
+ events_and_context=[(event, context)],
+ ratelimit=False,
+ ignore_shadow_ban=True,
+ )
+
+ break
+ except PartialStateConflictError as e:
+ # Persisting couldn't happen because the room got un-partial stated
+ # in the meantime and context needs to be recomputed, so let's do so.
+ if i == max_retries - 1:
+ raise e
+ pass
return True
except AuthError:
logger.info(
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 41c675f408..0fc829acf7 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -36,6 +36,7 @@ from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
+from authlib.oauth2.rfc7636.challenge import create_s256_code_challenge
from authlib.oidc.core import CodeIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
@@ -475,6 +476,16 @@ class OidcProvider:
)
)
+ # If PKCE support is advertised ensure the wanted method is available.
+ if m.get("code_challenge_methods_supported") is not None:
+ m.validate_code_challenge_methods_supported()
+ if "S256" not in m["code_challenge_methods_supported"]:
+ raise ValueError(
+ '"S256" not in "code_challenge_methods_supported" ({supported!r})'.format(
+ supported=m["code_challenge_methods_supported"],
+ )
+ )
+
if m.get("response_types_supported") is not None:
m.validate_response_types_supported()
@@ -602,6 +613,11 @@ class OidcProvider:
if self._config.jwks_uri:
metadata["jwks_uri"] = self._config.jwks_uri
+ if self._config.pkce_method == "always":
+ metadata["code_challenge_methods_supported"] = ["S256"]
+ elif self._config.pkce_method == "never":
+ metadata.pop("code_challenge_methods_supported", None)
+
self._validate_metadata(metadata)
return metadata
@@ -653,7 +669,7 @@ class OidcProvider:
return jwk_set
- async def _exchange_code(self, code: str) -> Token:
+ async def _exchange_code(self, code: str, code_verifier: str) -> Token:
"""Exchange an authorization code for a token.
This calls the ``token_endpoint`` with the authorization code we
@@ -666,6 +682,7 @@ class OidcProvider:
Args:
code: The authorization code we got from the callback.
+ code_verifier: The PKCE code verifier to send, blank if unused.
Returns:
A dict containing various tokens.
@@ -696,6 +713,8 @@ class OidcProvider:
"code": code,
"redirect_uri": self._callback_url,
}
+ if code_verifier:
+ args["code_verifier"] = code_verifier
body = urlencode(args, True)
# Fill the body/headers with credentials
@@ -914,11 +933,14 @@ class OidcProvider:
- ``scope``: the list of scopes set in ``oidc_config.scopes``
- ``state``: a random string
- ``nonce``: a random string
+ - ``code_challenge``: a RFC7636 code challenge (if PKCE is supported)
- In addition generating a redirect URL, we are setting a cookie with
- a signed macaroon token containing the state, the nonce and the
- client_redirect_url params. Those are then checked when the client
- comes back from the provider.
+ In addition to generating a redirect URL, we are setting a cookie with
+ a signed macaroon token containing the state, the nonce, the
+ client_redirect_url, and (optionally) the code_verifier params. The state,
+ nonce, and client_redirect_url are then checked when the client comes back
+ from the provider. The code_verifier is passed back to the server during
+ the token exchange and compared to the code_challenge sent in this request.
Args:
request: the incoming request from the browser.
@@ -935,10 +957,25 @@ class OidcProvider:
state = generate_token()
nonce = generate_token()
+ code_verifier = ""
if not client_redirect_url:
client_redirect_url = b""
+ metadata = await self.load_metadata()
+
+ # Automatically enable PKCE if it is supported.
+ extra_grant_values = {}
+ if metadata.get("code_challenge_methods_supported"):
+ code_verifier = generate_token(48)
+
+ # Note that we verified the server supports S256 earlier (in
+ # OidcProvider._validate_metadata).
+ extra_grant_values = {
+ "code_challenge_method": "S256",
+ "code_challenge": create_s256_code_challenge(code_verifier),
+ }
+
cookie = self._macaroon_generaton.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
@@ -946,6 +983,7 @@ class OidcProvider:
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id or "",
+ code_verifier=code_verifier,
),
)
@@ -966,7 +1004,6 @@ class OidcProvider:
)
)
- metadata = await self.load_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
return prepare_grant_uri(
authorization_endpoint,
@@ -976,6 +1013,7 @@ class OidcProvider:
scope=self._scopes,
state=state,
nonce=nonce,
+ **extra_grant_values,
)
async def handle_oidc_callback(
@@ -1003,7 +1041,9 @@ class OidcProvider:
# Exchange the code with the provider
try:
logger.debug("Exchanging OAuth2 code for a token")
- token = await self._exchange_code(code)
+ token = await self._exchange_code(
+ code, code_verifier=session_data.code_verifier
+ )
except OidcError as e:
logger.warning("Could not exchange OAuth2 code: %s", e)
self._sso_handler.render_error(request, e.error, e.error_description)
@@ -1435,6 +1475,7 @@ class UserAttributeDict(TypedDict):
localpart: Optional[str]
confirm_localpart: bool
display_name: Optional[str]
+ picture: Optional[str] # may be omitted by older `OidcMappingProviders`
emails: List[str]
@@ -1519,7 +1560,8 @@ env.filters.update(
@attr.s(slots=True, frozen=True, auto_attribs=True)
class JinjaOidcMappingConfig:
- subject_claim: str
+ subject_template: Template
+ picture_template: Template
localpart_template: Optional[Template]
display_name_template: Optional[Template]
email_template: Optional[Template]
@@ -1538,7 +1580,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@staticmethod
def parse_config(config: dict) -> JinjaOidcMappingConfig:
- subject_claim = config.get("subject_claim", "sub")
+ def parse_template_config_with_claim(
+ option_name: str, default_claim: str
+ ) -> Template:
+ template_name = f"{option_name}_template"
+ template = config.get(template_name)
+ if not template:
+ # Convert the legacy subject_claim into a template.
+ claim = config.get(f"{option_name}_claim", default_claim)
+ template = "{{ user.%s }}" % (claim,)
+
+ try:
+ return env.from_string(template)
+ except Exception as e:
+ raise ConfigError("invalid jinja template", path=[template_name]) from e
+
+ subject_template = parse_template_config_with_claim("subject", "sub")
+ picture_template = parse_template_config_with_claim("picture", "picture")
def parse_template_config(option_name: str) -> Optional[Template]:
if option_name not in config:
@@ -1571,7 +1629,8 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
raise ConfigError("must be a bool", path=["confirm_localpart"])
return JinjaOidcMappingConfig(
- subject_claim=subject_claim,
+ subject_template=subject_template,
+ picture_template=picture_template,
localpart_template=localpart_template,
display_name_template=display_name_template,
email_template=email_template,
@@ -1580,7 +1639,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
- return userinfo[self._config.subject_claim]
+ return self._config.subject_template.render(user=userinfo).strip()
async def map_user_attributes(
self, userinfo: UserInfo, token: Token, failures: int
@@ -1611,10 +1670,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if email:
emails.append(email)
+ picture = self._config.picture_template.render(user=userinfo).strip()
+
return UserAttributeDict(
localpart=localpart,
display_name=display_name,
emails=emails,
+ picture=picture,
confirm_localpart=self._config.confirm_localpart,
)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8972f58241..f2095ce164 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, List, Optional, Set
import attr
from twisted.python.failure import Failure
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
@@ -27,9 +27,9 @@ from synapse.handlers.room import ShutdownRoomResponse
from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, StreamKeyType
+from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -391,7 +391,7 @@ class PaginationHandler:
"""
return self._delete_by_id.get(delete_id)
- def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]:
+ def get_delete_ids_by_room(self, room_id: str) -> Optional[StrCollection]:
"""Get all active delete ids by room
Args:
@@ -448,6 +448,12 @@ class PaginationHandler:
if pagin_config.from_token:
from_token = pagin_config.from_token
+ elif pagin_config.direction == Direction.FORWARDS:
+ from_token = (
+ await self.hs.get_event_sources().get_start_token_for_pagination(
+ room_id
+ )
+ )
else:
from_token = (
await self.hs.get_event_sources().get_current_token_for_pagination(
@@ -470,7 +476,7 @@ class PaginationHandler:
room_id, requester, allow_departed_users=True
)
- if pagin_config.direction == "b":
+ if pagin_config.direction == Direction.BACKWARDS:
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 7baa45f495..b4c0577e4d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -64,7 +64,13 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
-from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
+from synapse.types import (
+ JsonDict,
+ StrCollection,
+ StreamKeyType,
+ UserID,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -320,7 +326,7 @@ class BasePresenceHandler(abc.ABC):
for destination, host_states in hosts_to_states.items():
self._federation.send_presence_to_destinations(host_states, [destination])
- async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None:
+ async def send_full_presence_to_users(self, user_ids: StrCollection) -> None:
"""
Adds to the list of users who should receive a full snapshot of presence
upon their next sync. Note that this only works for local users.
@@ -1601,7 +1607,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# Having a default limit doesn't match the EventSource API, but some
# callers do not provide it. It is unused in this class.
limit: int = 0,
- room_ids: Optional[Collection[str]] = None,
+ room_ids: Optional[StrCollection] = None,
is_guest: bool = False,
explicit_room_id: Optional[str] = None,
include_offline: bool = True,
@@ -1688,14 +1694,16 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end.
- interested_and_updated_users: Collection[str]
+ interested_and_updated_users: StrCollection
if from_key is not None:
# First get all users that have had a presence update
- updated_users = stream_change_cache.get_all_entities_changed(from_key)
+ result = stream_change_cache.get_all_entities_changed(from_key)
# Cross-reference users we're interested in with those that have had updates.
- if updated_users is not None:
+ if result.hit:
+ updated_users = result.entities
+
# If we have the full list of changes for presence we can
# simply check which ones share a room with the user.
get_updates_counter.labels("stream").inc()
@@ -1764,14 +1772,14 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
Returns:
A list of presence states for the given user to receive.
"""
+ updated_users = None
if from_key:
# Only return updates since the last sync
- updated_users = self.store.presence_stream_cache.get_all_entities_changed(
- from_key
- )
- if not updated_users:
- updated_users = []
+ result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
+ if result.hit:
+ updated_users = result.entities
+ if updated_users is not None:
# Get the actual presence update for each change
users_to_state = await self.get_presence_handler().current_state_for_users(
updated_users
@@ -2118,7 +2126,7 @@ class PresenceFederationQueue:
# stream_id, destinations, user_ids)`. We don't store the full states
# for efficiency, and remote workers will already have the full states
# cached.
- self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
+ self._queue: List[Tuple[int, int, StrCollection, Set[str]]] = []
self._next_id = 1
@@ -2140,7 +2148,7 @@ class PresenceFederationQueue:
self._queue = self._queue[index:]
def send_presence_to_destinations(
- self, states: Collection[UserPresenceState], destinations: Collection[str]
+ self, states: Collection[UserPresenceState], destinations: StrCollection
) -> None:
"""Send the presence states to the given destinations.
@@ -2153,6 +2161,11 @@ class PresenceFederationQueue:
# This should only be called on a presence writer.
assert self._presence_writer
+ if not states or not destinations:
+ # Ignore calls which either don't have any new states or don't need
+ # to be sent anywhere.
+ return
+
if self._federation:
self._federation.send_presence_to_destinations(
states=states,
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index ac01582442..04c61ae3dd 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -92,7 +92,6 @@ class ReceiptsHandler:
continue
# Check if these receipts apply to a thread.
- thread_id = None
data = user_values.get("data", {})
thread_id = data.get("thread_id")
# If the thread ID is invalid, consider it missing.
@@ -316,5 +315,5 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
return events, to_key
- def get_current_key(self, direction: str = "f") -> int:
+ def get_current_key(self) -> int:
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ca1c7a1866..c611efb760 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
)
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
+from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import (
@@ -45,8 +46,8 @@ from synapse.replication.http.register import (
ReplicationRegisterServlet,
)
from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
+from synapse.types.state import StateFilter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -841,6 +842,9 @@ class RegistrationHandler:
refresh_token = None
refresh_token_id = None
+ # This can only run on the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
registered_device_id = await self.device_handler.check_device_registered(
user_id,
device_id,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 8e71dda970..0fb15391e0 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -13,17 +13,19 @@
# limitations under the License.
import enum
import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
import attr
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, StreamToken, UserID
+from synapse.types import JsonDict, Requester, UserID
+from synapse.util.async_helpers import gather_results
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -172,40 +174,6 @@ class RelationsHandler:
return return_value
- async def get_relations_for_event(
- self,
- event_id: str,
- event: EventBase,
- room_id: str,
- relation_type: str,
- ignored_users: FrozenSet[str] = frozenset(),
- ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
- """Get a list of events which relate to an event, ordered by topological ordering.
-
- Args:
- event_id: Fetch events that relate to this event ID.
- event: The matching EventBase to event_id.
- room_id: The room the event belongs to.
- relation_type: The type of relation.
- ignored_users: The users ignored by the requesting user.
-
- Returns:
- List of event IDs that match relations requested. The rows are of
- the form `{"event_id": "..."}`.
- """
-
- # Call the underlying storage method, which is cached.
- related_events, next_token = await self._main_store.get_relations_for_event(
- event_id, event, room_id, relation_type, direction="f"
- )
-
- # Filter out ignored users and convert to the expected format.
- related_events = [
- event for event in related_events if event.sender not in ignored_users
- ]
-
- return related_events, next_token
-
async def redact_events_related_to(
self,
requester: Requester,
@@ -259,51 +227,107 @@ class RelationsHandler:
e.msg,
)
- async def get_annotations_for_event(
- self,
- event_id: str,
- room_id: str,
- limit: int = 5,
- ignored_users: FrozenSet[str] = frozenset(),
- ) -> List[JsonDict]:
- """Get a list of annotations on the event, grouped by event type and
+ async def get_annotations_for_events(
+ self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
+ ) -> Dict[str, List[JsonDict]]:
+ """Get a list of annotations to the given events, grouped by event type and
aggregation key, sorted by count.
- This is used e.g. to get the what and how many reactions have happend
+ This is used e.g. to get the what and how many reactions have happened
on an event.
Args:
- event_id: Fetch events that relate to this event ID.
- room_id: The room the event belongs to.
- limit: Only fetch the `limit` groups.
+ event_ids: Fetch events that relate to these event IDs.
ignored_users: The users ignored by the requesting user.
Returns:
- List of groups of annotations that match. Each row is a dict with
- `type`, `key` and `count` fields.
+ A map of event IDs to a list of groups of annotations that match.
+ Each entry is a dict with `type`, `key` and `count` fields.
"""
# Get the base results for all users.
- full_results = await self._main_store.get_aggregation_groups_for_event(
- event_id, room_id, limit
+ full_results = await self._main_store.get_aggregation_groups_for_events(
+ event_ids
)
+ # Avoid additional logic if there are no ignored users.
+ if not ignored_users:
+ return {
+ event_id: results
+ for event_id, results in full_results.items()
+ if results
+ }
+
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_aggregation_groups_for_users(
- event_id, room_id, limit, ignored_users
+ [event_id for event_id, results in full_results.items() if results],
+ ignored_users,
)
- filtered_results = []
- for result in full_results:
- key = (result["type"], result["key"])
- if key in ignored_results:
- result = result.copy()
- result["count"] -= ignored_results[key]
- if result["count"] <= 0:
- continue
- filtered_results.append(result)
+ filtered_results = {}
+ for event_id, results in full_results.items():
+ # If no annotations, skip.
+ if not results:
+ continue
+
+ # If there are not ignored results for this event, copy verbatim.
+ if event_id not in ignored_results:
+ filtered_results[event_id] = results
+ continue
+
+ # Otherwise, subtract out the ignored results.
+ event_ignored_results = ignored_results[event_id]
+ for result in results:
+ key = (result["type"], result["key"])
+ if key in event_ignored_results:
+ # Ensure to not modify the cache.
+ result = result.copy()
+ result["count"] -= event_ignored_results[key]
+ if result["count"] <= 0:
+ continue
+ filtered_results.setdefault(event_id, []).append(result)
return filtered_results
+ async def get_references_for_events(
+ self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
+ ) -> Dict[str, List[_RelatedEvent]]:
+ """Get a list of references to the given events.
+
+ Args:
+ event_ids: Fetch events that relate to this event ID.
+ ignored_users: The users ignored by the requesting user.
+
+ Returns:
+ A map of event IDs to a list related events.
+ """
+
+ related_events = await self._main_store.get_references_for_events(event_ids)
+
+ # Avoid additional logic if there are no ignored users.
+ if not ignored_users:
+ return {
+ event_id: results
+ for event_id, results in related_events.items()
+ if results
+ }
+
+ # Filter out ignored users.
+ results = {}
+ for event_id, events in related_events.items():
+ # If no references, skip.
+ if not events:
+ continue
+
+ # Filter ignored users out.
+ events = [event for event in events if event.sender not in ignored_users]
+ # If there are no events left, skip this event.
+ if not events:
+ continue
+
+ results[event_id] = events
+
+ return results
+
async def _get_threads_for_events(
self,
events_by_id: Dict[str, EventBase],
@@ -366,59 +390,70 @@ class RelationsHandler:
results = {}
for event_id, summary in summaries.items():
- if summary:
- thread_count, latest_thread_event = summary
+ # If no thread, skip.
+ if not summary:
+ continue
- # Subtract off the count of any ignored users.
- for ignored_user in ignored_users:
- thread_count -= ignored_results.get((event_id, ignored_user), 0)
+ thread_count, latest_thread_event = summary
- # This is gnarly, but if the latest event is from an ignored user,
- # attempt to find one that isn't from an ignored user.
- if latest_thread_event.sender in ignored_users:
- room_id = latest_thread_event.room_id
+ # Subtract off the count of any ignored users.
+ for ignored_user in ignored_users:
+ thread_count -= ignored_results.get((event_id, ignored_user), 0)
- # If the root event is not found, something went wrong, do
- # not include a summary of the thread.
- event = await self._event_handler.get_event(user, room_id, event_id)
- if event is None:
- continue
+ # This is gnarly, but if the latest event is from an ignored user,
+ # attempt to find one that isn't from an ignored user.
+ if latest_thread_event.sender in ignored_users:
+ room_id = latest_thread_event.room_id
- potential_events, _ = await self.get_relations_for_event(
- event_id,
- event,
- room_id,
- RelationTypes.THREAD,
- ignored_users,
- )
+ # If the root event is not found, something went wrong, do
+ # not include a summary of the thread.
+ event = await self._event_handler.get_event(user, room_id, event_id)
+ if event is None:
+ continue
- # If all found events are from ignored users, do not include
- # a summary of the thread.
- if not potential_events:
- continue
-
- # The *last* event returned is the one that is cared about.
- event = await self._event_handler.get_event(
- user, room_id, potential_events[-1].event_id
- )
- # It is unexpected that the event will not exist.
- if event is None:
- logger.warning(
- "Unable to fetch latest event in a thread with event ID: %s",
- potential_events[-1].event_id,
- )
- continue
- latest_thread_event = event
-
- results[event_id] = _ThreadAggregation(
- latest_event=latest_thread_event,
- count=thread_count,
- # If there's a thread summary it must also exist in the
- # participated dictionary.
- current_user_participated=events_by_id[event_id].sender == user_id
- or participated[event_id],
+ # Attempt to find another event to use as the latest event.
+ potential_events, _ = await self._main_store.get_relations_for_event(
+ event_id,
+ event,
+ room_id,
+ RelationTypes.THREAD,
+ direction=Direction.FORWARDS,
)
+ # Filter out ignored users.
+ potential_events = [
+ event
+ for event in potential_events
+ if event.sender not in ignored_users
+ ]
+
+ # If all found events are from ignored users, do not include
+ # a summary of the thread.
+ if not potential_events:
+ continue
+
+ # The *last* event returned is the one that is cared about.
+ event = await self._event_handler.get_event(
+ user, room_id, potential_events[-1].event_id
+ )
+ # It is unexpected that the event will not exist.
+ if event is None:
+ logger.warning(
+ "Unable to fetch latest event in a thread with event ID: %s",
+ potential_events[-1].event_id,
+ )
+ continue
+ latest_thread_event = event
+
+ results[event_id] = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ count=thread_count,
+ # If there's a thread summary it must also exist in the
+ # participated dictionary.
+ current_user_participated=events_by_id[event_id].sender == user_id
+ or participated[event_id],
+ )
+
return results
@trace
@@ -496,49 +531,56 @@ class RelationsHandler:
# (as that is what makes it part of the thread).
relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
- # Fetch other relations per event.
- for event in events_by_id.values():
- # Fetch any annotations (ie, reactions) to bundle with this event.
- annotations = await self.get_annotations_for_event(
- event.event_id, event.room_id, ignored_users=ignored_users
+ async def _fetch_annotations() -> None:
+ """Fetch any annotations (ie, reactions) to bundle with this event."""
+ annotations_by_event_id = await self.get_annotations_for_events(
+ events_by_id.keys(), ignored_users=ignored_users
)
- if annotations:
- results.setdefault(
- event.event_id, BundledAggregations()
- ).annotations = {"chunk": annotations}
+ for event_id, annotations in annotations_by_event_id.items():
+ if annotations:
+ results.setdefault(event_id, BundledAggregations()).annotations = {
+ "chunk": annotations
+ }
- # Fetch any references to bundle with this event.
- references, next_token = await self.get_relations_for_event(
- event.event_id,
- event,
- event.room_id,
- RelationTypes.REFERENCE,
- ignored_users=ignored_users,
+ async def _fetch_references() -> None:
+ """Fetch any references to bundle with this event."""
+ references_by_event_id = await self.get_references_for_events(
+ events_by_id.keys(), ignored_users=ignored_users
)
- if references:
- aggregations = results.setdefault(event.event_id, BundledAggregations())
- aggregations.references = {
- "chunk": [{"event_id": ev.event_id} for ev in references]
- }
+ for event_id, references in references_by_event_id.items():
+ if references:
+ results.setdefault(event_id, BundledAggregations()).references = {
+ "chunk": [{"event_id": ev.event_id} for ev in references]
+ }
- if next_token:
- aggregations.references["next_batch"] = await next_token.to_string(
- self._main_store
- )
+ async def _fetch_edits() -> None:
+ """
+ Fetch any edits (but not for redacted events).
- # Fetch any edits (but not for redacted events).
- #
- # Note that there is no use in limiting edits by ignored users since the
- # parent event should be ignored in the first place if the user is ignored.
- edits = await self._main_store.get_applicable_edits(
- [
- event_id
- for event_id, event in events_by_id.items()
- if not event.internal_metadata.is_redacted()
- ]
+ Note that there is no use in limiting edits by ignored users since the
+ parent event should be ignored in the first place if the user is ignored.
+ """
+ edits = await self._main_store.get_applicable_edits(
+ [
+ event_id
+ for event_id, event in events_by_id.items()
+ if not event.internal_metadata.is_redacted()
+ ]
+ )
+ for event_id, edit in edits.items():
+ results.setdefault(event_id, BundledAggregations()).replace = edit
+
+ # Parallelize the calls for annotations, references, and edits since they
+ # are unrelated.
+ await make_deferred_yieldable(
+ gather_results(
+ (
+ run_in_background(_fetch_annotations),
+ run_in_background(_fetch_references),
+ run_in_background(_fetch_edits),
+ )
+ )
)
- for event_id, edit in edits.items():
- results.setdefault(event_id, BundledAggregations()).replace = edit
return results
@@ -571,7 +613,7 @@ class RelationsHandler:
room_id, requester, allow_departed_users=True
)
- # Note that ignored users are not passed into get_relations_for_event
+ # Note that ignored users are not passed into get_threads
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
thread_roots, next_batch = await self._main_store.get_threads(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 6dcfd86fdf..7ba7c4ff07 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,22 +20,14 @@ import random
import string
from collections import OrderedDict
from http import HTTPStatus
-from typing import (
- TYPE_CHECKING,
- Any,
- Awaitable,
- Collection,
- Dict,
- List,
- Optional,
- Tuple,
-)
+from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
import attr
from typing_extensions import TypedDict
import synapse.events.snapshot
from synapse.api.constants import (
+ Direction,
EventContentFields,
EventTypes,
GuestAccess,
@@ -62,7 +54,7 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents
from synapse.handlers.relations import BundledAggregations
from synapse.module_api import NOT_SPAM
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.state import StateFilter
+from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
@@ -72,11 +64,13 @@ from synapse.types import (
RoomID,
RoomStreamToken,
StateMap,
+ StrCollection,
StreamKeyType,
StreamToken,
UserID,
create_requester,
)
+from synapse.types.state import StateFilter
from synapse.util import stringutils
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_and_validate_server_name
@@ -207,46 +201,64 @@ class RoomCreationHandler:
new_room_id = self._generate_room_id()
- # Check whether the user has the power level to carry out the upgrade.
- # `check_auth_rules_from_context` will check that they are in the room and have
- # the required power level to send the tombstone event.
- (
- tombstone_event,
- tombstone_context,
- ) = await self.event_creation_handler.create_event(
- requester,
- {
- "type": EventTypes.Tombstone,
- "state_key": "",
- "room_id": old_room_id,
- "sender": user_id,
- "content": {
- "body": "This room has been replaced",
- "replacement_room": new_room_id,
- },
- },
- )
- validate_event_for_room_version(tombstone_event)
- await self._event_auth_handler.check_auth_rules_from_context(tombstone_event)
+ # Try several times, it could fail with PartialStateConflictError
+ # in _upgrade_room, cf comment in except block.
+ max_retries = 5
+ for i in range(max_retries):
+ try:
+ # Check whether the user has the power level to carry out the upgrade.
+ # `check_auth_rules_from_context` will check that they are in the room and have
+ # the required power level to send the tombstone event.
+ (
+ tombstone_event,
+ tombstone_context,
+ ) = await self.event_creation_handler.create_event(
+ requester,
+ {
+ "type": EventTypes.Tombstone,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": user_id,
+ "content": {
+ "body": "This room has been replaced",
+ "replacement_room": new_room_id,
+ },
+ },
+ )
+ validate_event_for_room_version(tombstone_event)
+ await self._event_auth_handler.check_auth_rules_from_context(
+ tombstone_event
+ )
- # Upgrade the room
- #
- # If this user has sent multiple upgrade requests for the same room
- # and one of them is not complete yet, cache the response and
- # return it to all subsequent requests
- ret = await self._upgrade_response_cache.wrap(
- (old_room_id, user_id),
- self._upgrade_room,
- requester,
- old_room_id,
- old_room, # args for _upgrade_room
- new_room_id,
- new_version,
- tombstone_event,
- tombstone_context,
- )
+ # Upgrade the room
+ #
+ # If this user has sent multiple upgrade requests for the same room
+ # and one of them is not complete yet, cache the response and
+ # return it to all subsequent requests
+ ret = await self._upgrade_response_cache.wrap(
+ (old_room_id, user_id),
+ self._upgrade_room,
+ requester,
+ old_room_id,
+ old_room, # args for _upgrade_room
+ new_room_id,
+ new_version,
+ tombstone_event,
+ tombstone_context,
+ )
- return ret
+ return ret
+ except PartialStateConflictError as e:
+ # Clean up the cache so we can retry properly
+ self._upgrade_response_cache.unset((old_room_id, user_id))
+ # Persisting couldn't happen because the room got un-partial stated
+ # in the meantime and context needs to be recomputed, so let's do so.
+ if i == max_retries - 1:
+ raise e
+ pass
+
+ # This is to satisfy mypy and should never happen
+ raise PartialStateConflictError()
async def _upgrade_room(
self,
@@ -1476,7 +1488,7 @@ class TimestampLookupHandler:
requester: Requester,
room_id: str,
timestamp: int,
- direction: str,
+ direction: Direction,
) -> Tuple[str, int]:
"""Find the closest event to the given timestamp in the given direction.
If we can't find an event locally or the event we have locally is next to a gap,
@@ -1487,7 +1499,7 @@ class TimestampLookupHandler:
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
+ direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
@@ -1522,13 +1534,13 @@ class TimestampLookupHandler:
local_event_id, allow_none=False, allow_rejected=False
)
- if direction == "f":
+ if direction == Direction.FORWARDS:
# We only need to check for a backward gap if we're looking forwards
# to ensure there is nothing in between.
is_event_next_to_backward_gap = (
await self.store.is_event_next_to_backward_gap(local_event)
)
- elif direction == "b":
+ elif direction == Direction.BACKWARDS:
# We only need to check for a forward gap if we're looking backwards
# to ensure there is nothing in between
is_event_next_to_forward_gap = (
@@ -1625,7 +1637,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
user: UserID,
from_key: RoomStreamToken,
limit: int,
- room_ids: Collection[str],
+ room_ids: StrCollection,
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 411a6fb22f..c73d2adaad 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -375,6 +375,8 @@ class RoomBatchHandler:
# Events are sorted by (topological_ordering, stream_ordering)
# where topological_ordering is just depth.
for (event, context) in reversed(events_to_persist):
+ # This call can't raise `PartialStateConflictError` since we forbid
+ # use of the historical batch API during partial state
await self.event_creation_handler.handle_new_client_event(
await self.create_requester_for_user_id_from_app_service(
event.sender, app_service_requester.app_service
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 6ad2b38b8f..d236cc09b5 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -34,7 +34,7 @@ from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.logging import opentracing
from synapse.module_api import NOT_SPAM
-from synapse.storage.state import StateFilter
+from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.types import (
JsonDict,
Requester,
@@ -45,6 +45,7 @@ from synapse.types import (
create_requester,
get_domain_from_id,
)
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@@ -392,60 +393,81 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
event_pos = await self.store.get_position_for_event(existing_event_id)
return existing_event_id, event_pos.stream
- event, context = await self.event_creation_handler.create_event(
- requester,
- {
- "type": EventTypes.Member,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- "state_key": user_id,
- # For backwards compatibility:
- "membership": membership,
- "origin_server_ts": origin_server_ts,
- },
- txn_id=txn_id,
- allow_no_prev_events=allow_no_prev_events,
- prev_event_ids=prev_event_ids,
- state_event_ids=state_event_ids,
- depth=depth,
- require_consent=require_consent,
- outlier=outlier,
- historical=historical,
- )
-
- prev_state_ids = await context.get_prev_state_ids(
- StateFilter.from_types([(EventTypes.Member, None)])
- )
-
- prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
-
- if event.membership == Membership.JOIN:
- newly_joined = True
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
- newly_joined = prev_member_event.membership != Membership.JOIN
-
- # Only rate-limit if the user actually joined the room, otherwise we'll end
- # up blocking profile updates.
- if newly_joined and ratelimit:
- await self._join_rate_limiter_local.ratelimit(requester)
- await self._join_rate_per_room_limiter.ratelimit(
- requester, key=room_id, update=False
+ # Try several times, it could fail with PartialStateConflictError,
+ # in handle_new_client_event, cf comment in except block.
+ max_retries = 5
+ for i in range(max_retries):
+ try:
+ event, context = await self.event_creation_handler.create_event(
+ requester,
+ {
+ "type": EventTypes.Member,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ "state_key": user_id,
+ # For backwards compatibility:
+ "membership": membership,
+ "origin_server_ts": origin_server_ts,
+ },
+ txn_id=txn_id,
+ allow_no_prev_events=allow_no_prev_events,
+ prev_event_ids=prev_event_ids,
+ state_event_ids=state_event_ids,
+ depth=depth,
+ require_consent=require_consent,
+ outlier=outlier,
+ historical=historical,
)
- with opentracing.start_active_span("handle_new_client_event"):
- result_event = await self.event_creation_handler.handle_new_client_event(
- requester,
- events_and_context=[(event, context)],
- extra_users=[target],
- ratelimit=ratelimit,
- )
- if event.membership == Membership.LEAVE:
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
- if prev_member_event.membership == Membership.JOIN:
- await self._user_left_room(target, room_id)
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.Member, None)])
+ )
+
+ prev_member_event_id = prev_state_ids.get(
+ (EventTypes.Member, user_id), None
+ )
+
+ if event.membership == Membership.JOIN:
+ newly_joined = True
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(
+ prev_member_event_id
+ )
+ newly_joined = prev_member_event.membership != Membership.JOIN
+
+ # Only rate-limit if the user actually joined the room, otherwise we'll end
+ # up blocking profile updates.
+ if newly_joined and ratelimit:
+ await self._join_rate_limiter_local.ratelimit(requester)
+ await self._join_rate_per_room_limiter.ratelimit(
+ requester, key=room_id, update=False
+ )
+ with opentracing.start_active_span("handle_new_client_event"):
+ result_event = (
+ await self.event_creation_handler.handle_new_client_event(
+ requester,
+ events_and_context=[(event, context)],
+ extra_users=[target],
+ ratelimit=ratelimit,
+ )
+ )
+
+ if event.membership == Membership.LEAVE:
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(
+ prev_member_event_id
+ )
+ if prev_member_event.membership == Membership.JOIN:
+ await self._user_left_room(target, room_id)
+
+ break
+ except PartialStateConflictError as e:
+ # Persisting couldn't happen because the room got un-partial stated
+ # in the meantime and context needs to be recomputed, so let's do so.
+ if i == max_retries - 1:
+ raise e
+ pass
# we know it was persisted, so should have a stream ordering
assert result_event.internal_metadata.stream_ordering
@@ -1234,6 +1256,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit: Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
+ PartialStateConflictError: if attempting to persist a partial state event in
+ a room that has been un-partial stated.
"""
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
@@ -1863,21 +1887,37 @@ class RoomMemberMasterHandler(RoomMemberHandler):
list(previous_membership_event.auth_event_ids()) + prev_event_ids
)
- event, context = await self.event_creation_handler.create_event(
- requester,
- event_dict,
- txn_id=txn_id,
- prev_event_ids=prev_event_ids,
- auth_event_ids=auth_event_ids,
- outlier=True,
- )
- event.internal_metadata.out_of_band_membership = True
+ # Try several times, it could fail with PartialStateConflictError
+ # in handle_new_client_event, cf comment in except block.
+ max_retries = 5
+ for i in range(max_retries):
+ try:
+ event, context = await self.event_creation_handler.create_event(
+ requester,
+ event_dict,
+ txn_id=txn_id,
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
+ outlier=True,
+ )
+ event.internal_metadata.out_of_band_membership = True
+
+ result_event = (
+ await self.event_creation_handler.handle_new_client_event(
+ requester,
+ events_and_context=[(event, context)],
+ extra_users=[UserID.from_string(target_user)],
+ )
+ )
+
+ break
+ except PartialStateConflictError as e:
+ # Persisting couldn't happen because the room got un-partial stated
+ # in the meantime and context needs to be recomputed, so let's do so.
+ if i == max_retries - 1:
+ raise e
+ pass
- result_event = await self.event_creation_handler.handle_new_client_event(
- requester,
- events_and_context=[(event, context)],
- extra_users=[UserID.from_string(target_user)],
- )
# we know it was persisted, so must have a stream ordering
assert result_event.internal_metadata.stream_ordering
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 8d08625237..4472019fbc 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set,
import attr
from synapse.api.constants import (
- EventContentFields,
EventTypes,
HistoryVisibility,
JoinRules,
@@ -37,7 +36,7 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
-from synapse.types import JsonDict, Requester
+from synapse.types import JsonDict, Requester, StrCollection
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -701,13 +700,6 @@ class RoomSummaryHandler:
# there should always be an entry
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
- current_state_ids = await self._storage_controllers.state.get_current_state_ids(
- room_id
- )
- create_event = await self._store.get_event(
- current_state_ids[(EventTypes.Create, "")]
- )
-
entry = {
"room_id": stats["room_id"],
"name": stats["name"],
@@ -720,7 +712,7 @@ class RoomSummaryHandler:
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
),
"guest_can_join": stats["guest_access"] == "can_join",
- "room_type": create_event.content.get(EventContentFields.ROOM_TYPE),
+ "room_type": stats["room_type"],
}
if self._msc3266_enabled:
@@ -730,7 +722,11 @@ class RoomSummaryHandler:
# Federation requests need to provide additional information so the
# requested server is able to filter the response appropriately.
if for_federation:
+ current_state_ids = (
+ await self._storage_controllers.state.get_current_state_ids(room_id)
+ )
room_version = await self._store.get_room_version(room_id)
+
if await self._event_auth_handler.has_restricted_join_rules(
current_state_ids, room_version
):
@@ -874,7 +870,7 @@ class _RoomQueueEntry:
# The room ID of this entry.
room_id: str
# The server to query if the room is not known locally.
- via: Sequence[str]
+ via: StrCollection
# The minimum number of hops necessary to get to this room (compared to the
# originally requested room).
depth: int = 0
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index bcab98c6d5..9bbf83047d 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -14,7 +14,7 @@
import itertools
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
from unpaddedbase64 import decode_base64, encode_base64
@@ -23,8 +23,8 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
-from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StreamKeyType, UserID
+from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
+from synapse.types.state import StateFilter
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -275,7 +275,7 @@ class SearchHandler:
)
room_ids = {r.room_id for r in rooms}
- # If doing a subset of all rooms seearch, check if any of the rooms
+ # If doing a subset of all rooms search, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
historical_room_ids: List[str] = []
@@ -418,7 +418,7 @@ class SearchHandler:
async def _search_by_rank(
self,
user: UserID,
- room_ids: Collection[str],
+ room_ids: StrCollection,
search_term: str,
keys: Iterable[str],
search_filter: Filter,
@@ -491,7 +491,7 @@ class SearchHandler:
async def _search_by_recent(
self,
user: UserID,
- room_ids: Collection[str],
+ room_ids: StrCollection,
search_term: str,
keys: Iterable[str],
search_filter: Filter,
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 73861bbd40..bd9d0bb34b 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -15,6 +15,7 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.types import Requester
if TYPE_CHECKING:
@@ -29,7 +30,10 @@ class SetPasswordHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._auth_handler = hs.get_auth_handler()
- self._device_handler = hs.get_device_handler()
+ # This can only be instantiated on the main process.
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self._device_handler = device_handler
async def set_password(
self,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 749d7e93b0..4a27c0f051 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
+import hashlib
+import io
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
- Collection,
Dict,
Iterable,
List,
@@ -37,6 +38,7 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.device import DeviceHandler
from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
@@ -44,6 +46,7 @@ from synapse.http.server import respond_with_html, respond_with_redirect
from synapse.http.site import SynapseRequest
from synapse.types import (
JsonDict,
+ StrCollection,
UserID,
contains_invalid_mxid_characters,
create_requester,
@@ -137,7 +140,9 @@ class UserAttributes:
localpart: Optional[str]
confirm_localpart: bool = False
display_name: Optional[str] = None
- emails: Collection[str] = attr.Factory(list)
+ picture: Optional[str] = None
+ # mypy thinks these are incompatible for some reason.
+ emails: StrCollection = attr.Factory(list) # type: ignore[assignment]
@attr.s(slots=True, auto_attribs=True)
@@ -155,7 +160,7 @@ class UsernameMappingSession:
# attributes returned by the ID mapper
display_name: Optional[str]
- emails: Collection[str]
+ emails: StrCollection
# An optional dictionary of extra attributes to be provided to the client in the
# login response.
@@ -170,7 +175,7 @@ class UsernameMappingSession:
# choices made by the user
chosen_localpart: Optional[str] = None
use_display_name: bool = True
- emails_to_use: Collection[str] = ()
+ emails_to_use: StrCollection = ()
terms_accepted_version: Optional[str] = None
@@ -195,6 +200,10 @@ class SsoHandler:
self._error_template = hs.config.sso.sso_error_template
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
self._profile_handler = hs.get_profile_handler()
+ self._media_repo = (
+ hs.get_media_repository() if hs.config.media.can_load_media_repo else None
+ )
+ self._http_client = hs.get_proxied_blacklisted_http_client()
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
@@ -494,6 +503,8 @@ class SsoHandler:
await self._profile_handler.set_displayname(
user_id_obj, requester, attributes.display_name, True
)
+ if attributes.picture:
+ await self.set_avatar(user_id, attributes.picture)
await self._auth_handler.complete_sso_login(
user_id,
@@ -702,8 +713,110 @@ class SsoHandler:
await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
+
+ # Set avatar, if available
+ if attributes.picture:
+ await self.set_avatar(registered_user_id, attributes.picture)
+
return registered_user_id
+ async def set_avatar(self, user_id: str, picture_https_url: str) -> bool:
+ """Set avatar of the user.
+
+ This downloads the image file from the URL provided, stores that in
+ the media repository and then sets the avatar on the user's profile.
+
+ It can detect if the same image is being saved again and bails early by storing
+ the hash of the file in the `upload_name` of the avatar image.
+
+ Currently, it only supports server configurations which run the media repository
+ within the same process.
+
+ It silently fails and logs a warning by raising an exception and catching it
+ internally if:
+ * it is unable to fetch the image itself (non 200 status code) or
+ * the image supplied is bigger than max allowed size or
+ * the image type is not one of the allowed image types.
+
+ Args:
+ user_id: matrix user ID in the form @localpart:domain as a string.
+
+ picture_https_url: HTTPS url for the picture image file.
+
+ Returns: `True` if the user's avatar has been successfully set to the image at
+ `picture_https_url`.
+ """
+ if self._media_repo is None:
+ logger.info(
+ "failed to set user avatar because out-of-process media repositories "
+ "are not supported yet "
+ )
+ return False
+
+ try:
+ uid = UserID.from_string(user_id)
+
+ def is_allowed_mime_type(content_type: str) -> bool:
+ if (
+ self._profile_handler.allowed_avatar_mimetypes
+ and content_type
+ not in self._profile_handler.allowed_avatar_mimetypes
+ ):
+ return False
+ return True
+
+ # download picture, enforcing size limit & mime type check
+ picture = io.BytesIO()
+
+ content_length, headers, uri, code = await self._http_client.get_file(
+ url=picture_https_url,
+ output_stream=picture,
+ max_size=self._profile_handler.max_avatar_size,
+ is_allowed_content_type=is_allowed_mime_type,
+ )
+
+ if code != 200:
+ raise Exception(
+ "GET request to download sso avatar image returned {}".format(code)
+ )
+
+ # upload name includes hash of the image file's content so that we can
+ # easily check if it requires an update or not, the next time user logs in
+ upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest()
+
+ # bail if user already has the same avatar
+ profile = await self._profile_handler.get_profile(user_id)
+ if profile["avatar_url"] is not None:
+ server_name = profile["avatar_url"].split("/")[-2]
+ media_id = profile["avatar_url"].split("/")[-1]
+ if server_name == self._server_name:
+ media = await self._media_repo.store.get_local_media(media_id)
+ if media is not None and upload_name == media["upload_name"]:
+ logger.info("skipping saving the user avatar")
+ return True
+
+ # store it in media repository
+ avatar_mxc_url = await self._media_repo.create_content(
+ media_type=headers[b"Content-Type"][0].decode("utf-8"),
+ upload_name=upload_name,
+ content=picture,
+ content_length=content_length,
+ auth_user=uid,
+ )
+
+ # save it as user avatar
+ await self._profile_handler.set_avatar_url(
+ uid,
+ create_requester(uid),
+ str(avatar_mxc_url),
+ )
+
+ logger.info("successfully saved the user avatar")
+ return True
+ except Exception:
+ logger.warning("failed to save the user avatar")
+ return False
+
async def complete_sso_ui_auth_request(
self,
auth_provider_id: str,
@@ -1035,6 +1148,8 @@ class SsoHandler:
) -> None:
"""Revoke any devices and in-flight logins tied to a provider session.
+ Can only be called from the main process.
+
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
@@ -1042,6 +1157,12 @@ class SsoHandler:
expected_user_id: The user we're expecting to logout. If set, it will ignore
sessions belonging to other users and log an error.
"""
+
+ # It is expected that this is the main process.
+ assert isinstance(
+ self._device_handler, DeviceHandler
+ ), "revoking SSO sessions can only be called on the main process"
+
# Invalidate any running user-mapping sessions
to_delete = []
for session_id, session in self._username_mapping_sessions.items():
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 259456b55d..3566537894 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -17,7 +17,6 @@ from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
- Collection,
Dict,
FrozenSet,
List,
@@ -31,19 +30,30 @@ from typing import (
import attr
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import (
+ AccountDataTypes,
+ EventContentFields,
+ EventTypes,
+ Membership,
+)
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.handlers.relations import BundledAggregations
+from synapse.logging import issue9533_logger
from synapse.logging.context import current_context
-from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
+from synapse.logging.opentracing import (
+ SynapseTags,
+ log_kv,
+ set_tag,
+ start_active_span,
+ trace,
+)
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import RoomNotifCounts
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.roommember import MemberSummary
-from synapse.storage.state import StateFilter
from synapse.types import (
DeviceListUpdates,
JsonDict,
@@ -51,10 +61,12 @@ from synapse.types import (
Requester,
RoomStreamToken,
StateMap,
+ StrCollection,
StreamKeyType,
StreamToken,
UserID,
)
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache
@@ -278,7 +290,7 @@ class SyncHandler:
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
- self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync
+ self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
async def wait_for_sync_for_user(
self,
@@ -1167,7 +1179,7 @@ class SyncHandler:
async def _find_missing_partial_state_memberships(
self,
room_id: str,
- members_to_fetch: Collection[str],
+ members_to_fetch: StrCollection,
events_with_membership_auth: Mapping[str, EventBase],
found_state_ids: StateMap[str],
) -> StateMap[str]:
@@ -1328,7 +1340,10 @@ class SyncHandler:
membership_change_events = []
if since_token:
membership_change_events = await self.store.get_membership_changes_for_user(
- user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
+ user_id,
+ since_token.room_key,
+ now_token.room_key,
+ self.rooms_to_exclude_globally,
)
mem_last_change_by_room_id: Dict[str, EventBase] = {}
@@ -1363,12 +1378,49 @@ class SyncHandler:
else:
mutable_joined_room_ids.discard(room_id)
+ # Tweak the set of rooms to return to the client for eager (non-lazy) syncs.
+ mutable_rooms_to_exclude = set(self.rooms_to_exclude_globally)
+ if not sync_config.filter_collection.lazy_load_members():
+ # Non-lazy syncs should never include partially stated rooms.
+ # Exclude all partially stated rooms from this sync.
+ results = await self.store.is_partial_state_room_batched(
+ mutable_joined_room_ids
+ )
+ mutable_rooms_to_exclude.update(
+ room_id
+ for room_id, is_partial_state in results.items()
+ if is_partial_state
+ )
+
+ # Incremental eager syncs should additionally include rooms that
+ # - we are joined to
+ # - are full-stated
+ # - became fully-stated at some point during the sync period
+ # (These rooms will have been omitted during a previous eager sync.)
+ forced_newly_joined_room_ids: Set[str] = set()
+ if since_token and not sync_config.filter_collection.lazy_load_members():
+ un_partial_stated_rooms = (
+ await self.store.get_un_partial_stated_rooms_between(
+ since_token.un_partial_stated_rooms_key,
+ now_token.un_partial_stated_rooms_key,
+ mutable_joined_room_ids,
+ )
+ )
+ results = await self.store.is_partial_state_room_batched(
+ un_partial_stated_rooms
+ )
+ forced_newly_joined_room_ids.update(
+ room_id
+ for room_id, is_partial_state in results.items()
+ if not is_partial_state
+ )
+
# Now we have our list of joined room IDs, exclude as configured and freeze
joined_room_ids = frozenset(
(
room_id
for room_id in mutable_joined_room_ids
- if room_id not in self.rooms_to_exclude
+ if room_id not in mutable_rooms_to_exclude
)
)
@@ -1385,6 +1437,8 @@ class SyncHandler:
since_token=since_token,
now_token=now_token,
joined_room_ids=joined_room_ids,
+ excluded_room_ids=frozenset(mutable_rooms_to_exclude),
+ forced_newly_joined_room_ids=frozenset(forced_newly_joined_room_ids),
membership_change_events=membership_change_events,
)
@@ -1394,46 +1448,77 @@ class SyncHandler:
sync_result_builder
)
- logger.debug("Fetching room data")
-
- res = await self._generate_sync_entry_for_rooms(
- sync_result_builder, account_data_by_room
+ # Presence data is included if the server has it enabled and not filtered out.
+ include_presence_data = bool(
+ self.hs_config.server.use_presence
+ and not sync_config.filter_collection.blocks_all_presence()
)
- newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res
- _, _, newly_left_rooms, newly_left_users = res
+ # Device list updates are sent if a since token is provided.
+ include_device_list_updates = bool(since_token and since_token.device_list_key)
- block_all_presence_data = (
- since_token is None and sync_config.filter_collection.blocks_all_presence()
- )
- if self.hs_config.server.use_presence and not block_all_presence_data:
- logger.debug("Fetching presence data")
- await self._generate_sync_entry_for_presence(
- sync_result_builder,
+ # If we do not care about the rooms or things which depend on the room
+ # data (namely presence and device list updates), then we can skip
+ # this process completely.
+ device_lists = DeviceListUpdates()
+ if (
+ not sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+ or include_presence_data
+ or include_device_list_updates
+ ):
+ logger.debug("Fetching room data")
+
+ # Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
+ # is used in calculate_user_changes below.
+ (
newly_joined_rooms,
- newly_joined_or_invited_or_knocked_users,
+ newly_left_rooms,
+ ) = await self._generate_sync_entry_for_rooms(
+ sync_result_builder, account_data_by_room
)
+ # Work out which users have joined or left rooms we're in. We use this
+ # to build the presence and device_list parts of the sync response in
+ # `_generate_sync_entry_for_presence` and
+ # `_generate_sync_entry_for_device_list` respectively.
+ if include_presence_data or include_device_list_updates:
+ # This uses the sync_result_builder.joined which is set in
+ # `_generate_sync_entry_for_rooms`, if that didn't find any joined
+ # rooms for some reason it is a no-op.
+ (
+ newly_joined_or_invited_or_knocked_users,
+ newly_left_users,
+ ) = sync_result_builder.calculate_user_changes()
+
+ if include_presence_data:
+ logger.debug("Fetching presence data")
+ await self._generate_sync_entry_for_presence(
+ sync_result_builder,
+ newly_joined_rooms,
+ newly_joined_or_invited_or_knocked_users,
+ )
+
+ if include_device_list_updates:
+ device_lists = await self._generate_sync_entry_for_device_list(
+ sync_result_builder,
+ newly_joined_rooms=newly_joined_rooms,
+ newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
+ newly_left_rooms=newly_left_rooms,
+ newly_left_users=newly_left_users,
+ )
+
logger.debug("Fetching to-device data")
await self._generate_sync_entry_for_to_device(sync_result_builder)
- device_lists = await self._generate_sync_entry_for_device_list(
- sync_result_builder,
- newly_joined_rooms=newly_joined_rooms,
- newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
- newly_left_rooms=newly_left_rooms,
- newly_left_users=newly_left_users,
- )
-
logger.debug("Fetching OTK data")
device_id = sync_config.device_id
- one_time_key_counts: JsonDict = {}
+ one_time_keys_count: JsonDict = {}
unused_fallback_key_types: List[str] = []
if device_id:
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
- one_time_key_counts = await self.store.count_e2e_one_time_keys(
+ one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = (
@@ -1463,7 +1548,7 @@ class SyncHandler:
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
- device_one_time_keys_count=one_time_key_counts,
+ device_one_time_keys_count=one_time_keys_count,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)
@@ -1492,6 +1577,7 @@ class SyncHandler:
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
+ assert since_token is not None
# Take a copy since these fields will be mutated later.
newly_joined_or_invited_or_knocked_users = set(
@@ -1499,91 +1585,87 @@ class SyncHandler:
)
newly_left_users = set(newly_left_users)
- if since_token and since_token.device_list_key:
- # We want to figure out what user IDs the client should refetch
- # device keys for, and which users we aren't going to track changes
- # for anymore.
- #
- # For the first step we check:
- # a. if any users we share a room with have updated their devices,
- # and
- # b. we also check if we've joined any new rooms, or if a user has
- # joined a room we're in.
- #
- # For the second step we just find any users we no longer share a
- # room with by looking at all users that have left a room plus users
- # that were in a room we've left.
+ # We want to figure out what user IDs the client should refetch
+ # device keys for, and which users we aren't going to track changes
+ # for anymore.
+ #
+ # For the first step we check:
+ # a. if any users we share a room with have updated their devices,
+ # and
+ # b. we also check if we've joined any new rooms, or if a user has
+ # joined a room we're in.
+ #
+ # For the second step we just find any users we no longer share a
+ # room with by looking at all users that have left a room plus users
+ # that were in a room we've left.
- users_that_have_changed = set()
+ users_that_have_changed = set()
- joined_rooms = sync_result_builder.joined_room_ids
+ joined_rooms = sync_result_builder.joined_room_ids
- # Step 1a, check for changes in devices of users we share a room
- # with
- #
- # We do this in two different ways depending on what we have cached.
- # If we already have a list of all the user that have changed since
- # the last sync then it's likely more efficient to compare the rooms
- # they're in with the rooms the syncing user is in.
- #
- # If we don't have that info cached then we get all the users that
- # share a room with our user and check if those users have changed.
- changed_users = self.store.get_cached_device_list_changes(
- since_token.device_list_key
- )
- if changed_users is not None:
- result = await self.store.get_rooms_for_users(changed_users)
+ # Step 1a, check for changes in devices of users we share a room
+ # with
+ #
+ # We do this in two different ways depending on what we have cached.
+ # If we already have a list of all the user that have changed since
+ # the last sync then it's likely more efficient to compare the rooms
+ # they're in with the rooms the syncing user is in.
+ #
+ # If we don't have that info cached then we get all the users that
+ # share a room with our user and check if those users have changed.
+ cache_result = self.store.get_cached_device_list_changes(
+ since_token.device_list_key
+ )
+ if cache_result.hit:
+ changed_users = cache_result.entities
- for changed_user_id, entries in result.items():
- # Check if the changed user shares any rooms with the user,
- # or if the changed user is the syncing user (as we always
- # want to include device list updates of their own devices).
- if user_id == changed_user_id or any(
- rid in joined_rooms for rid in entries
- ):
- users_that_have_changed.add(changed_user_id)
- else:
- users_that_have_changed = (
- await self._device_handler.get_device_changes_in_shared_rooms(
- user_id,
- sync_result_builder.joined_room_ids,
- from_token=since_token,
- )
- )
+ result = await self.store.get_rooms_for_users(changed_users)
- # Step 1b, check for newly joined rooms
- for room_id in newly_joined_rooms:
- joined_users = await self.store.get_users_in_room(room_id)
- newly_joined_or_invited_or_knocked_users.update(joined_users)
-
- # TODO: Check that these users are actually new, i.e. either they
- # weren't in the previous sync *or* they left and rejoined.
- users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
-
- user_signatures_changed = (
- await self.store.get_users_whose_signatures_changed(
- user_id, since_token.device_list_key
- )
- )
- users_that_have_changed.update(user_signatures_changed)
-
- # Now find users that we no longer track
- for room_id in newly_left_rooms:
- left_users = await self.store.get_users_in_room(room_id)
- newly_left_users.update(left_users)
-
- # Remove any users that we still share a room with.
- left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
- for user_id, entries in left_users_rooms.items():
- if any(rid in joined_rooms for rid in entries):
- newly_left_users.discard(user_id)
-
- return DeviceListUpdates(
- changed=users_that_have_changed, left=newly_left_users
- )
+ for changed_user_id, entries in result.items():
+ # Check if the changed user shares any rooms with the user,
+ # or if the changed user is the syncing user (as we always
+ # want to include device list updates of their own devices).
+ if user_id == changed_user_id or any(
+ rid in joined_rooms for rid in entries
+ ):
+ users_that_have_changed.add(changed_user_id)
else:
- return DeviceListUpdates()
+ users_that_have_changed = (
+ await self._device_handler.get_device_changes_in_shared_rooms(
+ user_id,
+ sync_result_builder.joined_room_ids,
+ from_token=since_token,
+ )
+ )
+ # Step 1b, check for newly joined rooms
+ for room_id in newly_joined_rooms:
+ joined_users = await self.store.get_users_in_room(room_id)
+ newly_joined_or_invited_or_knocked_users.update(joined_users)
+
+ # TODO: Check that these users are actually new, i.e. either they
+ # weren't in the previous sync *or* they left and rejoined.
+ users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
+
+ user_signatures_changed = await self.store.get_users_whose_signatures_changed(
+ user_id, since_token.device_list_key
+ )
+ users_that_have_changed.update(user_signatures_changed)
+
+ # Now find users that we no longer track
+ for room_id in newly_left_rooms:
+ left_users = await self.store.get_users_in_room(room_id)
+ newly_left_users.update(left_users)
+
+ # Remove any users that we still share a room with.
+ left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
+ for user_id, entries in left_users_rooms.items():
+ if any(rid in joined_rooms for rid in entries):
+ newly_left_users.discard(user_id)
+
+ return DeviceListUpdates(changed=users_that_have_changed, left=newly_left_users)
+
+ @trace
async def _generate_sync_entry_for_to_device(
self, sync_result_builder: "SyncResultBuilder"
) -> None:
@@ -1603,19 +1685,29 @@ class SyncHandler:
)
for message in messages:
- # We pop here as we shouldn't be sending the message ID down
- # `/sync`
- message_id = message.pop("message_id", None)
- if message_id:
- set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+ log_kv(
+ {
+ "event": "to_device_message",
+ "sender": message["sender"],
+ "type": message["type"],
+ EventContentFields.TO_DEVICE_MSGID: message["content"].get(
+ EventContentFields.TO_DEVICE_MSGID
+ ),
+ }
+ )
- logger.debug(
- "Returning %d to-device messages between %d and %d (current token: %d)",
- len(messages),
- since_stream_id,
- stream_id,
- now_token.to_device_key,
- )
+ if messages and issue9533_logger.isEnabledFor(logging.DEBUG):
+ issue9533_logger.debug(
+ "Returning to-device messages with stream_ids (%d, %d]; now: %d;"
+ " msgids: %s",
+ since_stream_id,
+ stream_id,
+ now_token.to_device_key,
+ [
+ message["content"].get(EventContentFields.TO_DEVICE_MSGID)
+ for message in messages
+ ],
+ )
sync_result_builder.now_token = now_token.copy_and_replace(
StreamKeyType.TO_DEVICE, stream_id
)
@@ -1648,6 +1740,7 @@ class SyncHandler:
since_token = sync_result_builder.since_token
if since_token and not sync_result_builder.full_state:
+ # TODO Do not fetch room account data if it will be unused.
(
global_account_data,
account_data_by_room,
@@ -1664,6 +1757,7 @@ class SyncHandler:
sync_config.user
)
else:
+ # TODO Do not fetch room account data if it will be unused.
(
global_account_data,
account_data_by_room,
@@ -1746,7 +1840,7 @@ class SyncHandler:
self,
sync_result_builder: "SyncResultBuilder",
account_data_by_room: Dict[str, Dict[str, JsonDict]],
- ) -> Tuple[AbstractSet[str], AbstractSet[str], AbstractSet[str], AbstractSet[str]]:
+ ) -> Tuple[AbstractSet[str], AbstractSet[str]]:
"""Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result.
@@ -1759,25 +1853,21 @@ class SyncHandler:
account_data_by_room: Dictionary of per room account data
Returns:
- Returns a 4-tuple describing rooms the user has joined or left, and users who've
- joined or left rooms any rooms the user is in. This gets used later in
- `_generate_sync_entry_for_device_list`.
+ Returns a 2-tuple describing rooms the user has joined or left.
Its entries are:
- newly_joined_rooms
- - newly_joined_or_invited_or_knocked_users
- newly_left_rooms
- - newly_left_users
"""
+
since_token = sync_result_builder.since_token
+ user_id = sync_result_builder.sync_config.user.to_string()
# 1. Start by fetching all ephemeral events in rooms we've joined (if required).
- user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
- since_token is None
- and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
+ sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+ or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
-
if block_all_room_ephemeral:
ephemeral_by_room: Dict[str, List[JsonDict]] = {}
else:
@@ -1800,19 +1890,21 @@ class SyncHandler:
)
if not tags_by_room:
logger.debug("no-oping sync")
- return set(), set(), set(), set()
+ return set(), set()
# 3. Work out which rooms need reporting in the sync response.
ignored_users = await self.store.ignored_users(user_id)
if since_token:
- room_changes = await self._get_rooms_changed(
+ room_changes = await self._get_room_changes_for_incremental_sync(
sync_result_builder, ignored_users
)
tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key
)
else:
- room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
+ room_changes = await self._get_room_changes_for_initial_sync(
+ sync_result_builder, ignored_users
+ )
tags_by_room = await self.store.get_tags_for_user(user_id)
log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1827,6 +1919,7 @@ class SyncHandler:
# joined or archived).
async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
logger.debug("Generating room entry for %s", room_entry.room_id)
+ # Note that this mutates sync_result_builder.{joined,archived}.
await self._generate_room_entry(
sync_result_builder,
room_entry,
@@ -1843,20 +1936,7 @@ class SyncHandler:
sync_result_builder.invited.extend(invited)
sync_result_builder.knocked.extend(knocked)
- # 5. Work out which users have joined or left rooms we're in. We use this
- # to build the device_list part of the sync response in
- # `_generate_sync_entry_for_device_list`.
- (
- newly_joined_or_invited_or_knocked_users,
- newly_left_users,
- ) = sync_result_builder.calculate_user_changes()
-
- return (
- set(newly_joined_rooms),
- newly_joined_or_invited_or_knocked_users,
- set(newly_left_rooms),
- newly_left_users,
- )
+ return set(newly_joined_rooms), set(newly_left_rooms)
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
@@ -1871,7 +1951,7 @@ class SyncHandler:
assert since_token
- if membership_change_events:
+ if membership_change_events or sync_result_builder.forced_newly_joined_room_ids:
return True
stream_id = since_token.room_key.stream
@@ -1880,7 +1960,7 @@ class SyncHandler:
return True
return False
- async def _get_rooms_changed(
+ async def _get_room_changes_for_incremental_sync(
self,
sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
@@ -1918,7 +1998,9 @@ class SyncHandler:
for event in membership_change_events:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
- newly_joined_rooms: List[str] = []
+ newly_joined_rooms: List[str] = list(
+ sync_result_builder.forced_newly_joined_room_ids
+ )
newly_left_rooms: List[str] = []
room_entries: List[RoomSyncResultBuilder] = []
invited: List[InvitedSyncResult] = []
@@ -2124,7 +2206,7 @@ class SyncHandler:
newly_left_rooms,
)
- async def _get_all_rooms(
+ async def _get_room_changes_for_initial_sync(
self,
sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
@@ -2149,7 +2231,7 @@ class SyncHandler:
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id,
membership_list=Membership.LIST,
- excluded_rooms=self.rooms_to_exclude,
+ excluded_rooms=sync_result_builder.excluded_room_ids,
)
room_entries = []
@@ -2307,7 +2389,9 @@ class SyncHandler:
account_data_events = []
if tags is not None:
- account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
+ account_data_events.append(
+ {"type": AccountDataTypes.TAG, "content": {"tags": tags}}
+ )
for account_data_type, content in account_data.items():
account_data_events.append(
@@ -2518,6 +2602,13 @@ class SyncResultBuilder:
since_token: The token supplied by user, or None.
now_token: The token to sync up to.
joined_room_ids: List of rooms the user is joined to
+ excluded_room_ids: Set of room ids we should omit from the /sync response.
+ forced_newly_joined_room_ids:
+ Rooms that should be presented in the /sync response as if they were
+ newly joined during the sync period, even if that's not the case.
+ (This is useful if the room was previously excluded from a /sync response,
+ and now the client should be made aware of it.)
+ Only used by incremental syncs.
# The following mirror the fields in a sync response
presence
@@ -2534,6 +2625,8 @@ class SyncResultBuilder:
since_token: Optional[StreamToken]
now_token: StreamToken
joined_room_ids: FrozenSet[str]
+ excluded_room_ids: FrozenSet[str]
+ forced_newly_joined_room_ids: FrozenSet[str]
membership_change_events: List[EventBase]
presence: List[UserPresenceState] = attr.Factory(list)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index b59d309606..b38b5ff495 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler):
if last_id == current_id:
return [], current_id, False
- changed_rooms: Optional[
- Iterable[str]
- ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
+ result = self._typing_stream_change_cache.get_all_entities_changed(last_id)
- if changed_rooms is None:
+ if result.hit:
+ changed_rooms: Iterable[str] = result.entities
+ else:
changed_rooms = self._room_serials
rows = []
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 1124d0d0ce..404570e8c3 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -577,7 +577,24 @@ def _unrecognised_request_handler(request: Request) -> NoReturn:
Args:
request: Unused, but passed in to match the signature of ServletCallback.
"""
- raise UnrecognizedRequestError()
+ raise UnrecognizedRequestError(code=404)
+
+
+class UnrecognizedRequestResource(resource.Resource):
+ """
+ Similar to twisted.web.resource.NoResource, but returns a JSON 404 with an
+ errcode of M_UNRECOGNIZED.
+ """
+
+ def render(self, request: SynapseRequest) -> int:
+ f = failure.Failure(UnrecognizedRequestError(code=404))
+ return_json_error(f, request, None)
+ # A response has already been sent but Twisted requires either NOT_DONE_YET
+ # or the response bytes as a return value.
+ return NOT_DONE_YET
+
+ def getChild(self, name: str, request: Request) -> resource.Resource:
+ return self
class RootRedirect(resource.Resource):
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index dead02cd5c..0070bd2940 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -13,6 +13,7 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
+import enum
import logging
from http import HTTPStatus
from typing import (
@@ -362,6 +363,7 @@ def parse_string(
request: Request,
name: str,
*,
+ default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
@@ -413,6 +415,74 @@ def parse_string(
)
+EnumT = TypeVar("EnumT", bound=enum.Enum)
+
+
+@overload
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ default: EnumT,
+) -> EnumT:
+ ...
+
+
+@overload
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ *,
+ required: Literal[True],
+) -> EnumT:
+ ...
+
+
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ default: Optional[EnumT] = None,
+ required: bool = False,
+) -> Optional[EnumT]:
+ """
+ Parse an enum parameter from the request query string.
+
+ Note that the enum *must only have string values*.
+
+ Args:
+ request: the twisted HTTP request.
+ name: the name of the query parameter.
+ E: the enum which represents valid values
+ default: enum value to use if the parameter is absent, defaults to None.
+ required: whether to raise a 400 SynapseError if the
+ parameter is absent, defaults to False.
+
+ Returns:
+ An enum value.
+
+ Raises:
+ SynapseError if the parameter is absent and required, or if the
+ parameter is present, must be one of a list of allowed values and
+ is not one of those allowed values.
+ """
+ # Assert the enum values are strings.
+ assert all(
+ isinstance(e.value, str) for e in E
+ ), "parse_enum only works with string values"
+ str_value = parse_string(
+ request,
+ name,
+ default=default.value if default is not None else None,
+ required=required,
+ allowed_values=[e.value for e in E],
+ )
+ if str_value is None:
+ return None
+ return E(str_value)
+
+
def _parse_string_value(
value: bytes,
allowed_values: Optional[Iterable[str]],
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 96d1c9ae9e..524a0536b0 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -292,8 +292,15 @@ logger = logging.getLogger(__name__)
class SynapseTags:
- # The message ID of any to_device message processed
- TO_DEVICE_MESSAGE_ID = "to_device.message_id"
+ # The message ID of any to_device EDU processed
+ TO_DEVICE_EDU_ID = "to_device.edu_id"
+
+ # Details about to-device messages
+ TO_DEVICE_TYPE = "to_device.type"
+ TO_DEVICE_SENDER = "to_device.sender"
+ TO_DEVICE_RECIPIENT = "to_device.recipient"
+ TO_DEVICE_RECIPIENT_DEVICE = "to_device.recipient_device"
+ TO_DEVICE_MSGID = "to_device.msgid" # client-generated ID
# Whether the sync response has new data to be returned to the client.
SYNC_RESULT = "sync.new_data"
@@ -315,6 +322,11 @@ class SynapseTags:
# The name of the external cache
CACHE_NAME = "cache.name"
+ # Boolean. Present on /v2/send_join requests, omitted from all others.
+ # True iff partial state was requested and we provided (or intended to provide)
+ # partial state in the response.
+ SEND_JOIN_RESPONSE_IS_PARTIAL_STATE = "send_join.partial_state_response"
+
# Used to tag function arguments
#
# Tag a named arg. The name of the argument should be appended to this prefix.
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index c3d3daf877..b01372565d 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -47,11 +47,7 @@ from twisted.python.threadpool import ThreadPool
# This module is imported for its side effects; flake8 needn't warn that it's unused.
import synapse.metrics._reactor_metrics # noqa: F401
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
-from synapse.metrics._legacy_exposition import (
- MetricsResource,
- generate_latest,
- start_http_server,
-)
+from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
from synapse.metrics._types import Collector
from synapse.util import SYNAPSE_VERSION
@@ -474,7 +470,6 @@ __all__ = [
"Collector",
"MetricsResource",
"generate_latest",
- "start_http_server",
"LaterGauge",
"InFlightGauge",
"GaugeBucketCollector",
diff --git a/synapse/metrics/_legacy_exposition.py b/synapse/metrics/_legacy_exposition.py
deleted file mode 100644
index 1459f9d224..0000000000
--- a/synapse/metrics/_legacy_exposition.py
+++ /dev/null
@@ -1,288 +0,0 @@
-# Copyright 2015-2019 Prometheus Python Client Developers
-# Copyright 2019 Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This code is based off `prometheus_client/exposition.py` from version 0.7.1.
-
-Due to the renaming of metrics in prometheus_client 0.4.0, this customised
-vendoring of the code will emit both the old versions that Synapse dashboards
-expect, and the newer "best practice" version of the up-to-date official client.
-"""
-import logging
-import math
-import threading
-from http.server import BaseHTTPRequestHandler, HTTPServer
-from socketserver import ThreadingMixIn
-from typing import Any, Dict, List, Type, Union
-from urllib.parse import parse_qs, urlparse
-
-from prometheus_client import REGISTRY, CollectorRegistry
-from prometheus_client.core import Sample
-
-from twisted.web.resource import Resource
-from twisted.web.server import Request
-
-logger = logging.getLogger(__name__)
-CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
-
-
-def floatToGoString(d: Union[int, float]) -> str:
- d = float(d)
- if d == math.inf:
- return "+Inf"
- elif d == -math.inf:
- return "-Inf"
- elif math.isnan(d):
- return "NaN"
- else:
- s = repr(d)
- dot = s.find(".")
- # Go switches to exponents sooner than Python.
- # We only need to care about positive values for le/quantile.
- if d > 0 and dot > 6:
- mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.")
- return f"{mantissa}e+0{dot - 1}"
- return s
-
-
-def sample_line(line: Sample, name: str) -> str:
- if line.labels:
- labelstr = "{{{0}}}".format(
- ",".join(
- [
- '{}="{}"'.format(
- k,
- v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""),
- )
- for k, v in sorted(line.labels.items())
- ]
- )
- )
- else:
- labelstr = ""
- timestamp = ""
- if line.timestamp is not None:
- # Convert to milliseconds.
- timestamp = f" {int(float(line.timestamp) * 1000):d}"
- return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
-
-
-# Mapping from new metric names to legacy metric names.
-# We translate these back to their old names when exposing them through our
-# legacy vendored exporter.
-# Only this legacy exposition module applies these name changes.
-LEGACY_METRIC_NAMES = {
- "synapse_util_caches_cache_hits": "synapse_util_caches_cache:hits",
- "synapse_util_caches_cache_size": "synapse_util_caches_cache:size",
- "synapse_util_caches_cache_evicted_size": "synapse_util_caches_cache:evicted_size",
- "synapse_util_caches_cache": "synapse_util_caches_cache:total",
- "synapse_util_caches_response_cache_size": "synapse_util_caches_response_cache:size",
- "synapse_util_caches_response_cache_hits": "synapse_util_caches_response_cache:hits",
- "synapse_util_caches_response_cache_evicted_size": "synapse_util_caches_response_cache:evicted_size",
- "synapse_util_caches_response_cache": "synapse_util_caches_response_cache:total",
- "synapse_federation_client_sent_pdu_destinations": "synapse_federation_client_sent_pdu_destinations:total",
- "synapse_federation_client_sent_pdu_destinations_count": "synapse_federation_client_sent_pdu_destinations:count",
- "synapse_admin_mau_current": "synapse_admin_mau:current",
- "synapse_admin_mau_max": "synapse_admin_mau:max",
- "synapse_admin_mau_registered_reserved_users": "synapse_admin_mau:registered_reserved_users",
-}
-
-
-def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes:
- """
- Generate metrics in legacy format. Modern metrics are generated directly
- by prometheus-client.
- """
-
- output = []
-
- for metric in registry.collect():
- if not metric.samples:
- # No samples, don't bother.
- continue
-
- # Translate to legacy metric name if it has one.
- mname = LEGACY_METRIC_NAMES.get(metric.name, metric.name)
- mnewname = metric.name
- mtype = metric.type
-
- # OpenMetrics -> Prometheus
- if mtype == "counter":
- mnewname = mnewname + "_total"
- elif mtype == "info":
- mtype = "gauge"
- mnewname = mnewname + "_info"
- elif mtype == "stateset":
- mtype = "gauge"
- elif mtype == "gaugehistogram":
- mtype = "histogram"
- elif mtype == "unknown":
- mtype = "untyped"
-
- # Output in the old format for compatibility.
- if emit_help:
- output.append(
- "# HELP {} {}\n".format(
- mname,
- metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
- )
- )
- output.append(f"# TYPE {mname} {mtype}\n")
-
- om_samples: Dict[str, List[str]] = {}
- for s in metric.samples:
- for suffix in ["_created", "_gsum", "_gcount"]:
- if s.name == mname + suffix:
- # OpenMetrics specific sample, put in a gauge at the end.
- # (these come from gaugehistograms which don't get renamed,
- # so no need to faff with mnewname)
- om_samples.setdefault(suffix, []).append(sample_line(s, s.name))
- break
- else:
- newname = s.name.replace(mnewname, mname)
- if ":" in newname and newname.endswith("_total"):
- newname = newname[: -len("_total")]
- output.append(sample_line(s, newname))
-
- for suffix, lines in sorted(om_samples.items()):
- if emit_help:
- output.append(
- "# HELP {}{} {}\n".format(
- mname,
- suffix,
- metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
- )
- )
- output.append(f"# TYPE {mname}{suffix} gauge\n")
- output.extend(lines)
-
- # Get rid of the weird colon things while we're at it
- if mtype == "counter":
- mnewname = mnewname.replace(":total", "")
- mnewname = mnewname.replace(":", "_")
-
- if mname == mnewname:
- continue
-
- # Also output in the new format, if it's different.
- if emit_help:
- output.append(
- "# HELP {} {}\n".format(
- mnewname,
- metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
- )
- )
- output.append(f"# TYPE {mnewname} {mtype}\n")
-
- for s in metric.samples:
- # Get rid of the OpenMetrics specific samples (we should already have
- # dealt with them above anyway.)
- for suffix in ["_created", "_gsum", "_gcount"]:
- if s.name == mname + suffix:
- break
- else:
- sample_name = LEGACY_METRIC_NAMES.get(s.name, s.name)
- output.append(
- sample_line(s, sample_name.replace(":total", "").replace(":", "_"))
- )
-
- return "".join(output).encode("utf-8")
-
-
-class MetricsHandler(BaseHTTPRequestHandler):
- """HTTP handler that gives metrics from ``REGISTRY``."""
-
- registry = REGISTRY
-
- def do_GET(self) -> None:
- registry = self.registry
- params = parse_qs(urlparse(self.path).query)
-
- if "help" in params:
- emit_help = True
- else:
- emit_help = False
-
- try:
- output = generate_latest(registry, emit_help=emit_help)
- except Exception:
- self.send_error(500, "error generating metric output")
- raise
- try:
- self.send_response(200)
- self.send_header("Content-Type", CONTENT_TYPE_LATEST)
- self.send_header("Content-Length", str(len(output)))
- self.end_headers()
- self.wfile.write(output)
- except BrokenPipeError as e:
- logger.warning(
- "BrokenPipeError when serving metrics (%s). Did Prometheus restart?", e
- )
-
- def log_message(self, format: str, *args: Any) -> None:
- """Log nothing."""
-
- @classmethod
- def factory(cls, registry: CollectorRegistry) -> Type:
- """Returns a dynamic MetricsHandler class tied
- to the passed registry.
- """
- # This implementation relies on MetricsHandler.registry
- # (defined above and defaulted to REGISTRY).
-
- # As we have unicode_literals, we need to create a str()
- # object for type().
- cls_name = str(cls.__name__)
- MyMetricsHandler = type(cls_name, (cls, object), {"registry": registry})
- return MyMetricsHandler
-
-
-class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
- """Thread per request HTTP server."""
-
- # Make worker threads "fire and forget". Beginning with Python 3.7 this
- # prevents a memory leak because ``ThreadingMixIn`` starts to gather all
- # non-daemon threads in a list in order to join on them at server close.
- # Enabling daemon threads virtually makes ``_ThreadingSimpleServer`` the
- # same as Python 3.7's ``ThreadingHTTPServer``.
- daemon_threads = True
-
-
-def start_http_server(
- port: int, addr: str = "", registry: CollectorRegistry = REGISTRY
-) -> None:
- """Starts an HTTP server for prometheus metrics as a daemon thread"""
- CustomMetricsHandler = MetricsHandler.factory(registry)
- httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler)
- t = threading.Thread(target=httpd.serve_forever)
- t.daemon = True
- t.start()
-
-
-class MetricsResource(Resource):
- """
- Twisted ``Resource`` that serves prometheus metrics.
- """
-
- isLeaf = True
-
- def __init__(self, registry: CollectorRegistry = REGISTRY):
- self.registry = registry
-
- def render_GET(self, request: Request) -> bytes:
- request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
- response = generate_latest(self.registry)
- request.setHeader(b"Content-Length", str(len(response)))
- return response
diff --git a/synapse/metrics/_twisted_exposition.py b/synapse/metrics/_twisted_exposition.py
new file mode 100644
index 0000000000..0abcd14953
--- /dev/null
+++ b/synapse/metrics/_twisted_exposition.py
@@ -0,0 +1,38 @@
+# Copyright 2015-2019 Prometheus Python Client Developers
+# Copyright 2019 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from prometheus_client import REGISTRY, CollectorRegistry, generate_latest
+
+from twisted.web.resource import Resource
+from twisted.web.server import Request
+
+CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
+
+
+class MetricsResource(Resource):
+ """
+ Twisted ``Resource`` that serves prometheus metrics.
+ """
+
+ isLeaf = True
+
+ def __init__(self, registry: CollectorRegistry = REGISTRY):
+ self.registry = registry
+
+ def render_GET(self, request: Request) -> bytes:
+ request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
+ response = generate_latest(self.registry)
+ request.setHeader(b"Content-Length", str(len(response)))
+ return response
diff --git a/synapse/metrics/common_usage_metrics.py b/synapse/metrics/common_usage_metrics.py
index 0a22ea3d92..6e05b043d3 100644
--- a/synapse/metrics/common_usage_metrics.py
+++ b/synapse/metrics/common_usage_metrics.py
@@ -54,7 +54,9 @@ class CommonUsageMetricsManager:
async def setup(self) -> None:
"""Keep the gauges for common usage metrics up to date."""
- await self._update_gauges()
+ run_as_background_process(
+ desc="common_usage_metrics_update_gauges", func=self._update_gauges
+ )
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 1adc1fd64f..d22dd19d38 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -18,6 +18,7 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
+ Collection,
Dict,
Generator,
Iterable,
@@ -86,6 +87,7 @@ from synapse.handlers.auth import (
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
+from synapse.handlers.device import DeviceHandler
from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
@@ -110,7 +112,6 @@ from synapse.storage.background_updates import (
)
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
-from synapse.storage.state import StateFilter
from synapse.types import (
DomainSpecificString,
JsonDict,
@@ -123,9 +124,10 @@ from synapse.types import (
UserProfile,
create_requester,
)
+from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
-from synapse.util.caches.descriptors import CachedFunction, cached
+from synapse.util.caches.descriptors import CachedFunction, cached as _cached
from synapse.util.frozenutils import freeze
if TYPE_CHECKING:
@@ -135,6 +137,7 @@ if TYPE_CHECKING:
T = TypeVar("T")
P = ParamSpec("P")
+F = TypeVar("F", bound=Callable[..., Any])
"""
This package defines the 'stable' API which can be used by extension modules which
@@ -184,6 +187,42 @@ class UserIpAndAgent:
last_seen: int
+def cached(
+ *,
+ max_entries: int = 1000,
+ num_args: Optional[int] = None,
+ uncached_args: Optional[Collection[str]] = None,
+) -> Callable[[F], CachedFunction[F]]:
+ """Returns a decorator that applies a memoizing cache around the function. This
+ decorator behaves similarly to functools.lru_cache.
+
+ Example:
+
+ @cached()
+ def foo('a', 'b'):
+ ...
+
+ Added in Synapse v1.74.0.
+
+ Args:
+ max_entries: The maximum number of entries in the cache. If the cache is full
+ and a new entry is added, the least recently accessed entry will be evicted
+ from the cache.
+ num_args: The number of positional arguments (excluding `self`) to use as cache
+ keys. Defaults to all named args of the function.
+ uncached_args: A list of argument names to not use as the cache key. (`self` is
+ always ignored.) Cannot be used with num_args.
+
+ Returns:
+ A decorator that applies a memoizing cache around the function.
+ """
+ return _cached(
+ max_entries=max_entries,
+ num_args=num_args,
+ uncached_args=uncached_args,
+ )
+
+
class ModuleApi:
"""A proxy object that gets passed to various plugin modules so they
can register new users etc if necessary.
@@ -207,6 +246,7 @@ class ModuleApi:
self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
self._push_rules_handler = hs.get_push_rules_handler()
+ self._device_handler = hs.get_device_handler()
self.custom_template_dir = hs.config.server.custom_template_directory
try:
@@ -784,6 +824,8 @@ class ModuleApi:
) -> Generator["defer.Deferred[Any]", Any, None]:
"""Invalidate an access token for a user
+ Can only be called from the main process.
+
Added in Synapse v0.25.0.
Args:
@@ -796,6 +838,10 @@ class ModuleApi:
Raises:
synapse.api.errors.AuthError: the access token is invalid
"""
+ assert isinstance(
+ self._device_handler, DeviceHandler
+ ), "invalidate_access_token can only be called on the main process"
+
# see if the access token corresponds to a device
user_info = yield defer.ensureDeferred(
self._auth.get_user_by_access_token(access_token)
@@ -805,7 +851,7 @@ class ModuleApi:
if device_id:
# delete the device, which will also delete its access tokens
yield defer.ensureDeferred(
- self._hs.get_device_handler().delete_devices(user_id, [device_id])
+ self._device_handler.delete_devices(user_id, [device_id])
)
else:
# no associated device. Just delete the access token.
@@ -1112,7 +1158,7 @@ class ModuleApi:
# Send to remote destinations.
destination = UserID.from_string(user).domain
presence_handler.get_federation_queue().send_presence_to_destinations(
- presence_events, destination
+ presence_events, [destination]
)
def looping_background_call(
@@ -1539,6 +1585,33 @@ class ModuleApi:
return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None)
+ async def set_displayname(
+ self,
+ user_id: UserID,
+ new_displayname: str,
+ deactivation: bool = False,
+ ) -> None:
+ """Sets a user's display name.
+
+ Added in Synapse v1.76.0.
+
+ Args:
+ user_id:
+ The user whose display name is to be changed.
+ new_displayname:
+ The new display name to give the user.
+ deactivation:
+ Whether this change was made while deactivating the user.
+ """
+ requester = create_requester(user_id)
+ await self._hs.get_profile_handler().set_displayname(
+ target_user=user_id,
+ requester=requester,
+ new_displayname=new_displayname,
+ by_admin=True,
+ deactivation=deactivation,
+ )
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 26b97cf766..a8832a3f8e 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -46,6 +46,7 @@ from synapse.types import (
JsonDict,
PersistedEventPosition,
RoomStreamToken,
+ StrCollection,
StreamKeyType,
StreamToken,
UserID,
@@ -226,8 +227,7 @@ class Notifier:
self.store = hs.get_datastores().main
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
- # Called when there are new things to stream over replication
- self.replication_callbacks: List[Callable[[], None]] = []
+ self._replication_notifier = hs.get_replication_notifier()
self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
self._federation_client = hs.get_federation_http_client()
@@ -279,7 +279,7 @@ class Notifier:
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
"""
- self.replication_callbacks.append(cb)
+ self._replication_notifier.add_replication_callback(cb)
def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
"""Add a callback that will be called when a user joins a room.
@@ -315,6 +315,32 @@ class Notifier:
event_entries.append((entry, event.event_id))
await self.notify_new_room_events(event_entries, max_room_stream_token)
+ async def on_un_partial_stated_room(
+ self,
+ room_id: str,
+ new_token: int,
+ ) -> None:
+ """Used by the resync background processes to wake up all listeners
+ of this room when it is un-partial-stated.
+
+ It will also notify replication listeners of the change in stream.
+ """
+
+ # Wake up all related user stream notifiers
+ user_streams = self.room_to_user_streams.get(room_id, set())
+ time_now_ms = self.clock.time_msec()
+ for user_stream in user_streams:
+ try:
+ user_stream.notify(
+ StreamKeyType.UN_PARTIAL_STATED_ROOMS, new_token, time_now_ms
+ )
+ except Exception:
+ logger.exception("Failed to notify listener")
+
+ # Poke the replication so that other workers also see the write to
+ # the un-partial-stated rooms stream.
+ self.notify_replication()
+
async def notify_new_room_events(
self,
event_entries: List[Tuple[_PendingRoomEventEntry, str]],
@@ -691,7 +717,7 @@ class Notifier:
async def _get_room_ids(
self, user: UserID, explicit_room_id: Optional[str]
- ) -> Tuple[Collection[str], bool]:
+ ) -> Tuple[StrCollection, bool]:
joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
@@ -741,8 +767,7 @@ class Notifier:
def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
- for cb in self.replication_callbacks:
- cb()
+ self._replication_notifier.notify_replication()
def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
for cb in self._new_join_in_room_callbacks:
@@ -759,3 +784,26 @@ class Notifier:
# Tell the federation client about the fact the server is back up, so
# that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server)
+
+
+@attr.s(auto_attribs=True)
+class ReplicationNotifier:
+ """Tracks callbacks for things that need to know about stream changes.
+
+ This is separate from the notifier to avoid circular dependencies.
+ """
+
+ _replication_callbacks: List[Callable[[], None]] = attr.Factory(list)
+
+ def add_replication_callback(self, cb: Callable[[], None]) -> None:
+ """Add a callback that will be called when some new data is available.
+ Callback is not given any arguments. It should *not* return a Deferred - if
+ it needs to do any asynchronous work, a background thread should be started and
+ wrapped with run_as_background_process.
+ """
+ self._replication_callbacks.append(cb)
+
+ def notify_replication(self) -> None:
+ """Notify the any replication listeners that there's a new event"""
+ for cb in self._replication_callbacks:
+ cb()
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 75b7e126ca..20369f3dfe 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -22,20 +22,28 @@ from typing import (
List,
Mapping,
Optional,
+ Set,
Tuple,
Union,
)
from prometheus_client import Counter
-from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes
+from synapse.api.constants import (
+ MAIN_TIMELINE,
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RelationTypes,
+)
+from synapse.api.room_versions import PushRuleRoomFlag, RoomVersion
from synapse.event_auth import auth_types_for_event, get_user_power_level
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
-from synapse.storage.state import StateFilter
from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
+from synapse.types.state import StateFilter
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
@@ -61,6 +69,9 @@ STATE_EVENT_TYPES_TO_MARK_UNREAD = {
}
+SENTINEL = object()
+
+
def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
# Exclude rejected and soft-failed events.
if context.rejected or event.internal_metadata.is_soft_failed():
@@ -105,8 +116,12 @@ class BulkPushRuleEvaluator:
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
self._event_auth_handler = hs.get_event_auth_handler()
+ self.should_calculate_push_rules = self.hs.config.push.enable_push
self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled
+ self._intentional_mentions_enabled = (
+ self.hs.config.experimental.msc3952_intentional_mentions
+ )
self.room_push_rule_cache_metrics = register_cache(
"cache",
@@ -268,6 +283,8 @@ class BulkPushRuleEvaluator:
for each event, check if the message should increment the unread count, and
insert the results into the event_push_actions_staging table.
"""
+ if not self.should_calculate_push_rules:
+ return
# For batched events the power level events may not have been persisted yet,
# so we pass in the batched events. Thus if the event cannot be found in the
# database we can check in the batch.
@@ -332,19 +349,51 @@ class BulkPushRuleEvaluator:
related_events = await self._related_events(event)
# It's possible that old room versions have non-integer power levels (floats or
- # strings). Workaround this by explicitly converting to int.
+ # strings; even the occasional `null`). For old rooms, we interpret these as if
+ # they were integers. Do this here for the `@room` power level threshold.
+ # Note that this is done automatically for the sender's power level by
+ # _get_power_levels_and_sender_level in its call to get_user_power_level
+ # (even for room V10.)
notification_levels = power_levels.get("notifications", {})
if not event.room_version.msc3667_int_only_power_levels:
- for user_id, level in notification_levels.items():
- notification_levels[user_id] = int(level)
+ keys = list(notification_levels.keys())
+ for key in keys:
+ level = notification_levels.get(key, SENTINEL)
+ if level is not SENTINEL and type(level) is not int:
+ try:
+ notification_levels[key] = int(level)
+ except (TypeError, ValueError):
+ del notification_levels[key]
+
+ # Pull out any user and room mentions.
+ mentions = event.content.get(EventContentFields.MSC3952_MENTIONS)
+ has_mentions = self._intentional_mentions_enabled and isinstance(mentions, dict)
+ user_mentions: Set[str] = set()
+ room_mention = False
+ if has_mentions:
+ # mypy seems to have lost the type even though it must be a dict here.
+ assert isinstance(mentions, dict)
+ # Remove out any non-string items and convert to a set.
+ user_mentions_raw = mentions.get("user_ids")
+ if isinstance(user_mentions_raw, list):
+ user_mentions = set(
+ filter(lambda item: isinstance(item, str), user_mentions_raw)
+ )
+ # Room mention is only true if the value is exactly true.
+ room_mention = mentions.get("room") is True
evaluator = PushRuleEvaluator(
- _flatten_dict(event),
+ _flatten_dict(event, room_version=event.room_version),
+ has_mentions,
+ user_mentions,
+ room_mention,
room_member_count,
sender_power_level,
notification_levels,
related_events,
self._related_event_match_enabled,
+ event.room_version.msc3931_push_features,
+ self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
)
users = rules_by_user.keys()
@@ -420,9 +469,33 @@ StateGroup = Union[object, int]
def _flatten_dict(
d: Union[EventBase, Mapping[str, Any]],
+ room_version: Optional[RoomVersion] = None,
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
+ """
+ Given a JSON dictionary (or event) which might contain sub dictionaries,
+ flatten it into a single layer dictionary by combining the keys & sub-keys.
+
+ Any (non-dictionary), non-string value is dropped.
+
+ Transforms:
+
+ {"foo": {"bar": "test"}}
+
+ To:
+
+ {"foo.bar": "test"}
+
+ Args:
+ d: The event or content to continue flattening.
+ room_version: The room version object.
+ prefix: The key prefix (from outer dictionaries).
+ result: The result to mutate.
+
+ Returns:
+ The resulting dictionary.
+ """
if prefix is None:
prefix = []
if result is None:
@@ -431,6 +504,31 @@ def _flatten_dict(
if isinstance(value, str):
result[".".join(prefix + [key])] = value.lower()
elif isinstance(value, Mapping):
+ # do not set `room_version` due to recursion considerations below
_flatten_dict(value, prefix=(prefix + [key]), result=result)
+ # `room_version` should only ever be set when looking at the top level of an event
+ if (
+ room_version is not None
+ and PushRuleRoomFlag.EXTENSIBLE_EVENTS in room_version.msc3931_push_features
+ and isinstance(d, EventBase)
+ ):
+ # Room supports extensible events: replace `content.body` with the plain text
+ # representation from `m.markup`, as per MSC1767.
+ markup = d.get("content").get("m.markup")
+ if room_version.identifier.startswith("org.matrix.msc1767."):
+ markup = d.get("content").get("org.matrix.msc1767.markup")
+ if markup is not None and isinstance(markup, list):
+ text = ""
+ for rep in markup:
+ if not isinstance(rep, dict):
+ # invalid markup - skip all processing
+ break
+ if rep.get("mimetype", "text/plain") == "text/plain":
+ rep_text = rep.get("body")
+ if rep_text is not None and isinstance(rep_text, str):
+ text = rep_text.lower()
+ break
+ result["content.body"] = text
+
return result
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 622a1e35c5..bb76c169c6 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -26,10 +26,7 @@ def format_push_rules_for_user(
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
- rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {
- "global": {},
- "device": {},
- }
+ rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {"global": {}}
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index c2575ba3d9..93b255ced5 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -37,8 +37,8 @@ from synapse.push.push_types import (
TemplateVars,
)
from synapse.storage.databases.main.event_push_actions import EmailPushAction
-from synapse.storage.state import StateFilter
from synapse.types import StateMap, UserID
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index edeba27a45..7ee07e4bee 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,7 +17,6 @@ from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
-from synapse.util.async_helpers import concurrently_execute
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
@@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge = len(invites)
- room_notifs = []
-
- async def get_room_unread_count(room_id: str) -> None:
- room_notifs.append(
- await store.get_unread_event_push_actions_by_room_for_user(
- room_id,
- user_id,
- )
- )
-
- await concurrently_execute(get_room_unread_count, joins, 10)
-
- for notifs in room_notifs:
- # Combine the counts from all the threads.
- notify_count = notifs.main_timeline.notify_count + sum(
- n.notify_count for n in notifs.threads.values()
- )
+ room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
+ for room_id, notify_count in room_to_count.items():
+ # room_to_count may include rooms which the user has left,
+ # ignore those.
+ if room_id not in joins:
+ continue
if notify_count == 0:
continue
@@ -51,8 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
# return one badge count per conversation
badge += 1
else:
- # increment the badge count by the number of unread messages in the room
+ # Increase badge by number of notifications in room
+ # NOTE: this includes threaded and unthreaded notifications.
badge += notify_count
+
return badge
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 3f4d3fc51a..908f3f1db7 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -17,7 +17,7 @@ import logging
import re
import urllib.parse
from inspect import signature
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple
from prometheus_client import Counter, Gauge
@@ -27,6 +27,7 @@ from twisted.web.server import Request
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer
+from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.opentracing import trace_with_opname
@@ -53,6 +54,9 @@ _outgoing_request_counter = Counter(
)
+_STREAM_POSITION_KEY = "_INT_STREAM_POS"
+
+
class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""Helper base class for defining new replication HTTP endpoints.
@@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
a connection error is received.
RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
receiving connection errors, each will backoff exponentially longer.
+ WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to
+ catch up before processing the request and/or response. Defaults to
+ True.
"""
NAME: str = abc.abstractproperty() # type: ignore
@@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
RETRY_ON_CONNECT_ERROR = True
RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1)
+ WAIT_FOR_STREAMS: ClassVar[bool] = True
+
def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache: ResponseCache[str] = ResponseCache(
@@ -126,6 +135,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret
+ self._streams = hs.get_replication_command_handler().get_streams_to_replicate()
+ self._replication = hs.get_replication_data_handler()
+ self._instance_name = hs.get_instance_name()
+
def _check_auth(self, request: Request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@@ -160,7 +173,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def _handle_request(
- self, request: Request, **kwargs: Any
+ self, request: Request, content: JsonDict, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Handle incoming request.
@@ -201,6 +214,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
+ # We have to pull these out here to avoid circular dependencies...
+ streams = hs.get_replication_command_handler().get_streams_to_replicate()
+ replication = hs.get_replication_data_handler()
+
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
@@ -219,6 +236,24 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
data = await cls._serialize_payload(**kwargs)
+ if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS:
+ # Include the current stream positions that we write to. We
+ # don't do this for GETs as they don't have a body, and we
+ # generally assume that a GET won't rely on data we have
+ # written.
+ if _STREAM_POSITION_KEY in data:
+ raise Exception(
+ "data to send contains %r key", _STREAM_POSITION_KEY
+ )
+
+ data[_STREAM_POSITION_KEY] = {
+ "streams": {
+ stream.NAME: stream.current_token(local_instance_name)
+ for stream in streams
+ },
+ "instance_name": local_instance_name,
+ }
+
url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
@@ -308,6 +343,17 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
) from e
_outgoing_request_counter.labels(cls.NAME, 200).inc()
+
+ # Wait on any streams that the remote may have written to.
+ for stream_name, position in result.get(
+ _STREAM_POSITION_KEY, {}
+ ).items():
+ await replication.wait_for_stream_position(
+ instance_name=instance_name,
+ stream_name=stream_name,
+ position=position,
+ )
+
return result
return send_request
@@ -353,6 +399,22 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self._replication_secret:
self._check_auth(request)
+ if self.METHOD == "GET":
+ # GET APIs always have an empty body.
+ content = {}
+ else:
+ content = parse_json_object_from_request(request)
+
+ # Wait on any streams that the remote may have written to.
+ for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[
+ "streams"
+ ].items():
+ await self._replication.wait_for_stream_position(
+ instance_name=content[_STREAM_POSITION_KEY]["instance_name"],
+ stream_name=stream_name,
+ position=position,
+ )
+
if self.CACHE:
txn_id = kwargs.pop("txn_id")
@@ -361,13 +423,28 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# correctly yet. In particular, there may be issues to do with logging
# context lifetimes.
- return await self.response_cache.wrap(
- txn_id, self._handle_request, request, **kwargs
+ code, response = await self.response_cache.wrap(
+ txn_id, self._handle_request, request, content, **kwargs
+ )
+ else:
+ # The `@cancellable` decorator may be applied to `_handle_request`. But we
+ # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
+ # so we have to set up the cancellable flag ourselves.
+ request.is_render_cancellable = is_function_cancellable(
+ self._handle_request
)
- # The `@cancellable` decorator may be applied to `_handle_request`. But we
- # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
- # so we have to set up the cancellable flag ourselves.
- request.is_render_cancellable = is_function_cancellable(self._handle_request)
+ code, response = await self._handle_request(request, content, **kwargs)
- return await self._handle_request(request, **kwargs)
+ # Return streams we may have written to in the course of processing this
+ # request.
+ if _STREAM_POSITION_KEY in response:
+ raise Exception("data to send contains %r key", _STREAM_POSITION_KEY)
+
+ if self.WAIT_FOR_STREAMS:
+ response[_STREAM_POSITION_KEY] = {
+ stream.NAME: stream.current_token(self._instance_name)
+ for stream in self._streams
+ }
+
+ return code, response
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
index 310f609153..2374f810c9 100644
--- a/synapse/replication/http/account_data.py
+++ b/synapse/replication/http/account_data.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -28,7 +27,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
+class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint):
"""Add user account data on the appropriate account data worker.
Request format:
@@ -49,7 +48,6 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
- self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload( # type: ignore[override]
@@ -62,10 +60,8 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str, account_data_type: str
+ self, request: Request, content: JsonDict, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
max_stream_id = await self.handler.add_account_data_for_user(
user_id, account_data_type, content["content"]
)
@@ -73,7 +69,45 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
return 200, {"max_stream_id": max_stream_id}
-class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
+class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint):
+ """Remove user account data on the appropriate account data worker.
+
+ Request format:
+
+ POST /_synapse/replication/remove_user_account_data/:user_id/:type
+
+ {
+ "content": { ... },
+ }
+
+ """
+
+ NAME = "remove_user_account_data"
+ PATH_ARGS = ("user_id", "account_data_type")
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.handler = hs.get_account_data_handler()
+
+ @staticmethod
+ async def _serialize_payload( # type: ignore[override]
+ user_id: str, account_data_type: str
+ ) -> JsonDict:
+ return {}
+
+ async def _handle_request( # type: ignore[override]
+ self, request: Request, content: JsonDict, user_id: str, account_data_type: str
+ ) -> Tuple[int, JsonDict]:
+ max_stream_id = await self.handler.remove_account_data_for_user(
+ user_id, account_data_type
+ )
+
+ return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint):
"""Add room account data on the appropriate account data worker.
Request format:
@@ -94,7 +128,6 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
- self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload( # type: ignore[override]
@@ -107,10 +140,13 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str, room_id: str, account_data_type: str
+ self,
+ request: Request,
+ content: JsonDict,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
max_stream_id = await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, content["content"]
)
@@ -118,6 +154,49 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
return 200, {"max_stream_id": max_stream_id}
+class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint):
+ """Remove room account data on the appropriate account data worker.
+
+ Request format:
+
+ POST /_synapse/replication/remove_room_account_data/:user_id/:room_id/:account_data_type
+
+ {
+ "content": { ... },
+ }
+
+ """
+
+ NAME = "remove_room_account_data"
+ PATH_ARGS = ("user_id", "room_id", "account_data_type")
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.handler = hs.get_account_data_handler()
+
+ @staticmethod
+ async def _serialize_payload( # type: ignore[override]
+ user_id: str, room_id: str, account_data_type: str, content: JsonDict
+ ) -> JsonDict:
+ return {}
+
+ async def _handle_request( # type: ignore[override]
+ self,
+ request: Request,
+ content: JsonDict,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
+ ) -> Tuple[int, JsonDict]:
+ max_stream_id = await self.handler.remove_account_data_for_room(
+ user_id, room_id, account_data_type
+ )
+
+ return 200, {"max_stream_id": max_stream_id}
+
+
class ReplicationAddTagRestServlet(ReplicationEndpoint):
"""Add tag on the appropriate account data worker.
@@ -139,7 +218,6 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
- self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload( # type: ignore[override]
@@ -152,10 +230,8 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str, room_id: str, tag: str
+ self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
max_stream_id = await self.handler.add_tag_to_room(
user_id, room_id, tag, content["content"]
)
@@ -186,7 +262,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
- self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override]
@@ -194,7 +269,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str, room_id: str, tag: str
+ self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_tag_from_room(
user_id,
@@ -206,7 +281,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- ReplicationUserAccountDataRestServlet(hs).register(http_server)
- ReplicationRoomAccountDataRestServlet(hs).register(http_server)
+ ReplicationAddUserAccountDataRestServlet(hs).register(http_server)
+ ReplicationAddRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server)
ReplicationRemoveTagRestServlet(hs).register(http_server)
+
+ if hs.config.experimental.msc3391_enabled:
+ ReplicationRemoveUserAccountDataRestServlet(hs).register(http_server)
+ ReplicationRemoveRoomAccountDataRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index c21629def8..ecea6fc915 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -13,12 +13,12 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.logging.opentracing import active_span
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.device_list_updater = hs.get_device_handler().device_list_updater
+ from synapse.handlers.device import DeviceHandler
+
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_list_updater = handler.device_list_updater
+
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@@ -72,13 +77,82 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
- ) -> Tuple[int, JsonDict]:
+ self, request: Request, content: JsonDict, user_id: str
+ ) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id)
return 200, user_devices
+class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
+ """Ask master to resync the device list for multiple users from the same
+ remote server by contacting their server.
+
+ This must happen on master so that the results can be correctly cached in
+ the database and streamed to workers.
+
+ Request format:
+
+ POST /_synapse/replication/multi_user_device_resync
+
+ {
+ "user_ids": ["@alice:example.org", "@bob:example.org", ...]
+ }
+
+ Response is roughly equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
+ response, but there is a map from user ID to response, e.g.:
+
+ {
+ "@alice:example.org": {
+ "devices": [
+ {
+ "device_id": "JLAFKJWSCS",
+ "keys": { ... },
+ "device_display_name": "Alice's Mobile Phone"
+ }
+ ]
+ },
+ ...
+ }
+ """
+
+ NAME = "multi_user_device_resync"
+ PATH_ARGS = ()
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ from synapse.handlers.device import DeviceHandler
+
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_list_updater = handler.device_list_updater
+
+ self.store = hs.get_datastores().main
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[override]
+ return {"user_ids": user_ids}
+
+ async def _handle_request( # type: ignore[override]
+ self, request: Request, content: JsonDict
+ ) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
+ user_ids: List[str] = content["user_ids"]
+
+ logger.info("Resync for %r", user_ids)
+ span = active_span()
+ if span:
+ span.set_tag("user_ids", f"{user_ids!r}")
+
+ multi_user_devices = await self.device_list_updater.multi_user_device_resync(
+ user_ids
+ )
+
+ return 200, multi_user_devices
+
+
class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
"""Ask master to upload keys for the user and send them out over federation to
update other servers.
@@ -129,10 +203,8 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: Request
+ self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
user_id = content["user_id"]
device_id = content["device_id"]
keys = content["keys"]
@@ -146,4 +218,5 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
+ ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server)
ReplicationUploadKeysForUserRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index d3abafed28..53ad327030 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
from synapse.util.metrics import Measure
@@ -114,10 +113,8 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload
- async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override]
+ async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override]
with Measure(self.clock, "repl_fed_send_events_parse"):
- content = parse_json_object_from_request(request)
-
room_id = content["room_id"]
backfilled = content["backfilled"]
@@ -181,13 +178,10 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
return {"origin": origin, "content": content}
async def _handle_request( # type: ignore[override]
- self, request: Request, edu_type: str
+ self, request: Request, content: JsonDict, edu_type: str
) -> Tuple[int, JsonDict]:
- with Measure(self.clock, "repl_fed_send_edu_parse"):
- content = parse_json_object_from_request(request)
-
- origin = content["origin"]
- edu_content = content["content"]
+ origin = content["origin"]
+ edu_content = content["content"]
logger.info("Got %r edu from %s", edu_type, origin)
@@ -231,13 +225,10 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
return {"args": args}
async def _handle_request( # type: ignore[override]
- self, request: Request, query_type: str
+ self, request: Request, content: JsonDict, query_type: str
) -> Tuple[int, JsonDict]:
- with Measure(self.clock, "repl_fed_query_parse"):
- content = parse_json_object_from_request(request)
-
- args = content["args"]
- args["origin"] = content["origin"]
+ args = content["args"]
+ args["origin"] = content["origin"]
logger.info("Got %r query from %s", query_type, args["origin"])
@@ -274,7 +265,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, room_id: str
+ self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
await self.store.clean_room_for_join(room_id)
@@ -307,9 +298,8 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
return {"room_version": room_version.identifier}
async def _handle_request( # type: ignore[override]
- self, request: Request, room_id: str
+ self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {}
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index c68e18da12..6ad6cb1bfe 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple, cast
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -73,10 +72,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 663bff5738..9fa1060d48 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
@@ -79,10 +78,8 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: SynapseRequest, room_id: str, user_id: str
+ self, request: SynapseRequest, content: JsonDict, room_id: str, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
remote_room_hosts = content["remote_room_hosts"]
event_content = content["content"]
@@ -147,11 +144,10 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
async def _handle_request( # type: ignore[override]
self,
request: SynapseRequest,
+ content: JsonDict,
room_id: str,
user_id: str,
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
remote_room_hosts = content["remote_room_hosts"]
event_content = content["content"]
@@ -217,10 +213,8 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: SynapseRequest, invite_event_id: str
+ self, request: SynapseRequest, content: JsonDict, invite_event_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
txn_id = content["txn_id"]
event_content = content["content"]
@@ -285,10 +279,9 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
async def _handle_request( # type: ignore[override]
self,
request: SynapseRequest,
+ content: JsonDict,
knock_event_id: str,
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
txn_id = content["txn_id"]
event_content = content["content"]
@@ -347,7 +340,12 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, room_id: str, user_id: str, change: str
+ self,
+ request: Request,
+ content: JsonDict,
+ room_id: str,
+ user_id: str,
+ change: str,
) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index 4a5b08f56f..db16aac9c2 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, UserID
@@ -56,7 +55,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
await self._presence_handler.bump_presence_active_time(
UserID.from_string(user_id)
@@ -107,10 +106,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
await self._presence_handler.set_state(
UserID.from_string(user_id),
content["state"],
diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py
index af5c2f66a7..297e8ad564 100644
--- a/synapse/replication/http/push.py
+++ b/synapse/replication/http/push.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -61,10 +60,8 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
app_id = content["app_id"]
pushkey = content["pushkey"]
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 976c283360..265e601b96 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -96,10 +95,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
await self.registration_handler.check_registration_ratelimit(content["address"])
# Always default admin users to approved (since it means they were created by
@@ -150,10 +147,8 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
return {"auth_result": auth_result, "access_token": access_token}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
auth_result = content["auth_result"]
access_token = content["access_token"]
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 4215a1c1bc..27ad914075 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
from synapse.util.metrics import Measure
@@ -114,11 +113,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request, event_id: str
+ self, request: Request, content: JsonDict, event_id: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_send_event_parse"):
- content = parse_json_object_from_request(request)
-
event_dict = content["event"]
room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
internal_metadata = content["internal_metadata"]
diff --git a/synapse/replication/http/send_events.py b/synapse/replication/http/send_events.py
index 8889bbb644..4f82c9f96d 100644
--- a/synapse/replication/http/send_events.py
+++ b/synapse/replication/http/send_events.py
@@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
from synapse.util.metrics import Measure
@@ -114,10 +113,9 @@ class ReplicationSendEventsRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request
+ self, request: Request, payload: JsonDict
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_send_events_parse"):
- payload = parse_json_object_from_request(request)
events_and_context = []
events = payload["events"]
diff --git a/synapse/replication/http/state.py b/synapse/replication/http/state.py
index 838b7584e5..0c524e7de3 100644
--- a/synapse/replication/http/state.py
+++ b/synapse/replication/http/state.py
@@ -57,7 +57,7 @@ class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, room_id: str
+ self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
writer_instance = self._events_shard_config.get_instance(room_id)
if writer_instance != self._instance_name:
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index c065225362..3c7b5b18ea 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -54,6 +54,10 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
PATH_ARGS = ("stream_name",)
METHOD = "GET"
+ # We don't want to wait for replication streams to catch up, as this gets
+ # called in the process of catching replication streams up.
+ WAIT_FOR_STREAMS = False
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -67,7 +71,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request( # type: ignore[override]
- self, request: Request, stream_name: str
+ self, request: Request, content: JsonDict, stream_name: str
) -> Tuple[int, JsonDict]:
stream = self.streams.get(stream_name)
if stream is None:
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 8f3ce1edd3..f1dc435f8d 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,6 +16,7 @@
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory
@@ -33,15 +34,20 @@ from synapse.replication.tcp.streams import (
PushersStream,
PushRulesStream,
ReceiptsStream,
- TagAccountDataStream,
ToDeviceStream,
TypingStream,
+ UnPartialStatedEventStream,
+ UnPartialStatedRoomStream,
)
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
+from synapse.replication.tcp.streams.partial_state import (
+ UnPartialStatedEventStreamRow,
+ UnPartialStatedRoomStreamRow,
+)
from synapse.types import PersistedEventPosition, ReadReceipt, StreamKeyType, UserID
from synapse.util.async_helpers import Linearizer, timeout_deferred
from synapse.util.metrics import Measure
@@ -53,7 +59,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# How long we allow callers to wait for replication updates before timing out.
-_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
+_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 5
class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
@@ -117,6 +123,7 @@ class ReplicationDataHandler:
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler()
+ self._state_storage_controller = hs.get_storage_controllers().state
self._notify_pushers = hs.config.worker.start_pushers
self._pusher_pool = hs.get_pusherpool()
@@ -126,9 +133,9 @@ class ReplicationDataHandler:
if hs.should_send_federation():
self.send_handler = FederationSenderHandler(hs)
- # Map from stream to list of deferreds waiting for the stream to
+ # Map from stream and instance to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
- self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {}
+ self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {}
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -145,6 +152,9 @@ class ReplicationDataHandler:
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, instance_name, token, rows)
+ # NOTE: this must be called after process_replication_rows to ensure any
+ # cache invalidations are first handled before any stream ID advances.
+ self.store.process_replication_position(stream_name, instance_name, token)
if self.send_handler:
await self.send_handler.process_replication_rows(stream_name, token, rows)
@@ -158,7 +168,7 @@ class ReplicationDataHandler:
self.notifier.on_new_event(
StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows]
)
- elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
+ elif stream_name in AccountDataStream.NAME:
self.notifier.on_new_event(
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
)
@@ -178,7 +188,7 @@ class ReplicationDataHandler:
elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
- if row.entity.startswith("@"):
+ if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
self.notifier.on_new_event(
@@ -197,6 +207,12 @@ class ReplicationDataHandler:
# we don't need to optimise this for multiple rows.
for row in rows:
if row.type != EventsStreamEventRow.TypeId:
+ # The row's data is an `EventsStreamCurrentStateRow`.
+ # When we recompute the current state of a room based on forward
+ # extremities (see `update_current_state`), no new events are
+ # persisted, so we must poke the replication callbacks ourselves.
+ # This functionality is used when finishing up a partial state join.
+ self.notifier.notify_replication()
continue
assert isinstance(row, EventsStreamRow)
assert isinstance(row.data, EventsStreamEventRow)
@@ -236,6 +252,23 @@ class ReplicationDataHandler:
self.notifier.notify_user_joined_room(
row.data.event_id, row.data.room_id
)
+ elif stream_name == UnPartialStatedRoomStream.NAME:
+ for row in rows:
+ assert isinstance(row, UnPartialStatedRoomStreamRow)
+
+ # Wake up any tasks waiting for the room to be un-partial-stated.
+ self._state_storage_controller.notify_room_un_partial_stated(
+ row.room_id
+ )
+ await self.notifier.on_un_partial_stated_room(row.room_id, token)
+ elif stream_name == UnPartialStatedEventStream.NAME:
+ for row in rows:
+ assert isinstance(row, UnPartialStatedEventStreamRow)
+
+ # Wake up any tasks waiting for the event to be un-partial-stated.
+ self._state_storage_controller.notify_event_un_partial_stated(
+ row.event_id
+ )
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
@@ -244,7 +277,7 @@ class ReplicationDataHandler:
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
- waiting_list = self._streams_to_waiters.get(stream_name, [])
+ waiting_list = self._streams_to_waiters.get((stream_name, instance_name), [])
# Index of first item with a position after the current token, i.e we
# have called all deferreds before this index. If not overwritten by
@@ -253,14 +286,13 @@ class ReplicationDataHandler:
# `len(list)` works for both cases.
index_of_first_deferred_not_called = len(waiting_list)
+ # We don't fire the deferreds until after we finish iterating over the
+ # list, to avoid the list changing when we fire the deferreds.
+ deferreds_to_callback = []
+
for idx, (position, deferred) in enumerate(waiting_list):
if position <= token:
- try:
- with PreserveLoggingContext():
- deferred.callback(None)
- except Exception:
- # The deferred has been cancelled or timed out.
- pass
+ deferreds_to_callback.append(deferred)
else:
# The list is sorted by position so we don't need to continue
# checking any further entries in the list.
@@ -271,6 +303,14 @@ class ReplicationDataHandler:
# loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+ for deferred in deferreds_to_callback:
+ try:
+ with PreserveLoggingContext():
+ deferred.callback(None)
+ except Exception:
+ # The deferred has been cancelled or timed out.
+ pass
+
async def on_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
@@ -289,10 +329,18 @@ class ReplicationDataHandler:
self.send_handler.wake_destination(server)
async def wait_for_stream_position(
- self, instance_name: str, stream_name: str, position: int
+ self,
+ instance_name: str,
+ stream_name: str,
+ position: int,
) -> None:
"""Wait until this instance has received updates up to and including
the given stream position.
+
+ Args:
+ instance_name
+ stream_name
+ position
"""
if instance_name == self._instance_name:
@@ -300,7 +348,7 @@ class ReplicationDataHandler:
# anyway in that case we don't need to wait.
return
- current_position = self._streams[stream_name].current_token(self._instance_name)
+ current_position = self._streams[stream_name].current_token(instance_name)
if position <= current_position:
# We're already past the position
return
@@ -312,17 +360,32 @@ class ReplicationDataHandler:
deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
)
- waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
+ waiting_list = self._streams_to_waiters.setdefault(
+ (stream_name, instance_name), []
+ )
waiting_list.append((position, deferred))
waiting_list.sort(key=lambda t: t[0])
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
- logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
- await make_deferred_yieldable(deferred)
logger.info(
- "Finished waiting for repl stream %r to reach %s", stream_name, position
+ "Waiting for repl stream %r to reach %s (%s)",
+ stream_name,
+ position,
+ instance_name,
+ )
+ try:
+ await make_deferred_yieldable(deferred)
+ except defer.TimeoutError:
+ logger.error("Timed out waiting for stream %s", stream_name)
+ return
+
+ logger.info(
+ "Finished waiting for repl stream %r to reach %s (%s)",
+ stream_name,
+ position,
+ instance_name,
)
def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
@@ -397,7 +460,11 @@ class FederationSenderHandler:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
- hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+ hosts = {
+ row.entity
+ for row in rows
+ if not row.entity.startswith("@") and not row.is_signature
+ }
for host in hosts:
self.federation_sender.send_device_messages(host, immediate=False)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index ca99381648..ffe882bc99 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -58,7 +58,6 @@ from synapse.replication.tcp.streams import (
PresenceStream,
ReceiptsStream,
Stream,
- TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
@@ -145,7 +144,7 @@ class ReplicationCommandHandler:
continue
- if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+ if isinstance(stream, AccountDataStream):
# Only add AccountDataStream and TagAccountDataStream as a source on the
# instance in charge of account_data persistence.
if hs.get_instance_name() in hs.config.worker.writers.account_data:
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 9eb3e34695..ce95714ea0 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -199,33 +199,28 @@ class ReplicationStreamer:
# The token has advanced but there is no data to
# send, so we send a `POSITION` to inform other
# workers of the updated position.
- if stream.NAME == EventsStream.NAME:
- # XXX: We only do this for the EventStream as it
- # turns out that e.g. account data streams share
- # their "current token" with each other, meaning
- # that it is *not* safe to send a POSITION.
- # Note: `last_token` may not *actually* be the
- # last token we sent out in a RDATA or POSITION.
- # This can happen if we sent out an RDATA for
- # position X when our current token was say X+1.
- # Other workers will see RDATA for X and then a
- # POSITION with last token of X+1, which will
- # cause them to check if there were any missing
- # updates between X and X+1.
- logger.info(
- "Sending position: %s -> %s",
+ # Note: `last_token` may not *actually* be the
+ # last token we sent out in a RDATA or POSITION.
+ # This can happen if we sent out an RDATA for
+ # position X when our current token was say X+1.
+ # Other workers will see RDATA for X and then a
+ # POSITION with last token of X+1, which will
+ # cause them to check if there were any missing
+ # updates between X and X+1.
+ logger.info(
+ "Sending position: %s -> %s",
+ stream.NAME,
+ current_token,
+ )
+ self.command_handler.send_command(
+ PositionCommand(
stream.NAME,
+ self._instance_name,
+ last_token,
current_token,
)
- self.command_handler.send_command(
- PositionCommand(
- stream.NAME,
- self._instance_name,
- last_token,
- current_token,
- )
- )
+ )
continue
# Some streams return multiple rows with the same stream IDs,
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index b1cd55bf6f..9c67f661a3 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -35,13 +35,15 @@ from synapse.replication.tcp.streams._base import (
PushRulesStream,
ReceiptsStream,
Stream,
- TagAccountDataStream,
ToDeviceStream,
TypingStream,
- UserSignatureStream,
)
from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.federation import FederationStream
+from synapse.replication.tcp.streams.partial_state import (
+ UnPartialStatedEventStream,
+ UnPartialStatedRoomStream,
+)
STREAMS_MAP = {
stream.NAME: stream
@@ -58,9 +60,9 @@ STREAMS_MAP = {
DeviceListsStream,
ToDeviceStream,
FederationStream,
- TagAccountDataStream,
AccountDataStream,
- UserSignatureStream,
+ UnPartialStatedRoomStream,
+ UnPartialStatedEventStream,
)
}
@@ -77,7 +79,7 @@ __all__ = [
"CachesStream",
"DeviceListsStream",
"ToDeviceStream",
- "TagAccountDataStream",
"AccountDataStream",
- "UserSignatureStream",
+ "UnPartialStatedRoomStream",
+ "UnPartialStatedEventStream",
]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index e01155ad59..a4bdb48c0c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -28,8 +28,8 @@ from typing import (
import attr
+from synapse.api.constants import AccountDataTypes
from synapse.replication.http.streams import ReplicationGetStreamUpdates
-from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -463,18 +463,67 @@ class DeviceListsStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
entity: str
+ # Indicates that a user has signed their own device with their user-signing key
+ is_signature: bool
NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
+ self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_device_stream_token),
- store.get_all_device_list_changes_for_remotes,
+ current_token_without_instance(self.store.get_device_stream_token),
+ self._update_function,
)
+ async def _update_function(
+ self,
+ instance_name: str,
+ from_token: Token,
+ current_token: Token,
+ target_row_count: int,
+ ) -> StreamUpdateResult:
+ (
+ device_updates,
+ devices_to_token,
+ devices_limited,
+ ) = await self.store.get_all_device_list_changes_for_remotes(
+ instance_name, from_token, current_token, target_row_count
+ )
+
+ (
+ signatures_updates,
+ signatures_to_token,
+ signatures_limited,
+ ) = await self.store.get_all_user_signature_changes_for_remotes(
+ instance_name, from_token, current_token, target_row_count
+ )
+
+ upper_limit_token = current_token
+ if devices_limited:
+ upper_limit_token = min(upper_limit_token, devices_to_token)
+ if signatures_limited:
+ upper_limit_token = min(upper_limit_token, signatures_to_token)
+
+ device_updates = [
+ (stream_id, (entity, False))
+ for stream_id, (entity,) in device_updates
+ if stream_id <= upper_limit_token
+ ]
+
+ signatures_updates = [
+ (stream_id, (entity, True))
+ for stream_id, (entity,) in signatures_updates
+ if stream_id <= upper_limit_token
+ ]
+
+ updates = list(
+ heapq.merge(device_updates, signatures_updates, key=lambda row: row[0])
+ )
+
+ return updates, upper_limit_token, devices_limited or signatures_limited
+
class ToDeviceStream(Stream):
"""New to_device messages for a client"""
@@ -495,27 +544,6 @@ class ToDeviceStream(Stream):
)
-class TagAccountDataStream(Stream):
- """Someone added/removed a tag for a room"""
-
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class TagAccountDataStreamRow:
- user_id: str
- room_id: str
- data: JsonDict
-
- NAME = "tag_account_data"
- ROW_TYPE = TagAccountDataStreamRow
-
- def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
- super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_max_account_data_stream_id),
- store.get_all_updated_tags,
- )
-
-
class AccountDataStream(Stream):
"""Global or per room account data was changed"""
@@ -560,6 +588,19 @@ class AccountDataStream(Stream):
to_token = room_results[-1][0]
limited = True
+ tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags(
+ instance_name,
+ from_token,
+ to_token,
+ limit,
+ )
+
+ # again, if the tag results hit the limit, limit the global results to
+ # the same stream token.
+ if tags_limited:
+ to_token = tag_to_token
+ limited = True
+
# convert the global results to the right format, and limit them to the to_token
# at the same time
global_rows = (
@@ -568,11 +609,16 @@ class AccountDataStream(Stream):
if stream_id <= to_token
)
- # we know that the room_results are already limited to `to_token` so no need
- # for a check on `stream_id` here.
room_rows = (
(stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results
+ if stream_id <= to_token
+ )
+
+ tag_rows = (
+ (stream_id, (user_id, room_id, AccountDataTypes.TAG))
+ for stream_id, user_id, room_id in tags
+ if stream_id <= to_token
)
# We need to return a sorted list, so merge them together.
@@ -582,24 +628,7 @@ class AccountDataStream(Stream):
# leading to a comparison between the data tuples. The comparison could
# fail due to attempting to compare the `room_id` which results in a
# `TypeError` from comparing a `str` vs `None`.
- updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
- return updates, to_token, limited
-
-
-class UserSignatureStream(Stream):
- """A user has signed their own device with their user-signing key"""
-
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class UserSignatureStreamRow:
- user_id: str
-
- NAME = "user_signature"
- ROW_TYPE = UserSignatureStreamRow
-
- def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
- super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_device_stream_token),
- store.get_all_user_signature_changes_for_remotes,
+ updates = list(
+ heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
)
+ return updates, to_token, limited
diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py
new file mode 100644
index 0000000000..a8ce5ffd72
--- /dev/null
+++ b/synapse/replication/tcp/streams/partial_state.py
@@ -0,0 +1,73 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+import attr
+
+from synapse.replication.tcp.streams import Stream
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UnPartialStatedRoomStreamRow:
+ # ID of the room that has been un-partial-stated.
+ room_id: str
+
+
+class UnPartialStatedRoomStream(Stream):
+ """
+ Stream to notify about rooms becoming un-partial-stated;
+ that is, when the background sync finishes such that we now have full state for
+ the room.
+ """
+
+ NAME = "un_partial_stated_room"
+ ROW_TYPE = UnPartialStatedRoomStreamRow
+
+ def __init__(self, hs: "HomeServer"):
+ store = hs.get_datastores().main
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_un_partial_stated_rooms_token,
+ store.get_un_partial_stated_rooms_from_stream,
+ )
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UnPartialStatedEventStreamRow:
+ # ID of the event that has been un-partial-stated.
+ event_id: str
+
+ # True iff the rejection status of the event changed as a result of being
+ # un-partial-stated.
+ rejection_status_changed: bool
+
+
+class UnPartialStatedEventStream(Stream):
+ """
+ Stream to notify about events becoming un-partial-stated.
+ """
+
+ NAME = "un_partial_stated_event"
+ ROW_TYPE = UnPartialStatedEventStreamRow
+
+ def __init__(self, hs: "HomeServer"):
+ store = hs.get_datastores().main
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_un_partial_stated_events_token,
+ store.get_un_partial_stated_events_from_stream,
+ )
diff --git a/synapse/res/templates/_base.html b/synapse/res/templates/_base.html
index 46439fce6a..4b5cc7bcb6 100644
--- a/synapse/res/templates/_base.html
+++ b/synapse/res/templates/_base.html
@@ -13,13 +13,13 @@
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
index 406397aaca..f62038e111 100644
--- a/synapse/res/templates/notice_expiry.html
+++ b/synapse/res/templates/notice_expiry.html
@@ -21,13 +21,13 @@
{% if app_name == "Riot" %}
-
+
{% elif app_name == "Vector" %}
-
+
{% elif app_name == "Element" %}
{% else %}
-
+
{% endif %}
|
diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html
index 2add9dd859..7da0fff5e9 100644
--- a/synapse/res/templates/notif_mail.html
+++ b/synapse/res/templates/notif_mail.html
@@ -22,13 +22,13 @@
{%- if app_name == "Riot" %}
-
+
{%- elif app_name == "Vector" %}
-
+
{%- elif app_name == "Element" %}
{%- else %}
-
+
{%- endif %}
|
diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html
index 8204928cdf..f00992a24b 100644
--- a/synapse/res/templates/recaptcha.html
+++ b/synapse/res/templates/recaptcha.html
@@ -3,11 +3,10 @@
{% block header %}
-
{% endblock %}
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 28542cd774..14c4e6ebbb 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -29,7 +29,7 @@ from synapse.rest.client import (
initial_sync,
keys,
knock,
- login as v1_login,
+ login,
login_token_request,
logout,
mutual_rooms,
@@ -82,6 +82,10 @@ class ClientRestResource(JsonResource):
@staticmethod
def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
+ # Some servlets are only registered on the main process (and not worker
+ # processes).
+ is_main_process = hs.config.worker.worker_app is None
+
versions.register_servlets(hs, client_resource)
# Deprecated in r0
@@ -92,45 +96,58 @@ class ClientRestResource(JsonResource):
events.register_servlets(hs, client_resource)
room.register_servlets(hs, client_resource)
- v1_login.register_servlets(hs, client_resource)
+ login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
- directory.register_servlets(hs, client_resource)
+ if is_main_process:
+ directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
- pusher.register_servlets(hs, client_resource)
+ if is_main_process:
+ pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource)
- logout.register_servlets(hs, client_resource)
+ if is_main_process:
+ logout.register_servlets(hs, client_resource)
sync.register_servlets(hs, client_resource)
- filter.register_servlets(hs, client_resource)
+ if is_main_process:
+ filter.register_servlets(hs, client_resource)
account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
- auth.register_servlets(hs, client_resource)
+ if is_main_process:
+ auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource)
read_marker.register_servlets(hs, client_resource)
room_keys.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource)
- tokenrefresh.register_servlets(hs, client_resource)
+ if is_main_process:
+ tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource)
account_data.register_servlets(hs, client_resource)
- report_event.register_servlets(hs, client_resource)
- openid.register_servlets(hs, client_resource)
- notifications.register_servlets(hs, client_resource)
+ if is_main_process:
+ report_event.register_servlets(hs, client_resource)
+ openid.register_servlets(hs, client_resource)
+ notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)
- thirdparty.register_servlets(hs, client_resource)
+ if is_main_process:
+ thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)
user_directory.register_servlets(hs, client_resource)
- room_upgrade_rest_servlet.register_servlets(hs, client_resource)
+ if is_main_process:
+ room_upgrade_rest_servlet.register_servlets(hs, client_resource)
room_batch.register_servlets(hs, client_resource)
- capabilities.register_servlets(hs, client_resource)
- account_validity.register_servlets(hs, client_resource)
+ if is_main_process:
+ capabilities.register_servlets(hs, client_resource)
+ account_validity.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
- password_policy.register_servlets(hs, client_resource)
- knock.register_servlets(hs, client_resource)
+ if is_main_process:
+ password_policy.register_servlets(hs, client_resource)
+ knock.register_servlets(hs, client_resource)
# moving to /_synapse/admin
- admin.register_servlets_for_client_rest_resource(hs, client_resource)
+ if is_main_process:
+ admin.register_servlets_for_client_rest_resource(hs, client_resource)
# unstable
- mutual_rooms.register_servlets(hs, client_resource)
- login_token_request.register_servlets(hs, client_resource)
- rendezvous.register_servlets(hs, client_resource)
+ if is_main_process:
+ mutual_rooms.register_servlets(hs, client_resource)
+ login_token_request.register_servlets(hs, client_resource)
+ rendezvous.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index c62ea22116..79f22a59f1 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -152,7 +152,7 @@ class PurgeHistoryRestServlet(RestServlet):
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body:
ts = body["purge_up_to_ts"]
- if not isinstance(ts, int):
+ if type(ts) is not int:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"purge_up_to_ts must be an int",
@@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
"""
Register all the admin servlets.
"""
+ # Admin servlets aren't registered on workers.
+ if hs.config.worker.worker_app is not None:
+ return
+
register_servlets_for_client_rest_resource(hs, http_server)
BlockRoomRestServlet(hs).register(http_server)
ListRoomRestServlet(hs).register(http_server)
@@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
- DeviceRestServlet(hs).register(http_server)
- DevicesRestServlet(hs).register(http_server)
- DeleteDevicesRestServlet(hs).register(http_server)
UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
@@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserByExternalId(hs).register(http_server)
UserByThreePid(hs).register(http_server)
- # Some servlets only get registered for the main process.
- if hs.config.worker.worker_app is None:
- SendServerNoticeServlet(hs).register(http_server)
- BackgroundUpdateEnabledRestServlet(hs).register(http_server)
- BackgroundUpdateRestServlet(hs).register(http_server)
- BackgroundUpdateStartJobRestServlet(hs).register(http_server)
+ DeviceRestServlet(hs).register(http_server)
+ DevicesRestServlet(hs).register(http_server)
+ DeleteDevicesRestServlet(hs).register(http_server)
+ SendServerNoticeServlet(hs).register(http_server)
+ BackgroundUpdateEnabledRestServlet(hs).register(http_server)
+ BackgroundUpdateRestServlet(hs).register(http_server)
+ BackgroundUpdateStartJobRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
@@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource(
"""Register only the servlets which need to be exposed on /_matrix/client/xxx"""
WhoisRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
- DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
- ResetPasswordRestServlet(hs).register(http_server)
+ # The following resources can only be run on the main process.
+ if hs.config.worker.worker_app is None:
+ DeactivateAccountRestServlet(hs).register(http_server)
+ ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index d934880102..3b2f2d9abb 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -16,6 +16,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
@@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
@@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 6d634eef70..a3beb74e2c 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -16,8 +16,9 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
@@ -60,7 +61,7 @@ class EventReportsRestServlet(RestServlet):
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
- direction = parse_string(request, "dir", default="b")
+ direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS)
user_id = parse_string(request, "user_id")
room_id = parse_string(request, "room_id")
@@ -78,13 +79,6 @@ class EventReportsRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- if direction not in ("f", "b"):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Unknown direction: %s" % (direction,),
- errcode=Codes.INVALID_PARAM,
- )
-
event_reports, total = await self.store.get_event_reports_paginate(
start, limit, direction, user_id, room_id
)
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 023ed92144..e0ee55bd0e 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -15,9 +15,10 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.federation.transport.server import Authenticator
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.storage.databases.main.transactions import DestinationSortOrder
@@ -79,7 +80,7 @@ class ListDestinationsRestServlet(RestServlet):
allowed_values=[dest.value for dest in DestinationSortOrder],
)
- direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
@@ -192,7 +193,7 @@ class DestinationMembershipRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
rooms, total = await self._store.get_destination_rooms_paginate(
destination, start, limit, direction
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 73470f09ae..0d072c42a7 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,9 +17,16 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
+from synapse.http.servlet import (
+ RestServlet,
+ parse_boolean,
+ parse_enum,
+ parse_integer,
+ parse_string,
+)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
@@ -389,7 +396,7 @@ class UserMediaRestServlet(RestServlet):
# to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value
- direction = "b"
+ direction = Direction.BACKWARDS
else:
order_by = parse_string(
request,
@@ -397,8 +404,8 @@ class UserMediaRestServlet(RestServlet):
default=MediaSortOrder.CREATED_TS.value,
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
- direction = parse_string(
- request, "dir", default="f", allowed_values=("f", "b")
+ direction = parse_enum(
+ request, "dir", Direction, default=Direction.FORWARDS
)
media, total = await self.store.get_local_media_by_user_paginate(
@@ -447,7 +454,7 @@ class UserMediaRestServlet(RestServlet):
# to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value
- direction = "b"
+ direction = Direction.BACKWARDS
else:
order_by = parse_string(
request,
@@ -455,8 +462,8 @@ class UserMediaRestServlet(RestServlet):
default=MediaSortOrder.CREATED_TS.value,
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
- direction = parse_string(
- request, "dir", default="f", allowed_values=("f", "b")
+ direction = parse_enum(
+ request, "dir", Direction, default=Direction.FORWARDS
)
media, _ = await self.store.get_local_media_by_user_paginate(
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index af606e9252..95e751288b 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -143,7 +143,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
else:
# Get length of token to generate (default is 16)
length = body.get("length", 16)
- if not isinstance(length, int):
+ if type(length) is not int:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"length must be an integer",
@@ -163,8 +163,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
uses_allowed = body.get("uses_allowed", None)
if not (
- uses_allowed is None
- or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+ uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0)
):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
@@ -173,13 +172,13 @@ class NewRegistrationTokenRestServlet(RestServlet):
)
expiry_time = body.get("expiry_time", None)
- if not isinstance(expiry_time, (int, type(None))):
+ if type(expiry_time) not in (int, type(None)):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"expiry_time must be an integer or null",
Codes.INVALID_PARAM,
)
- if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+ if type(expiry_time) is int and expiry_time < self.clock.time_msec():
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"expiry_time must not be in the past",
@@ -284,7 +283,7 @@ class RegistrationTokenRestServlet(RestServlet):
uses_allowed = body["uses_allowed"]
if not (
uses_allowed is None
- or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+ or (type(uses_allowed) is int and uses_allowed >= 0)
):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
@@ -295,13 +294,13 @@ class RegistrationTokenRestServlet(RestServlet):
if "expiry_time" in body:
expiry_time = body["expiry_time"]
- if not isinstance(expiry_time, (int, type(None))):
+ if type(expiry_time) not in (int, type(None)):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"expiry_time must be an integer or null",
Codes.INVALID_PARAM,
)
- if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+ if type(expiry_time) is int and expiry_time < self.clock.time_msec():
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"expiry_time must not be in the past",
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 747e6fda83..1d6e4982d7 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -16,13 +16,14 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from urllib import parse as urlparse
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
assert_params_in_dict,
+ parse_enum,
parse_integer,
parse_json_object_from_request,
parse_string,
@@ -34,9 +35,9 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
-from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, RoomID, UserID, create_requester
+from synapse.types.state import StateFilter
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -224,15 +225,8 @@ class ListRoomRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- direction = parse_string(request, "dir", default="f")
- if direction not in ("f", "b"):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Unknown direction: %s" % (direction,),
- errcode=Codes.INVALID_PARAM,
- )
-
- reverse_order = True if direction == "b" else False
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
+ reverse_order = True if direction == Direction.BACKWARDS else False
# Return list of rooms according to parameters
rooms, total_rooms = await self.store.get_rooms_paginate(
@@ -949,7 +943,7 @@ class RoomTimestampToEventRestServlet(RestServlet):
await assert_user_is_admin(self._auth, requester)
timestamp = parse_integer(request, "ts", required=True)
- direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
(
event_id,
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 3b142b8402..9c45f4650d 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -16,8 +16,9 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import Direction
from synapse.api.errors import Codes, SynapseError
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.storage.databases.main.stats import UserSortOrder
@@ -102,13 +103,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- direction = parse_string(request, "dir", default="f")
- if direction not in ("f", "b"):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Unknown direction: %s" % (direction,),
- errcode=Codes.INVALID_PARAM,
- )
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 6e0c44be2a..b9dca8ef3a 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -18,12 +18,13 @@ import secrets
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
-from synapse.api.constants import UserTypes
+from synapse.api.constants import Direction, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_boolean,
+ parse_enum,
parse_integer,
parse_json_object_from_request,
parse_string,
@@ -120,7 +121,7 @@ class UsersRestServletV2(RestServlet):
),
)
- direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
users, total = await self.store.get_users_paginate(
start,
@@ -973,7 +974,7 @@ class UserTokenRestServlet(RestServlet):
body = parse_json_object_from_request(request, allow_empty_body=True)
valid_until_ms = body.get("valid_until_ms")
- if valid_until_ms and not isinstance(valid_until_ms, int):
+ if type(valid_until_ms) not in (int, type(None)):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int"
)
@@ -1125,14 +1126,14 @@ class RateLimitRestServlet(RestServlet):
messages_per_second = body.get("messages_per_second", 0)
burst_count = body.get("burst_count", 0)
- if not isinstance(messages_per_second, int) or messages_per_second < 0:
+ if type(messages_per_second) is not int or messages_per_second < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"%r parameter must be a positive int" % (messages_per_second,),
errcode=Codes.INVALID_PARAM,
)
- if not isinstance(burst_count, int) or burst_count < 0:
+ if type(burst_count) is not int or burst_count < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"%r parameter must be a positive int" % (burst_count,),
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 44f622bcce..4373c73662 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -338,6 +338,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ if not self.hs.config.registration.enable_3pid_changes:
+ raise SynapseError(
+ 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+ )
+
if not self.config.email.can_verify_email:
logger.warning(
"Adding emails have been disabled due to lack of an email config"
@@ -875,19 +880,21 @@ class AccountStatusRestServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- EmailPasswordRequestTokenRestServlet(hs).register(http_server)
- PasswordRestServlet(hs).register(http_server)
- DeactivateAccountRestServlet(hs).register(http_server)
- EmailThreepidRequestTokenRestServlet(hs).register(http_server)
- MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
- AddThreepidEmailSubmitTokenServlet(hs).register(http_server)
- AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server)
+ if hs.config.worker.worker_app is None:
+ EmailPasswordRequestTokenRestServlet(hs).register(http_server)
+ PasswordRestServlet(hs).register(http_server)
+ DeactivateAccountRestServlet(hs).register(http_server)
+ EmailThreepidRequestTokenRestServlet(hs).register(http_server)
+ MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
+ AddThreepidEmailSubmitTokenServlet(hs).register(http_server)
+ AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
- ThreepidAddRestServlet(hs).register(http_server)
- ThreepidBindRestServlet(hs).register(http_server)
- ThreepidUnbindRestServlet(hs).register(http_server)
- ThreepidDeleteRestServlet(hs).register(http_server)
+ if hs.config.worker.worker_app is None:
+ ThreepidAddRestServlet(hs).register(http_server)
+ ThreepidBindRestServlet(hs).register(http_server)
+ ThreepidUnbindRestServlet(hs).register(http_server)
+ ThreepidDeleteRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
- if hs.config.experimental.msc3720_enabled:
+ if hs.config.worker.worker_app is None and hs.config.experimental.msc3720_enabled:
AccountStatusRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index f13970b898..e805196fec 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -41,6 +41,7 @@ class AccountDataServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
+ self._hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.handler = hs.get_account_data_handler()
@@ -54,6 +55,16 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
+ # If experimental support for MSC3391 is enabled, then providing an empty dict
+ # as the value for an account data type should be functionally equivalent to
+ # calling the DELETE method on the same type.
+ if self._hs.config.experimental.msc3391_enabled:
+ if body == {}:
+ await self.handler.remove_account_data_for_user(
+ user_id, account_data_type
+ )
+ return 200, {}
+
await self.handler.add_account_data_for_user(user_id, account_data_type, body)
return 200, {}
@@ -72,9 +83,48 @@ class AccountDataServlet(RestServlet):
if event is None:
raise NotFoundError("Account data not found")
+ # If experimental support for MSC3391 is enabled, then this endpoint should
+ # return a 404 if the content for an account data type is an empty dict.
+ if self._hs.config.experimental.msc3391_enabled and event == {}:
+ raise NotFoundError("Account data not found")
+
return 200, event
+class UnstableAccountDataServlet(RestServlet):
+ """
+ Contains an unstable endpoint for removing user account data, as specified by
+ MSC3391. If that MSC is accepted, this code should have unstable prefixes removed
+ and become incorporated into AccountDataServlet above.
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc3391/user/(?P[^/]*)"
+ "/account_data/(?P[^/]*)",
+ unstable=True,
+ releases=(),
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.handler = hs.get_account_data_handler()
+
+ async def on_DELETE(
+ self,
+ request: SynapseRequest,
+ user_id: str,
+ account_data_type: str,
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ if user_id != requester.user.to_string():
+ raise AuthError(403, "Cannot delete account data for other users.")
+
+ await self.handler.remove_account_data_for_user(user_id, account_data_type)
+
+ return 200, {}
+
+
class RoomAccountDataServlet(RestServlet):
"""
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
@@ -89,6 +139,7 @@ class RoomAccountDataServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
+ self._hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.handler = hs.get_account_data_handler()
@@ -121,6 +172,16 @@ class RoomAccountDataServlet(RestServlet):
Codes.BAD_JSON,
)
+ # If experimental support for MSC3391 is enabled, then providing an empty dict
+ # as the value for an account data type should be functionally equivalent to
+ # calling the DELETE method on the same type.
+ if self._hs.config.experimental.msc3391_enabled:
+ if body == {}:
+ await self.handler.remove_account_data_for_room(
+ user_id, room_id, account_data_type
+ )
+ return 200, {}
+
await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
@@ -152,9 +213,63 @@ class RoomAccountDataServlet(RestServlet):
if event is None:
raise NotFoundError("Room account data not found")
+ # If experimental support for MSC3391 is enabled, then this endpoint should
+ # return a 404 if the content for an account data type is an empty dict.
+ if self._hs.config.experimental.msc3391_enabled and event == {}:
+ raise NotFoundError("Room account data not found")
+
return 200, event
+class UnstableRoomAccountDataServlet(RestServlet):
+ """
+ Contains an unstable endpoint for removing room account data, as specified by
+ MSC3391. If that MSC is accepted, this code should have unstable prefixes removed
+ and become incorporated into RoomAccountDataServlet above.
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc3391/user/(?P[^/]*)"
+ "/rooms/(?P[^/]*)"
+ "/account_data/(?P[^/]*)",
+ unstable=True,
+ releases=(),
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.handler = hs.get_account_data_handler()
+
+ async def on_DELETE(
+ self,
+ request: SynapseRequest,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ if user_id != requester.user.to_string():
+ raise AuthError(403, "Cannot delete account data for other users.")
+
+ if not RoomID.is_valid(room_id):
+ raise SynapseError(
+ 400,
+ f"{room_id} is not a valid room ID",
+ Codes.INVALID_PARAM,
+ )
+
+ await self.handler.remove_account_data_for_room(
+ user_id, room_id, account_data_type
+ )
+
+ return 200, {}
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountDataServlet(hs).register(http_server)
RoomAccountDataServlet(hs).register(http_server)
+
+ if hs.config.experimental.msc3391_enabled:
+ UnstableAccountDataServlet(hs).register(http_server)
+ UnstableRoomAccountDataServlet(hs).register(http_server)
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8f3cbd4ea2..486c6dbbc5 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr
from synapse.api import errors
from synapse.api.errors import NotFoundError
+from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel):
@@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.auth_handler = hs.get_auth_handler()
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
@@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
class PostBody(RequestBodyModel):
device_id: StrictStr
@@ -333,8 +342,10 @@ class ClaimDehydratedDeviceServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- DeleteDevicesRestServlet(hs).register(http_server)
+ if hs.config.worker.worker_app is None:
+ DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
- DeviceRestServlet(hs).register(http_server)
- DehydratedDeviceServlet(hs).register(http_server)
- ClaimDehydratedDeviceServlet(hs).register(http_server)
+ if hs.config.worker.worker_app is None:
+ DeviceRestServlet(hs).register(http_server)
+ DehydratedDeviceServlet(hs).register(http_server)
+ ClaimDehydratedDeviceServlet(hs).register(http_server)
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index ee038c7192..7873b363c0 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -376,5 +376,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)
- SigningKeyUploadServlet(hs).register(http_server)
- SignaturesUploadServlet(hs).register(http_server)
+ if hs.config.worker.worker_app is None:
+ SigningKeyUploadServlet(hs).register(http_server)
+ SignaturesUploadServlet(hs).register(http_server)
diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py
index 23dfa4518f..6d34625ad5 100644
--- a/synapse/rest/client/logout.py
+++ b/synapse/rest/client/logout.py
@@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
+from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
@@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
- self._device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
@@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
- self._device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 8191b4e32c..ad5c10c99d 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, List, Tuple, Union
from synapse.api.errors import (
NotFoundError,
@@ -169,7 +169,7 @@ class PushRuleRestServlet(RestServlet):
raise UnrecognizedRequestError()
-def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
+def _rule_spec_from_path(path: List[str]) -> RuleSpec:
"""Turn a sequence of path components into a rule spec
Args:
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 18a282b22c..28b7d30ea8 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -20,7 +20,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict
+from synapse.types import EventID, JsonDict, RoomID
from ._base import client_patterns
@@ -56,6 +56,9 @@ class ReceiptRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
+ if not RoomID.is_valid(room_id) or not event_id.startswith(EventID.SIGIL):
+ raise SynapseError(400, "A valid room ID and event ID must be specified")
+
if receipt_type not in self._known_receipt_types:
raise SynapseError(
400,
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index de810ae3ec..3cb1e7e375 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -949,9 +949,10 @@ def _calculate_registration_flows(
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- EmailRegisterRequestTokenRestServlet(hs).register(http_server)
- MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
- UsernameAvailabilityRestServlet(hs).register(http_server)
- RegistrationSubmitTokenServlet(hs).register(http_server)
+ if hs.config.worker.worker_app is None:
+ EmailRegisterRequestTokenRestServlet(hs).register(http_server)
+ MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
+ UsernameAvailabilityRestServlet(hs).register(http_server)
+ RegistrationSubmitTokenServlet(hs).register(http_server)
RegistrationTokenValidityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 9dd59196d9..7456d6f507 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -16,6 +16,7 @@ import logging
import re
from typing import TYPE_CHECKING, Optional, Tuple
+from synapse.api.constants import Direction
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -59,7 +60,7 @@ class RelationPaginationServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(
- self._store, request, default_limit=5, default_dir="b"
+ self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
)
# The unstable version of this API returns an extra field for client
diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py
index 6e962a4532..e2b410cf32 100644
--- a/synapse/rest/client/report_event.py
+++ b/synapse/rest/client/report_event.py
@@ -54,7 +54,7 @@ class ReportEventRestServlet(RestServlet):
"Param 'reason' must be a string",
Codes.BAD_JSON,
)
- if not isinstance(body.get("score", 0), int):
+ if type(body.get("score", 0)) is not int:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'score' must be an integer",
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 5f56eb4c3b..e7bec9f02e 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -26,7 +26,7 @@ from prometheus_client.core import Histogram
from twisted.web.server import Request
from synapse import event_auth
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -44,6 +44,7 @@ from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_boolean,
+ parse_enum,
parse_integer,
parse_json_object_from_request,
parse_string,
@@ -55,9 +56,9 @@ from synapse.logging.opentracing import set_tag
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.types.state import StateFilter
from synapse.util import json_decoder
from synapse.util.cancellation import cancellable
from synapse.util.stringutils import parse_and_validate_server_name, random_string
@@ -396,12 +397,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- try:
- content = parse_json_object_from_request(request)
- except Exception:
- # Turns out we used to ignore the body entirely, and some clients
- # cheekily send invalid bodies.
- content = {}
+ content = parse_json_object_from_request(request, allow_empty_body=True)
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
args: Dict[bytes, List[bytes]] = request.args # type: ignore
@@ -952,12 +948,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
}:
raise AuthError(403, "Guest access not allowed")
- try:
- content = parse_json_object_from_request(request)
- except Exception:
- # Turns out we used to ignore the body entirely, and some clients
- # cheekily send invalid bodies.
- content = {}
+ content = parse_json_object_from_request(request, allow_empty_body=True)
if membership_action == "invite" and all(
key in content for key in ("medium", "address")
@@ -1284,17 +1275,14 @@ class TimestampLookupRestServlet(RestServlet):
`dir` can be `f` or `b` to indicate forwards and backwards in time from the
given timestamp.
- GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir=
+ GET /_matrix/client/v1/rooms//timestamp_to_event?ts=&dir=
{
"event_id": ...
}
"""
PATTERNS = (
- re.compile(
- "^/_matrix/client/unstable/org.matrix.msc3030"
- "/rooms/(?P[^/]*)/timestamp_to_event$"
- ),
+ re.compile("^/_matrix/client/v1/rooms/(?P[^/]*)/timestamp_to_event$"),
)
def __init__(self, hs: "HomeServer"):
@@ -1310,7 +1298,7 @@ class TimestampLookupRestServlet(RestServlet):
await self._auth.check_user_in_room_or_world_readable(room_id, requester)
timestamp = parse_integer(request, "ts", required=True)
- direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
(
event_id,
@@ -1398,9 +1386,7 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
)
-def register_servlets(
- hs: "HomeServer", http_server: HttpServer, is_worker: bool = False
-) -> None:
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
@@ -1421,11 +1407,10 @@ def register_servlets(
RoomAliasListServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
RoomCreateRestServlet(hs).register(http_server)
- if hs.config.experimental.msc3030_enabled:
- TimestampLookupRestServlet(hs).register(http_server)
+ TimestampLookupRestServlet(hs).register(http_server)
# Some servlets only get registered for the main process.
- if not is_worker:
+ if hs.config.worker.worker_app is None:
RoomForgetRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 46a8b03829..55d52f0b28 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -46,7 +46,6 @@ class SendToDeviceRestServlet(servlet.RestServlet):
def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
- set_tag("message_type", message_type)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 48f3f144ea..3a2c6bd36d 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple
from typing_extensions import ParamSpec
+from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.web.server import Request
@@ -90,7 +91,7 @@ class HttpTransactionCache:
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: P.args,
**kwargs: P.kwargs,
- ) -> Awaitable[Tuple[int, JsonDict]]:
+ ) -> "Deferred[Tuple[int, JsonDict]]":
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index 116c982ce6..4670fad608 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -63,8 +63,8 @@ class UserDirectorySearchRestServlet(RestServlet):
body = parse_json_object_from_request(request)
- limit = body.get("limit", 10)
- limit = min(limit, 50)
+ limit = int(body.get("limit", 10))
+ limit = max(min(limit, 50), 0)
try:
search_term = body["search_term"]
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 180a11ef88..e19c0946c0 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -77,6 +77,7 @@ class VersionsRestServlet(RestServlet):
"v1.2",
"v1.3",
"v1.4",
+ "v1.5",
],
# as per MSC1497:
"unstable_features": {
@@ -101,8 +102,6 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3827.stable": True,
# Adds support for importing historical messages as per MSC2716
"org.matrix.msc2716": self.config.experimental.msc2716_enabled,
- # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
- "org.matrix.msc3030": self.config.experimental.msc3030_enabled,
# Adds support for thread relations, per MSC3440.
"org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above
# Support for thread read receipts & notification counts.
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 40b0d39eb2..c70e1837af 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -24,7 +24,6 @@ from matrix_common.types.mxc_uri import MXCUri
import twisted.internet.error
import twisted.web.http
from twisted.internet.defer import Deferred
-from twisted.web.resource import Resource
from synapse.api.errors import (
FederationDeniedError,
@@ -35,6 +34,7 @@ from synapse.api.errors import (
)
from synapse.config._base import ConfigError
from synapse.config.repository import ThumbnailRequirement
+from synapse.http.server import UnrecognizedRequestResource
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -1046,7 +1046,7 @@ class MediaRepository:
return removed_media, len(removed_media)
-class MediaRepositoryResource(Resource):
+class MediaRepositoryResource(UnrecognizedRequestResource):
"""File uploading and downloading.
Uploads are POSTed to a resource which returns a token which is used to GET
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 827afd868d..7592aa5d47 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import html
import logging
import urllib.parse
from typing import TYPE_CHECKING, List, Optional
@@ -161,7 +162,9 @@ class OEmbedProvider:
title = oembed.get("title")
if title and isinstance(title, str):
- open_graph_response["og:title"] = title
+ # A common WordPress plug-in seems to incorrectly escape entities
+ # in the oEmbed response.
+ open_graph_response["og:title"] = html.unescape(title)
author_name = oembed.get("author_name")
if not isinstance(author_name, str):
@@ -180,9 +183,9 @@ class OEmbedProvider:
# Process each type separately.
oembed_type = oembed.get("type")
if oembed_type == "rich":
- html = oembed.get("html")
- if isinstance(html, str):
- calc_description_and_urls(open_graph_response, html)
+ html_str = oembed.get("html")
+ if isinstance(html_str, str):
+ calc_description_and_urls(open_graph_response, html_str)
elif oembed_type == "photo":
# If this is a photo, use the full image, not the thumbnail.
@@ -192,12 +195,12 @@ class OEmbedProvider:
elif oembed_type == "video":
open_graph_response["og:type"] = "video.other"
- html = oembed.get("html")
- if html and isinstance(html, str):
+ html_str = oembed.get("html")
+ if html_str and isinstance(html_str, str):
calc_description_and_urls(open_graph_response, oembed["html"])
for size in ("width", "height"):
val = oembed.get(size)
- if val is not None and isinstance(val, int):
+ if type(val) is int:
open_graph_response[f"og:video:{size}"] = val
elif oembed_type == "link":
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index a48a4de92a..9480cc5763 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -77,7 +77,7 @@ class Thumbnailer:
image_exif = self.image._getexif() # type: ignore
if image_exif is not None:
image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
- assert isinstance(image_orientation, int)
+ assert type(image_orientation) is int
self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
except Exception as e:
# A lot of parsing errors can happen when parsing EXIF
diff --git a/synapse/server.py b/synapse/server.py
index f0a60d0056..9d6d268f49 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -107,7 +107,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
from synapse.module_api import ModuleApi
-from synapse.notifier import Notifier
+from synapse.notifier import Notifier, ReplicationNotifier
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
@@ -389,6 +389,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_notifier(self) -> Notifier:
return Notifier(self)
+ @cache_in_self
+ def get_replication_notifier(self) -> ReplicationNotifier:
+ return ReplicationNotifier()
+
@cache_in_self
def get_auth(self) -> Auth:
return Auth(self)
@@ -510,7 +514,7 @@ class HomeServer(metaclass=abc.ABCMeta):
)
@cache_in_self
- def get_device_handler(self):
+ def get_device_handler(self) -> DeviceWorkerHandler:
if self.config.worker.worker_app:
return DeviceWorkerHandler(self)
else:
@@ -743,7 +747,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_event_client_serializer(self) -> EventClientSerializer:
- return EventClientSerializer()
+ return EventClientSerializer(self.config.experimental.msc3925_inhibit_edit)
@cache_in_self
def get_password_policy_handler(self) -> PasswordPolicyHandler:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 833ffec3de..fdfb46ab82 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -44,8 +44,8 @@ from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
from synapse.types import StateMap
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
@@ -202,14 +202,20 @@ class StateHandler:
room_id: the room_id containing the given events.
event_ids: the events whose state should be fetched and resolved.
await_full_state: if `True`, will block if we do not yet have complete state
- at the given `event_id`s, regardless of whether `state_filter` is
- satisfied by partial state.
+ at these events and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
the state dict (a mapping from (event_type, state_key) -> event_id) which
holds the resolution of the states after the given event IDs.
"""
logger.debug("calling resolve_state_groups from compute_state_after_events")
+ if (
+ await_full_state
+ and state_filter
+ and not state_filter.must_await_full_state(self.hs.is_mine_id)
+ ):
+ await_full_state = False
ret = await self.resolve_state_groups_for_events(
room_id, event_ids, await_full_state
)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 69abf6fa87..41d9111019 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta):
token: int,
rows: Iterable[Any],
) -> None:
- pass
+ """
+ Used by storage classes to invalidate caches based on incoming replication data. These
+ must not update any ID generators, use `process_replication_position`.
+ """
+
+ def process_replication_position( # noqa: B027 (no-op by design)
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ ) -> None:
+ """
+ Used by storage classes to advance ID generators based on incoming replication data. This
+ is called after process_replication_rows such that caches are invalidated before any token
+ positions advance.
+ """
def _invalidate_state_caches(
self, room_id: str, members_changed: Collection[str]
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 52e2e35d06..1666c4616a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -544,6 +544,48 @@ class BackgroundUpdater:
The named index will be dropped upon completion of the new index.
"""
+ async def updater(progress: JsonDict, batch_size: int) -> int:
+ await self.create_index_in_background(
+ index_name=index_name,
+ table=table,
+ columns=columns,
+ where_clause=where_clause,
+ unique=unique,
+ psql_only=psql_only,
+ replaces_index=replaces_index,
+ )
+ await self._end_background_update(update_name)
+ return 1
+
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ updater, oneshot=True
+ )
+
+ async def create_index_in_background(
+ self,
+ index_name: str,
+ table: str,
+ columns: Iterable[str],
+ where_clause: Optional[str] = None,
+ unique: bool = False,
+ psql_only: bool = False,
+ replaces_index: Optional[str] = None,
+ ) -> None:
+ """Add an index in the background.
+
+ Args:
+ update_name: update_name to register for
+ index_name: name of index to add
+ table: table to add index to
+ columns: columns/expressions to include in index
+ where_clause: A WHERE clause to specify a partial unique index.
+ unique: true to make a UNIQUE index
+ psql_only: true to only create this index on psql databases (useful
+ for virtual sqlite tables)
+ replaces_index: The name of an index that this index replaces.
+ The named index will be dropped upon completion of the new index.
+ """
+
def create_index_psql(conn: Connection) -> None:
conn.rollback()
# postgres insists on autocommit for the index
@@ -618,16 +660,11 @@ class BackgroundUpdater:
else:
runner = create_index_sqlite
- async def updater(progress: JsonDict, batch_size: int) -> int:
- if runner is not None:
- logger.info("Adding index %s to %s", index_name, table)
- await self.db_pool.runWithConnection(runner)
- await self._end_background_update(update_name)
- return 1
+ if runner is None:
+ return
- self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
- updater, oneshot=True
- )
+ logger.info("Adding index %s to %s", index_name, table)
+ await self.db_pool.runWithConnection(runner)
async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index c737116224..7e13f691e3 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -58,13 +58,13 @@ from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
StateMap,
get_domain_from_id,
)
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
from synapse.util.metrics import Measure
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 2b31ce54bb..52efd4a171 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -31,12 +31,12 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.logging.opentracing import tag_args, trace
from synapse.storage.roommember import ProfileInfo
-from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
PartialStateEventsTracker,
)
from synapse.types import MutableStateMap, StateMap
+from synapse.types.state import StateFilter
from synapse.util.cancellation import cancellable
if TYPE_CHECKING:
@@ -493,8 +493,6 @@ class StateStorageController:
up to date.
"""
# FIXME(faster_joins): what do we do here?
- # https://github.com/matrix-org/synapse/issues/12814
- # https://github.com/matrix-org/synapse/issues/12815
# https://github.com/matrix-org/synapse/issues/13008
return await self.stores.main.get_partial_current_state_deltas(
@@ -571,10 +569,11 @@ class StateStorageController:
is arbitrary for rooms with partial state.
"""
# We have to read this list first to mitigate races with un-partial stating.
- # This will be empty for rooms with full state.
hosts_at_join = await self.stores.main.get_partial_state_servers_at_join(
room_id
)
+ if hosts_at_join is None:
+ hosts_at_join = frozenset()
hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0dc44b246c..e20c5c5302 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -667,7 +667,8 @@ class DatabasePool:
)
# also check variables referenced in func's closure
if inspect.isfunction(func):
- f = cast(types.FunctionType, func)
+ # Keep the cast for now---it helps PyCharm to understand what `func` is.
+ f = cast(types.FunctionType, func) # type: ignore[redundant-cast]
if f.__closure__:
for i, cell in enumerate(f.__closure__):
if inspect.isgenerator(cell.cell_contents):
@@ -1129,7 +1130,6 @@ class DatabasePool:
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
desc: str = "simple_upsert",
- lock: bool = True,
) -> bool:
"""Insert a row with values + insertion_values; on conflict, update with values.
@@ -1154,21 +1154,12 @@ class DatabasePool:
requiring that a unique index exist on the column names used to detect a
conflict (i.e. `keyvalues.keys()`).
- If there is no such index, we can "emulate" an upsert with a SELECT followed
- by either an INSERT or an UPDATE. This is unsafe: we cannot make the same
- atomicity guarantees that a native upsert can and are very vulnerable to races
- and crashes. Therefore if we wish to upsert without an appropriate unique index,
- we must either:
-
- 1. Acquire a table-level lock before the emulated upsert (`lock=True`), or
- 2. VERY CAREFULLY ensure that we are the only thread and worker which will be
- writing to this table, in which case we can proceed without a lock
- (`lock=False`).
-
- Generally speaking, you should use `lock=True`. If the table in question has a
- unique index[*], this class will use a native upsert (which is atomic and so can
- ignore the `lock` argument). Otherwise this class will use an emulated upsert,
- in which case we want the safer option unless we been VERY CAREFUL.
+ If there is no such index yet[*], we can "emulate" an upsert with a SELECT
+ followed by either an INSERT or an UPDATE. This is unsafe unless *all* upserters
+ run at the SERIALIZABLE isolation level: we cannot make the same atomicity
+ guarantees that a native upsert can and are very vulnerable to races and
+ crashes. Therefore to upsert without an appropriate unique index, we acquire a
+ table-level lock before the emulated upsert.
[*]: Some tables have unique indices added to them in the background. Those
tables `T` are keys in the dictionary UNIQUE_INDEX_BACKGROUND_UPDATES,
@@ -1189,7 +1180,6 @@ class DatabasePool:
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
desc: description of the transaction, for logging and metrics
- lock: True to lock the table when doing the upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
not empty then this always returns True)
@@ -1209,7 +1199,6 @@ class DatabasePool:
keyvalues,
values,
insertion_values,
- lock=lock,
db_autocommit=autocommit,
)
except self.engine.module.IntegrityError as e:
@@ -1232,7 +1221,6 @@ class DatabasePool:
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
- lock: bool = True,
) -> bool:
"""
Pick the UPSERT method which works best on the platform. Either the
@@ -1245,8 +1233,6 @@ class DatabasePool:
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
- lock: True to lock the table when doing the upsert. Unused when performing
- a native upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
not empty then this always returns True)
@@ -1270,7 +1256,6 @@ class DatabasePool:
values,
insertion_values=insertion_values,
where_clause=where_clause,
- lock=lock,
)
def simple_upsert_txn_emulated(
@@ -1291,14 +1276,15 @@ class DatabasePool:
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert.
+ Must not be False unless the table has already been locked.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
not empty then this always returns True)
"""
insertion_values = insertion_values or {}
- # We need to lock the table :(, unless we're *really* careful
if lock:
+ # We need to lock the table :(
self.engine.lock_table(txn, table)
def _getwhere(key: str) -> str:
@@ -1406,7 +1392,6 @@ class DatabasePool:
value_names: Collection[str],
value_values: Collection[Collection[Any]],
desc: str,
- lock: bool = True,
) -> None:
"""
Upsert, many times.
@@ -1418,8 +1403,6 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
- lock: True to lock the table when doing the upsert. Unused when performing
- a native upsert.
"""
# We can autocommit if it safe to upsert
@@ -1433,7 +1416,6 @@ class DatabasePool:
key_values,
value_names,
value_values,
- lock=lock,
db_autocommit=autocommit,
)
@@ -1445,7 +1427,6 @@ class DatabasePool:
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
- lock: bool = True,
) -> None:
"""
Upsert, many times.
@@ -1457,8 +1438,6 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
- lock: True to lock the table when doing the upsert. Unused when performing
- a native upsert.
"""
if table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@@ -1466,7 +1445,12 @@ class DatabasePool:
)
else:
return self.simple_upsert_many_txn_emulated(
- txn, table, key_names, key_values, value_names, value_values, lock=lock
+ txn,
+ table,
+ key_names,
+ key_values,
+ value_names,
+ value_values,
)
def simple_upsert_many_txn_emulated(
@@ -1477,7 +1461,6 @@ class DatabasePool:
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
- lock: bool = True,
) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
@@ -1489,18 +1472,16 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
- lock: True to lock the table when doing the upsert.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
if not value_names:
value_values = [() for x in range(len(key_values))]
- if lock:
- # Lock the table just once, to prevent it being done once per row.
- # Note that, according to Postgres' documentation, once obtained,
- # the lock is held for the remainder of the current transaction.
- self.engine.lock_table(txn, "user_ips")
+ # Lock the table just once, to prevent it being done once per row.
+ # Note that, according to Postgres' documentation, once obtained,
+ # the lock is held for the remainder of the current transaction.
+ self.engine.lock_table(txn, "user_ips")
for keyv, valv in zip(key_values, value_values):
_keys = {x: y for x, y in zip(key_names, keyv)}
@@ -1781,7 +1762,8 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
Returns:
- A list of dictionaries.
+ A list of dictionaries, one per result row, each a mapping between the
+ column names from `retcols` and that column's value for the row.
"""
return await self.runInteraction(
desc,
@@ -1810,6 +1792,10 @@ class DatabasePool:
column names and values to select the rows with, or None to not
apply a WHERE clause.
retcols: the names of the columns to return
+
+ Returns:
+ A list of dictionaries, one per result row, each a mapping between the
+ column names from `retcols` and that column's value for the row.
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1833,7 +1819,7 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch",
batch_size: int = 100,
- ) -> List[Any]:
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1917,6 +1903,19 @@ class DatabasePool:
updatevalues: Dict[str, Any],
desc: str,
) -> int:
+ """
+ Update rows in the given database table.
+ If the given keyvalues don't match anything, nothing will be updated.
+
+ Args:
+ table: The database table to update.
+ keyvalues: A mapping of column name to value to match rows on.
+ updatevalues: A mapping of column name to value to replace in any matched rows.
+ desc: description of the transaction, for logging and metrics.
+
+ Returns:
+ The number of rows that were updated. Will be 0 if no matching rows were found.
+ """
return await self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues
)
@@ -1928,6 +1927,19 @@ class DatabasePool:
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
) -> int:
+ """
+ Update rows in the given database table.
+ If the given keyvalues don't match anything, nothing will be updated.
+
+ Args:
+ txn: The database transaction object.
+ table: The database table to update.
+ keyvalues: A mapping of column name to value to match rows on.
+ updatevalues: A mapping of column name to value to replace in any matched rows.
+
+ Returns:
+ The number of rows that were updated. Will be 0 if no matching rows were found.
+ """
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
@@ -2075,13 +2087,14 @@ class DatabasePool:
retcols: Collection[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
- select_sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
+ select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+
+ if keyvalues:
+ select_sql += " WHERE %s" % (" AND ".join("%s = ?" % k for k in keyvalues),)
+ txn.execute(select_sql, list(keyvalues.values()))
+ else:
+ txn.execute(select_sql)
- txn.execute(select_sql, list(keyvalues.values()))
row = txn.fetchone()
if not row:
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0e47592be3..837dc7646e 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -17,6 +17,7 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import (
DatabasePool,
@@ -167,7 +168,7 @@ class DataStore(
guests: bool = True,
deactivated: bool = False,
order_by: str = UserSortOrder.NAME.value,
- direction: str = "f",
+ direction: Direction = Direction.FORWARDS,
approved: bool = True,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
@@ -197,7 +198,7 @@ class DataStore(
# Set ordering
order_by_column = UserSortOrder(order_by).value
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 282687ebce..8a359d7eb8 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -27,7 +27,7 @@ from typing import (
)
from synapse.api.constants import AccountDataTypes
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
+from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import (
DatabasePool,
@@ -75,6 +75,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="account_data",
instance_name=self._instance_name,
tables=[
@@ -95,6 +96,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# SQLite).
self._account_data_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
@@ -123,7 +125,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
async def get_account_data_for_user(
self, user_id: str
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- """Get all the client account_data for a user.
+ """
+ Get all the client account_data for a user.
+
+ If experimental MSC3391 support is enabled, any entries with an empty
+ content body are excluded; as this means they have been deleted.
Args:
user_id: The user to get the account_data for.
@@ -135,27 +141,48 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_user_txn(
txn: LoggingTransaction,
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "account_data",
- {"user_id": user_id},
- ["account_data_type", "content"],
- )
+ # The 'content != '{}' condition below prevents us from using
+ # `simple_select_list_txn` here, as it doesn't support conditions
+ # other than 'equals'.
+ sql = """
+ SELECT account_data_type, content FROM account_data
+ WHERE user_id = ?
+ """
+
+ # If experimental MSC3391 support is enabled, then account data entries
+ # with an empty content are considered "deleted". So skip adding them to
+ # the results.
+ if self.hs.config.experimental.msc3391_enabled:
+ sql += " AND content != '{}'"
+
+ txn.execute(sql, (user_id,))
+ rows = self.db_pool.cursor_to_dict(txn)
global_account_data = {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "room_account_data",
- {"user_id": user_id},
- ["room_id", "account_data_type", "content"],
- )
+ # The 'content != '{}' condition below prevents us from using
+ # `simple_select_list_txn` here, as it doesn't support conditions
+ # other than 'equals'.
+ sql = """
+ SELECT room_id, account_data_type, content FROM room_account_data
+ WHERE user_id = ?
+ """
+
+ # If experimental MSC3391 support is enabled, then account data entries
+ # with an empty content are considered "deleted". So skip adding them to
+ # the results.
+ if self.hs.config.experimental.msc3391_enabled:
+ sql += " AND content != '{}'"
+
+ txn.execute(sql, (user_id,))
+ rows = self.db_pool.cursor_to_dict(txn)
by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
+
room_data[row["account_data_type"]] = db_to_json(row["content"])
return global_account_data, by_room
@@ -411,10 +438,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
token: int,
rows: Iterable[Any],
) -> None:
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- elif stream_name == AccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
+ if stream_name == AccountDataStream.NAME:
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
@@ -429,6 +453,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == AccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
@@ -449,9 +480,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
content_json = json_encoder.encode(content)
async with self._account_data_id_gen.get_next() as next_id:
- # no need to lock here as room_account_data has a unique constraint
- # on (user_id, room_id, account_data_type) so simple_upsert will
- # retry if there is a conflict.
await self.db_pool.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
@@ -461,7 +489,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"account_data_type": account_data_type,
},
values={"stream_id": next_id, "content": content_json},
- lock=False,
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
@@ -473,6 +500,72 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return self._account_data_id_gen.get_current_token()
+ async def remove_account_data_for_room(
+ self, user_id: str, room_id: str, account_data_type: str
+ ) -> Optional[int]:
+ """Delete the room account data for the user of a given type.
+
+ Args:
+ user_id: The user to remove account_data for.
+ room_id: The room ID to scope the request to.
+ account_data_type: The account data type to delete.
+
+ Returns:
+ The maximum stream position, or None if there was no matching room account
+ data to delete.
+ """
+ assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
+
+ def _remove_account_data_for_room_txn(
+ txn: LoggingTransaction, next_id: int
+ ) -> bool:
+ """
+ Args:
+ txn: The transaction object.
+ next_id: The stream_id to update any existing rows to.
+
+ Returns:
+ True if an entry in room_account_data had its content set to '{}',
+ otherwise False. This informs callers of whether there actually was an
+ existing room account data entry to delete, or if the call was a no-op.
+ """
+ # We can't use `simple_update` as it doesn't have the ability to specify
+ # where clauses other than '=', which we need for `content != '{}'` below.
+ sql = """
+ UPDATE room_account_data
+ SET stream_id = ?, content = '{}'
+ WHERE user_id = ?
+ AND room_id = ?
+ AND account_data_type = ?
+ AND content != '{}'
+ """
+ txn.execute(
+ sql,
+ (next_id, user_id, room_id, account_data_type),
+ )
+ # Return true if any rows were updated.
+ return txn.rowcount != 0
+
+ async with self._account_data_id_gen.get_next() as next_id:
+ row_updated = await self.db_pool.runInteraction(
+ "remove_account_data_for_room",
+ _remove_account_data_for_room_txn,
+ next_id,
+ )
+
+ if not row_updated:
+ return None
+
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_account_data_for_user.invalidate((user_id,))
+ self.get_account_data_for_room.invalidate((user_id, room_id))
+ self.get_account_data_for_room_and_type.prefill(
+ (user_id, room_id, account_data_type), {}
+ )
+
+ return self._account_data_id_gen.get_current_token()
+
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
@@ -517,15 +610,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
) -> None:
content_json = json_encoder.encode(content)
- # no need to lock here as account_data has a unique constraint on
- # (user_id, account_data_type) so simple_upsert will retry if
- # there is a conflict.
self.db_pool.simple_upsert_txn(
txn,
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
values={"stream_id": next_id, "content": content_json},
- lock=False,
)
# Ignored users get denormalized into a separate table as an optimisation.
@@ -577,6 +666,108 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
+ async def remove_account_data_for_user(
+ self,
+ user_id: str,
+ account_data_type: str,
+ ) -> Optional[int]:
+ """
+ Delete a single piece of user account data by type.
+
+ A "delete" is performed by updating a potentially existing row in the
+ "account_data" database table for (user_id, account_data_type) and
+ setting its content to "{}".
+
+ Args:
+ user_id: The user ID to modify the account data of.
+ account_data_type: The type to remove.
+
+ Returns:
+ The maximum stream position, or None if there was no matching account data
+ to delete.
+ """
+ assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
+
+ def _remove_account_data_for_user_txn(
+ txn: LoggingTransaction, next_id: int
+ ) -> bool:
+ """
+ Args:
+ txn: The transaction object.
+ next_id: The stream_id to update any existing rows to.
+
+ Returns:
+ True if an entry in account_data had its content set to '{}', otherwise
+ False. This informs callers of whether there actually was an existing
+ account data entry to delete, or if the call was a no-op.
+ """
+ # We can't use `simple_update` as it doesn't have the ability to specify
+ # where clauses other than '=', which we need for `content != '{}'` below.
+ sql = """
+ UPDATE account_data
+ SET stream_id = ?, content = '{}'
+ WHERE user_id = ?
+ AND account_data_type = ?
+ AND content != '{}'
+ """
+ txn.execute(sql, (next_id, user_id, account_data_type))
+ if txn.rowcount == 0:
+ # We didn't update any rows. This means that there was no matching room
+ # account data entry to delete in the first place.
+ return False
+
+ # Ignored users get denormalized into a separate table as an optimisation.
+ if account_data_type == AccountDataTypes.IGNORED_USER_LIST:
+ # If this method was called with the ignored users account data type, we
+ # simply delete all ignored users.
+
+ # First pull all the users that this user ignores.
+ previously_ignored_users = set(
+ self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="ignored_users",
+ keyvalues={"ignorer_user_id": user_id},
+ retcol="ignored_user_id",
+ )
+ )
+
+ # Then delete them from the database.
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="ignored_users",
+ keyvalues={"ignorer_user_id": user_id},
+ )
+
+ # Invalidate the cache for ignored users which were removed.
+ for ignored_user_id in previously_ignored_users:
+ self._invalidate_cache_and_stream(
+ txn, self.ignored_by, (ignored_user_id,)
+ )
+
+ # Invalidate for this user the cache tracking ignored users.
+ self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
+
+ return True
+
+ async with self._account_data_id_gen.get_next() as next_id:
+ row_updated = await self.db_pool.runInteraction(
+ "remove_account_data_for_user",
+ _remove_account_data_for_user_txn,
+ next_id,
+ )
+
+ if not row_updated:
+ return None
+
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_by_type_for_user.prefill(
+ (user_id, account_data_type), {}
+ )
+
+ return self._account_data_id_gen.get_current_token()
+
async def purge_account_data_for_user(self, user_id: str) -> None:
"""
Removes ALL the account data for a user.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 63046c0527..c2c8018ee2 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -20,7 +20,7 @@ from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
AppServiceTransaction,
- TransactionOneTimeKeyCounts,
+ TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.config.appservice import load_appservices
@@ -260,7 +260,7 @@ class ApplicationServiceTransactionWorkerStore(
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
- one_time_key_counts: TransactionOneTimeKeyCounts,
+ one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
) -> AppServiceTransaction:
@@ -273,7 +273,7 @@ class ApplicationServiceTransactionWorkerStore(
events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
to_device_messages: A list of to-device messages to put in the transaction.
- one_time_key_counts: Counts of remaining one-time keys for relevant
+ one_time_keys_count: Counts of remaining one-time keys for relevant
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
@@ -299,7 +299,7 @@ class ApplicationServiceTransactionWorkerStore(
events=events,
ephemeral=ephemeral,
to_device_messages=to_device_messages,
- one_time_key_counts=one_time_key_counts,
+ one_time_keys_count=one_time_keys_count,
unused_fallback_keys=unused_fallback_keys,
device_list_summary=device_list_summary,
)
@@ -379,7 +379,7 @@ class ApplicationServiceTransactionWorkerStore(
events=events,
ephemeral=[],
to_device_messages=[],
- one_time_key_counts={},
+ one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
@@ -451,8 +451,6 @@ class ApplicationServiceTransactionWorkerStore(
table="application_services_state",
keyvalues={"as_id": service.id},
values={f"{stream_type}_stream_id": pos},
- # no need to lock when emulating upsert: as_id is a unique key
- lock=False,
desc="set_appservice_stream_type_pos",
)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index ddb7397714..5b66431691 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -75,6 +75,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
+ notifier=hs.get_replication_notifier(),
stream_name="caches",
instance_name=hs.get_instance_name(),
tables=[
@@ -164,9 +165,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=True,
)
elif stream_name == CachesStream.NAME:
- if self._cache_id_gen:
- self._cache_id_gen.advance(instance_name, token)
-
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
@@ -182,6 +180,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == CachesStream.NAME:
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data
@@ -259,6 +265,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
+ self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
self._attempt_to_invalidate_cache(
"get_aggregation_groups_for_event", (relates_to,)
)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 73c95ffb6f..8e61aba454 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -26,8 +26,15 @@ from typing import (
cast,
)
+from synapse.api.constants import EventContentFields
from synapse.logging import issue9533_logger
-from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.logging.opentracing import (
+ SynapseTags,
+ log_kv,
+ set_tag,
+ start_active_span,
+ trace,
+)
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -84,6 +91,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
@@ -94,7 +102,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
else:
self._can_write_to_device = True
self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_inbox", "stream_id"
+ db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
@@ -150,6 +158,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == ToDeviceStream.NAME:
+ self._device_inbox_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token()
@@ -397,6 +412,17 @@ class DeviceInboxWorkerStore(SQLBaseStore):
(recipient_user_id, recipient_device_id), []
).append(message_dict)
+ # start a new span for each message, so that we can tag each separately
+ with start_active_span("get_to_device_message"):
+ set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"])
+ set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"])
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT, recipient_user_id)
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, recipient_device_id)
+ set_tag(
+ SynapseTags.TO_DEVICE_MSGID,
+ message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
+ )
+
if limit is not None and rowcount == limit:
# We ended up bumping up against the message limit. There may be more messages
# to retrieve. Return what we have, as well as the last stream position that
@@ -678,12 +704,35 @@ class DeviceInboxWorkerStore(SQLBaseStore):
],
)
- if remote_messages_by_destination:
- issue9533_logger.debug(
- "Queued outgoing to-device messages with stream_id %i for %s",
- stream_id,
- list(remote_messages_by_destination.keys()),
- )
+ for destination, edu in remote_messages_by_destination.items():
+ if issue9533_logger.isEnabledFor(logging.DEBUG):
+ issue9533_logger.debug(
+ "Queued outgoing to-device messages with "
+ "stream_id %i, EDU message_id %s, type %s for %s: %s",
+ stream_id,
+ edu["message_id"],
+ edu["type"],
+ destination,
+ [
+ f"{user_id}/{device_id} (msgid "
+ f"{msg.get(EventContentFields.TO_DEVICE_MSGID)})"
+ for (user_id, messages_by_device) in edu["messages"].items()
+ for (device_id, msg) in messages_by_device.items()
+ ],
+ )
+
+ for (user_id, messages_by_device) in edu["messages"].items():
+ for (device_id, msg) in messages_by_device.items():
+ with start_active_span("store_outgoing_to_device_message"):
+ set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
+ set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
+ set_tag(SynapseTags.TO_DEVICE_TYPE, edu["type"])
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
+ set_tag(
+ SynapseTags.TO_DEVICE_MSGID,
+ msg.get(EventContentFields.TO_DEVICE_MSGID),
+ )
async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self._clock.time_msec()
@@ -801,7 +850,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# Only insert into the local inbox if the device exists on
# this server
device_id = row["device_id"]
- message_json = json_encoder.encode(messages_by_device[device_id])
+
+ with start_active_span("serialise_to_device_message"):
+ msg = messages_by_device[device_id]
+ set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
+ set_tag(SynapseTags.TO_DEVICE_SENDER, msg["sender"])
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
+ set_tag(
+ SynapseTags.TO_DEVICE_MSGID,
+ msg["content"].get(EventContentFields.TO_DEVICE_MSGID),
+ )
+ message_json = json_encoder.encode(msg)
+
messages_json_for_user[device_id] = message_json
if messages_json_for_user:
@@ -821,15 +882,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
],
)
- issue9533_logger.debug(
- "Stored to-device messages with stream_id %i for %s",
- stream_id,
- [
- (user_id, device_id)
- for (user_id, messages_by_device) in local_by_user_then_device.items()
- for device_id in messages_by_device.keys()
- ],
- )
+ if issue9533_logger.isEnabledFor(logging.DEBUG):
+ issue9533_logger.debug(
+ "Stored to-device messages with stream_id %i: %s",
+ stream_id,
+ [
+ f"{user_id}/{device_id} (msgid "
+ f"{msg['content'].get(EventContentFields.TO_DEVICE_MSGID)})"
+ for (
+ user_id,
+ messages_by_device,
+ ) in messages_by_user_then_device.items()
+ for (device_id, msg) in messages_by_device.items()
+ ],
+ )
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 57230df5ae..e8b6cc6b80 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,7 +38,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
+from synapse.replication.tcp.streams._base import DeviceListsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -54,11 +54,14 @@ from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
StreamIdGenerator,
)
-from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
+from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.caches.stream_change_cache import (
+ AllEntitiesChangedResult,
+ StreamChangeCache,
+)
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -89,12 +92,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# class below that is used on the main process.
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
+ ("device_lists_remote_pending", "stream_id"),
],
is_writer=hs.config.worker.worker_app is None,
)
@@ -159,18 +164,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME:
- self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
- elif stream_name == UserSignatureStream.NAME:
- self._device_list_id_gen.advance(instance_name, token)
- for row in rows:
- self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == DeviceListsStream.NAME:
+ self._device_list_id_gen.advance(instance_name, token)
+
+ super().process_replication_position(stream_name, instance_name, token)
+
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
+ if row.is_signature:
+ self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+ continue
+
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
@@ -799,18 +812,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_cached_device_list_changes(
self,
from_key: int,
- ) -> Optional[List[str]]:
+ ) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
return self._device_list_stream_cache.get_all_entities_changed(from_key)
+ @cancellable
+ async def get_all_devices_changed(
+ self,
+ from_key: int,
+ to_key: int,
+ ) -> Set[str]:
+ """Get all users whose devices have changed in the given range.
+
+ Args:
+ from_key: The minimum device lists stream token to query device list
+ changes for, exclusive.
+ to_key: The maximum device lists stream token to query device list
+ changes for, inclusive.
+
+ Returns:
+ The set of user_ids whose devices have changed since `from_key`
+ (exclusive) until `to_key` (inclusive).
+ """
+
+ result = self._device_list_stream_cache.get_all_entities_changed(from_key)
+
+ if result.hit:
+ # We know which users might have changed devices.
+ if not result.entities:
+ # If no users then we can return early.
+ return set()
+
+ # Otherwise we need to filter down the list
+ return await self.get_users_whose_devices_changed(
+ from_key, result.entities, to_key
+ )
+
+ # If the cache didn't tell us anything, we just need to query the full
+ # range.
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ """
+
+ rows = await self.db_pool.execute(
+ "get_all_devices_changed",
+ None,
+ sql,
+ from_key,
+ to_key,
+ )
+ return {u for u, in rows}
+
@cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
- user_ids: Optional[Collection[str]] = None,
+ user_ids: Collection[str],
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
@@ -830,46 +891,31 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- user_ids_to_check: Optional[Collection[str]]
- if user_ids is None:
- # Get set of all users that have had device list changes since 'from_key'
- user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
- from_key
- )
- else:
- # The same as above, but filter results to only those users in 'user_ids'
- user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
- user_ids, from_key
- )
+ user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
+ )
+ # If an empty set was returned, there's nothing to do.
if not user_ids_to_check:
return set()
+ if to_key is None:
+ to_key = self._device_list_id_gen.get_current_token()
+
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
- changes: Set[str] = set()
-
- stream_id_where_clause = "stream_id > ?"
- sql_args = [from_key]
-
- if to_key:
- stream_id_where_clause += " AND stream_id <= ?"
- sql_args.append(to_key)
-
- sql = f"""
+ sql = """
SELECT DISTINCT user_id FROM device_lists_stream
- WHERE {stream_id_where_clause}
- AND
+ WHERE ? < stream_id AND stream_id <= ? AND %s
"""
+ changes: Set[str] = set()
+
# Query device changes with a batch of users at a time
- # Assertion for mypy's benefit; see also
- # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
- assert user_ids_to_check is not None
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
- txn.execute(sql + clause, sql_args + args)
+ txn.execute(sql % (clause,), [from_key, to_key] + args)
changes.update(user_id for user_id, in txn)
return changes
@@ -1026,16 +1072,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return {row["user_id"] for row in rows}
- async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
+ async def mark_remote_users_device_caches_as_stale(
+ self, user_ids: StrCollection
+ ) -> None:
"""Records that the server has reason to believe the cache of the devices
for the remote users is out of date.
"""
- await self.db_pool.simple_upsert(
- table="device_lists_remote_resync",
- keyvalues={"user_id": user_id},
- values={},
- insertion_values={"added_ts": self._clock.time_msec()},
- desc="mark_remote_user_device_cache_as_stale",
+
+ def _mark_remote_users_device_caches_as_stale_txn(
+ txn: LoggingTransaction,
+ ) -> None:
+ # TODO add insertion_values support to simple_upsert_many and use
+ # that!
+ for user_id in user_ids:
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="device_lists_remote_resync",
+ keyvalues={"user_id": user_id},
+ values={},
+ insertion_values={"added_ts": self._clock.time_msec()},
+ )
+
+ await self.db_pool.runInteraction(
+ "mark_remote_users_device_caches_as_stale",
+ _mark_remote_users_device_caches_as_stale_txn,
)
async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None:
@@ -1441,6 +1501,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
self._remove_duplicate_outbound_pokes,
)
+ self.db_pool.updates.register_background_index_update(
+ "device_lists_changes_in_room_by_room_index",
+ index_name="device_lists_changes_in_room_by_room_idx",
+ table="device_lists_changes_in_room",
+ columns=["room_id", "stream_id"],
+ )
+
async def _drop_device_list_streams_non_unique_indexes(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -1737,9 +1804,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"content": json_encoder.encode(content)},
- # we don't need to lock, because we assume we are the only thread
- # updating this user's devices.
- lock=False,
)
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
@@ -1753,9 +1817,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
values={"stream_id": stream_id},
- # again, we can assume we are the only thread updating this user's
- # extremity.
- lock=False,
)
async def update_remote_device_list_cache(
@@ -1808,9 +1869,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
values={"stream_id": stream_id},
- # we don't need to lock, because we can assume we are the only thread
- # updating this user's extremity.
- lock=False,
)
async def add_device_change_to_streams(
@@ -2008,27 +2066,48 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def get_uncoverted_outbound_room_pokes(
- self, limit: int = 10
+ self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
"""Get device list changes by room that have not yet been handled and
written to `device_lists_outbound_pokes`.
+ Args:
+ start_stream_id: Together with `start_room_id`, indicates the position after
+ which to return device list changes.
+ start_room_id: Together with `start_stream_id`, indicates the position after
+ which to return device list changes.
+ limit: The maximum number of device list changes to return.
+
Returns:
- A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+ A list of user ID, device ID, room ID, stream ID and optional opentracing
+ context, in order of ascending (stream ID, room ID).
"""
sql = """
SELECT user_id, device_id, room_id, stream_id, opentracing_context
FROM device_lists_changes_in_room
- WHERE NOT converted_to_destinations
- ORDER BY stream_id
+ WHERE
+ (stream_id, room_id) > (?, ?) AND
+ stream_id <= ? AND
+ NOT converted_to_destinations
+ ORDER BY stream_id ASC, room_id ASC
LIMIT ?
"""
def get_uncoverted_outbound_room_pokes_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
- txn.execute(sql, (limit,))
+ txn.execute(
+ sql,
+ (
+ start_stream_id,
+ start_room_id,
+ # Avoid returning rows if there may be uncommitted device list
+ # changes with smaller stream IDs.
+ self._device_list_id_gen.get_current_token(),
+ limit,
+ ),
+ )
return [
(
@@ -2050,49 +2129,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
room_id: str,
- stream_id: Optional[int],
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
-
- Marks the associated row in `device_lists_changes_in_room` as handled,
- if `stream_id` is provided.
"""
+ if not hosts:
+ return
def add_device_list_outbound_pokes_txn(
txn: LoggingTransaction, stream_ids: List[int]
) -> None:
- if hosts:
- self._add_device_outbound_poke_to_stream_txn(
- txn,
- user_id=user_id,
- device_id=device_id,
- hosts=hosts,
- stream_ids=stream_ids,
- context=context,
- )
-
- if stream_id:
- self.db_pool.simple_update_txn(
- txn,
- table="device_lists_changes_in_room",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- "stream_id": stream_id,
- "room_id": room_id,
- },
- updatevalues={"converted_to_destinations": True},
- )
-
- if not hosts:
- # If there are no hosts then we don't try and generate stream IDs.
- return await self.db_pool.runInteraction(
- "add_device_list_outbound_pokes",
- add_device_list_outbound_pokes_txn,
- [],
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
+ user_id=user_id,
+ device_id=device_id,
+ hosts=hosts,
+ stream_ids=stream_ids,
+ context=context,
)
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
@@ -2156,3 +2211,37 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"get_pending_remote_device_list_updates_for_room",
get_pending_remote_device_list_updates_for_room_txn,
)
+
+ async def get_device_change_last_converted_pos(self) -> Tuple[int, str]:
+ """
+ Get the position of the last row in `device_list_changes_in_room` that has been
+ converted to `device_lists_outbound_pokes`.
+
+ Rows with a strictly greater position where `converted_to_destinations` is
+ `FALSE` have not been converted.
+ """
+
+ row = await self.db_pool.simple_select_one(
+ table="device_lists_changes_converted_stream_position",
+ keyvalues={},
+ retcols=["stream_id", "room_id"],
+ desc="get_device_change_last_converted_pos",
+ )
+ return row["stream_id"], row["room_id"]
+
+ async def set_device_change_last_converted_pos(
+ self,
+ stream_id: int,
+ room_id: str,
+ ) -> None:
+ """
+ Set the position of the last row in `device_list_changes_in_room` that has been
+ converted to `device_lists_outbound_pokes`.
+ """
+
+ await self.db_pool.simple_update_one(
+ table="device_lists_changes_converted_stream_position",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id, "room_id": room_id},
+ desc="set_device_change_last_converted_pos",
+ )
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index cf33e73e2b..c4ac6c33ba 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -33,7 +33,7 @@ from typing_extensions import Literal
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
- TransactionOneTimeKeyCounts,
+ TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -140,7 +140,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable
async def get_e2e_device_keys_for_cs_api(
self,
- query_list: List[Tuple[str, Optional[str]]],
+ query_list: Collection[Tuple[str, Optional[str]]],
include_displaynames: bool = True,
) -> Dict[str, Dict[str, JsonDict]]:
"""Fetch a list of device keys, formatted suitably for the C/S API.
@@ -514,7 +514,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def count_bulk_e2e_one_time_keys_for_as(
self, user_ids: Collection[str]
- ) -> TransactionOneTimeKeyCounts:
+ ) -> TransactionOneTimeKeysCount:
"""
Counts, in bulk, the one-time keys for all the users specified.
Intended to be used by application services for populating OTK counts in
@@ -528,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
def _count_bulk_e2e_one_time_keys_txn(
txn: LoggingTransaction,
- ) -> TransactionOneTimeKeyCounts:
+ ) -> TransactionOneTimeKeysCount:
user_in_where_clause, user_parameters = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids
)
@@ -541,7 +541,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
txn.execute(sql, user_parameters)
- result: TransactionOneTimeKeyCounts = {}
+ result: TransactionOneTimeKeysCount = {}
for user_id, device_id, algorithm, count in txn:
# We deliberately construct empty dictionaries for
@@ -1181,7 +1181,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs)
self._cross_signing_id_gen = StreamIdGenerator(
- db_conn, "e2e_cross_signing_keys", "stream_id"
+ db_conn,
+ hs.get_replication_notifier(),
+ "e2e_cross_signing_keys",
+ "stream_id",
)
async def set_e2e_device_keys(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 309a4ba664..bbee02ab18 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1686,7 +1686,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
},
insertion_values={},
desc="insert_insertion_extremity",
- lock=False,
)
async def insert_received_event_to_staging(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b283ab0f9c..3a0c370fde 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -74,6 +74,7 @@ receipt.
"""
import logging
+from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
@@ -95,6 +96,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ PostgresEngine,
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
@@ -272,15 +274,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
self._clear_old_push_actions_staging, 30 * 60 * 1000
)
- self.db_pool.updates.register_background_index_update(
- "event_push_summary_unique_index",
- index_name="event_push_summary_unique_index",
- table="event_push_summary",
- columns=["user_id", "room_id"],
- unique=True,
- replaces_index="event_push_summary_user_rm",
- )
-
self.db_pool.updates.register_background_index_update(
"event_push_summary_unique_index2",
index_name="event_push_summary_unique_index2",
@@ -463,6 +456,153 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result
+ async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
+ """Get the notification count by room for a user. Only considers notifications,
+ not highlight or unread counts, and threads are currently aggregated under their room.
+
+ This function is intentionally not cached because it is called to calculate the
+ unread badge for push notifications and thus the result is expected to change.
+
+ Note that this function assumes the user is a member of the room. Because
+ summary rows are not removed when a user leaves a room, the caller must
+ filter out those results from the result.
+
+ Returns:
+ A map of room ID to notification counts for the given user.
+ """
+ return await self.db_pool.runInteraction(
+ "get_unread_counts_by_room_for_user",
+ self._get_unread_counts_by_room_for_user_txn,
+ user_id,
+ )
+
+ def _get_unread_counts_by_room_for_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> Dict[str, int]:
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+ )
+ args.extend([user_id, user_id])
+
+ receipts_cte = f"""
+ WITH all_receipts AS (
+ SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
+ FROM receipts_linearized
+ LEFT JOIN events USING (room_id, event_id)
+ WHERE
+ {receipt_types_clause}
+ AND user_id = ?
+ GROUP BY room_id, thread_id
+ )
+ """
+
+ receipts_joins = """
+ LEFT JOIN (
+ SELECT room_id, thread_id,
+ max_receipt_stream_ordering AS threaded_receipt_stream_ordering
+ FROM all_receipts
+ WHERE thread_id IS NOT NULL
+ ) AS threaded_receipts USING (room_id, thread_id)
+ LEFT JOIN (
+ SELECT room_id, thread_id,
+ max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
+ FROM all_receipts
+ WHERE thread_id IS NULL
+ ) AS unthreaded_receipts USING (room_id)
+ """
+
+ # First get summary counts by room / thread for the user. We use the max receipt
+ # stream ordering of both threaded & unthreaded receipts to compare against the
+ # summary table.
+ #
+ # PostgreSQL and SQLite differ in comparing scalar numerics.
+ if isinstance(self.database_engine, PostgresEngine):
+ # GREATEST ignores NULLs.
+ max_clause = """GREATEST(
+ threaded_receipt_stream_ordering,
+ unthreaded_receipt_stream_ordering
+ )"""
+ else:
+ # MAX returns NULL if any are NULL, so COALESCE to 0 first.
+ max_clause = """MAX(
+ COALESCE(threaded_receipt_stream_ordering, 0),
+ COALESCE(unthreaded_receipt_stream_ordering, 0)
+ )"""
+
+ sql = f"""
+ {receipts_cte}
+ SELECT eps.room_id, eps.thread_id, notif_count
+ FROM event_push_summary AS eps
+ {receipts_joins}
+ WHERE user_id = ?
+ AND notif_count != 0
+ AND (
+ (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
+ OR last_receipt_stream_ordering = {max_clause}
+ )
+ """
+ txn.execute(sql, args)
+
+ seen_thread_ids = set()
+ room_to_count: Dict[str, int] = defaultdict(int)
+
+ for room_id, thread_id, notif_count in txn:
+ room_to_count[room_id] += notif_count
+ seen_thread_ids.add(thread_id)
+
+ # Now get any event push actions that haven't been rotated using the same OR
+ # join and filter by receipt and event push summary rotated up to stream ordering.
+ sql = f"""
+ {receipts_cte}
+ SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
+ FROM event_push_actions AS epa
+ {receipts_joins}
+ WHERE user_id = ?
+ AND epa.notif = 1
+ AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
+ AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
+ AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
+ GROUP BY epa.room_id, epa.thread_id
+ """
+ txn.execute(sql, args)
+
+ for room_id, thread_id, notif_count in txn:
+ # Note: only count push actions we have valid summaries for with up to date receipt.
+ if thread_id not in seen_thread_ids:
+ continue
+ room_to_count[room_id] += notif_count
+
+ thread_id_clause, thread_ids_args = make_in_list_sql_clause(
+ self.database_engine, "epa.thread_id", seen_thread_ids
+ )
+
+ # Finally re-check event_push_actions for any rooms not in the summary, ignoring
+ # the rotated up-to position. This handles the case where a read receipt has arrived
+ # but not been rotated meaning the summary table is out of date, so we go back to
+ # the push actions table.
+ sql = f"""
+ {receipts_cte}
+ SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
+ FROM event_push_actions AS epa
+ {receipts_joins}
+ WHERE user_id = ?
+ AND NOT {thread_id_clause}
+ AND epa.notif = 1
+ AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
+ AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
+ GROUP BY epa.room_id
+ """
+
+ args.extend(thread_ids_args)
+ txn.execute(sql, args)
+
+ for room_id, notif_count in txn:
+ room_to_count[room_id] += notif_count
+
+ return room_to_count
+
@cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index d68f127f9b..1536937b67 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1651,7 +1651,7 @@ class PersistEventsStore:
if self._ephemeral_messages_enabled:
# If there's an expiry timestamp on the event, store it.
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
- if isinstance(expiry_ts, int) and not event.is_state():
+ if type(expiry_ts) is int and not event.is_state():
self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
# Insert into the room_memberships table.
@@ -2049,6 +2049,10 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream(
txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
)
+ if rel_type == RelationTypes.REFERENCE:
+ self.store._invalidate_cache_and_stream(
+ txn, self.store.get_references_for_event, (redacted_relates_to,)
+ )
if rel_type == RelationTypes.REPLACE:
self.store._invalidate_cache_and_stream(
txn, self.store.get_applicable_edit, (redacted_relates_to,)
@@ -2129,10 +2133,10 @@ class PersistEventsStore:
):
if (
"min_lifetime" in event.content
- and not isinstance(event.content.get("min_lifetime"), int)
+ and type(event.content["min_lifetime"]) is not int
) or (
"max_lifetime" in event.content
- and not isinstance(event.content.get("max_lifetime"), int)
+ and type(event.content["max_lifetime"]) is not int
):
# Ignore the event if one of the value isn't an integer.
return
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 9e31798ab1..b9d3c36d60 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -69,6 +69,8 @@ class _BackgroundUpdates:
EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
+ EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index"
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
@@ -260,6 +262,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._background_events_populate_state_key_rejections,
)
+ # Add an index that would be useful for jumping to date using
+ # get_event_id_for_timestamp.
+ self.db_pool.updates.register_background_index_update(
+ _BackgroundUpdates.EVENTS_JUMP_TO_DATE_INDEX,
+ index_name="events_jump_to_date_idx",
+ table="events",
+ columns=["room_id", "origin_server_ts"],
+ where_clause="NOT outlier",
+ )
+
async def _background_reindex_fields_sender(
self, progress: JsonDict, batch_size: int
) -> int:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 7405db5b3d..a9259fe446 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -16,11 +16,11 @@ import logging
import threading
import weakref
from enum import Enum, auto
+from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Collection,
- Container,
Dict,
Iterable,
List,
@@ -38,7 +38,7 @@ from typing_extensions import Literal
from twisted.internet import defer
-from synapse.api.constants import EventTypes
+from synapse.api.constants import Direction, EventTypes
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
@@ -59,8 +59,9 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
-from synapse.replication.tcp.streams import BackfillStream
+from synapse.replication.tcp.streams import BackfillStream, UnPartialStatedEventStream
from synapse.replication.tcp.streams.events import EventsStream
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -70,12 +71,14 @@ from synapse.storage.database import (
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
+from synapse.types.state import StateFilter
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
@@ -107,6 +110,10 @@ event_fetch_ongoing_gauge = Gauge(
)
+class InvalidEventError(Exception):
+ """The event retrieved from the database is invalid and cannot be used."""
+
+
@attr.s(slots=True, auto_attribs=True)
class EventCacheEntry:
event: EventBase
@@ -188,6 +195,7 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="events",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
@@ -197,6 +205,7 @@ class EventsWorkerStore(SQLBaseStore):
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="backfill",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
@@ -214,12 +223,14 @@ class EventsWorkerStore(SQLBaseStore):
# SQLite).
self._stream_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"events",
"stream_ordering",
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"events",
"stream_ordering",
step=-1,
@@ -291,19 +302,128 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
+ self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator
+
+ if isinstance(database.engine, PostgresEngine):
+ self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="un_partial_stated_event_stream",
+ instance_name=hs.get_instance_name(),
+ tables=[
+ ("un_partial_stated_event_stream", "instance_name", "stream_id")
+ ],
+ sequence_name="un_partial_stated_event_stream_sequence",
+ # TODO(faster_joins, multiple writers) Support multiple writers.
+ writers=["master"],
+ )
+ else:
+ self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
+ db_conn,
+ hs.get_replication_notifier(),
+ "un_partial_stated_event_stream",
+ "stream_id",
+ )
+
+ def get_un_partial_stated_events_token(self, instance_name: str) -> int:
+ return (
+ self._un_partial_stated_events_stream_id_gen.get_current_token_for_writer(
+ instance_name
+ )
+ )
+
+ async def get_un_partial_stated_events_from_stream(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, Tuple[str, bool]]], int, bool]:
+ """Get updates for the un-partial-stated events replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
+ if last_id == current_id:
+ return [], current_id, False
+
+ def get_un_partial_stated_events_from_stream_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str, bool]]], int, bool]:
+ sql = """
+ SELECT stream_id, event_id, rejection_status_changed
+ FROM un_partial_stated_event_stream
+ WHERE ? < stream_id AND stream_id <= ? AND instance_name = ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
+ updates = [
+ (
+ row[0],
+ (
+ row[1],
+ bool(row[2]),
+ ),
+ )
+ for row in txn
+ ]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db_pool.runInteraction(
+ "get_un_partial_stated_events_from_stream",
+ get_un_partial_stated_events_from_stream_txn,
+ )
+
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
+ ) -> None:
+ if stream_name == UnPartialStatedEventStream.NAME:
+ for row in rows:
+ assert isinstance(row, UnPartialStatedEventStreamRow)
+
+ self.is_partial_state_event.invalidate((row.event_id,))
+
+ if row.rejection_status_changed:
+ # If the partial-stated event became rejected or unrejected
+ # when it wasn't before, we need to invalidate this cache.
+ self._invalidate_local_get_event_cache(row.event_id)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(instance_name, -token)
-
- super().process_replication_rows(stream_name, instance_name, token, rows)
+ elif stream_name == UnPartialStatedEventStream.NAME:
+ self._un_partial_stated_events_stream_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased
@@ -879,7 +999,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_stripped_room_state_from_event_context(
self,
context: EventContext,
- state_types_to_include: Container[str],
+ state_keys_to_include: StateFilter,
membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
@@ -892,7 +1012,7 @@ class EventsWorkerStore(SQLBaseStore):
Args:
context: The event context to retrieve state of the room from.
- state_types_to_include: The type of state events to include.
+ state_keys_to_include: The state events to include, for each event type.
membership_user_id: An optional user ID to include the stripped membership state
events of. This is useful when generating the stripped state of a room for
invites. We want to send membership events of the inviter, so that the
@@ -901,21 +1021,25 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
A list of dictionaries, each representing a stripped state event from the room.
"""
- current_state_ids = await context.get_current_state_ids()
+ if membership_user_id:
+ types = chain(
+ state_keys_to_include.to_types(),
+ [(EventTypes.Member, membership_user_id)],
+ )
+ filter = StateFilter.from_types(types)
+ else:
+ filter = state_keys_to_include
+ selected_state_ids = await context.get_current_state_ids(filter)
# We know this event is not an outlier, so this must be
# non-None.
- assert current_state_ids is not None
+ assert selected_state_ids is not None
- # The state to include
- state_to_include_ids = [
- e_id
- for k, e_id in current_state_ids.items()
- if k[0] in state_types_to_include
- or (membership_user_id and k == (EventTypes.Member, membership_user_id))
- ]
+ # Confusingly, get_current_state_events may return events that are discarded by
+ # the filter, if they're in context._state_delta_due_to_event. Strip these away.
+ selected_state_ids = filter.filter_state(selected_state_ids)
- state_to_include = await self.get_events(state_to_include_ids)
+ state_to_include = await self.get_events(selected_state_ids.values())
return [
{
@@ -1190,7 +1314,7 @@ class EventsWorkerStore(SQLBaseStore):
# invites, so just accept it for all membership events.
#
if d["type"] != EventTypes.Member:
- raise Exception(
+ raise InvalidEventError(
"Room %s for event %s is unknown" % (d["room_id"], event_id)
)
@@ -2116,7 +2240,7 @@ class EventsWorkerStore(SQLBaseStore):
)
async def get_event_id_for_timestamp(
- self, room_id: str, timestamp: int, direction: str
+ self, room_id: str, timestamp: int, direction: Direction
) -> Optional[str]:
"""Find the closest event to the given timestamp in the given direction.
@@ -2124,14 +2248,14 @@ class EventsWorkerStore(SQLBaseStore):
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
+ direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
The closest event_id otherwise None if we can't find any event in
the given direction.
"""
- if direction == "b":
+ if direction == Direction.BACKWARDS:
# Find closest event *before* a given timestamp. We use descending
# (which gives values largest to smallest) because we want the
# largest possible timestamp *before* the given timestamp.
@@ -2183,9 +2307,6 @@ class EventsWorkerStore(SQLBaseStore):
return None
- if direction not in ("f", "b"):
- raise ValueError("Unknown direction: %s" % (direction,))
-
return await self.db_pool.runInteraction(
"get_event_id_for_timestamp_txn",
get_event_id_for_timestamp_txn,
@@ -2287,6 +2408,9 @@ class EventsWorkerStore(SQLBaseStore):
This can happen, for example, when resyncing state during a faster join.
+ It is the caller's responsibility to ensure that other workers are
+ sent a notification so that they call `_invalidate_local_get_event_cache()`.
+
Args:
txn:
event_id: ID of event to update
@@ -2325,14 +2449,3 @@ class EventsWorkerStore(SQLBaseStore):
)
self.invalidate_get_event_cache_after_txn(txn, event_id)
-
- # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
- # call '_send_invalidation_to_replication', but we actually need the other
- # end to call _invalidate_local_get_event_cache() rather than (just)
- # _get_event_cache.invalidate().
- #
- # One solution might be to (somehow) get the workers to call
- # _invalidate_caches_for_event() (though that will invalidate more than
- # strictly necessary).
- #
- # https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 9b172a64d8..b202c5eb87 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -26,6 +26,7 @@ from typing import (
cast,
)
+from synapse.api.constants import Direction
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -176,7 +177,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
limit: int,
user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value,
- direction: str = "f",
+ direction: Direction = Direction.FORWARDS,
) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@@ -199,7 +200,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
# Set ordering
order_by_column = MediaSortOrder(order_by).value
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 9769a18a9d..beb210f8ee 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -77,6 +77,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._presence_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="presence_stream",
instance_name=self._instance_name,
tables=[("presence_stream", "instance_name", "stream_id")],
@@ -85,7 +86,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
)
else:
self._presence_id_gen = StreamIdGenerator(
- db_conn, "presence_stream", "stream_id"
+ db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
)
self.hs = hs
@@ -439,8 +440,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
rows: Iterable[Any],
) -> None:
if stream_name == PresenceStream.NAME:
- self._presence_id_gen.advance(instance_name, token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == PresenceStream.NAME:
+ self._presence_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 12ad44dbb3..9b2bbe060d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -84,7 +84,13 @@ def _load_rules(
push_rules = PushRules(ruleslist)
filtered_rules = FilteredPushRules(
- push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled
+ push_rules,
+ enabled_map,
+ msc1767_enabled=experimental_config.msc1767_enabled,
+ msc3664_enabled=experimental_config.msc3664_enabled,
+ msc3381_polls_enabled=experimental_config.msc3381_polls_enabled,
+ msc3952_intentional_mentions=experimental_config.msc3952_intentional_mentions,
+ msc3958_suppress_edits_enabled=experimental_config.msc3958_supress_edit_notifs,
)
return filtered_rules
@@ -114,6 +120,7 @@ class PushRulesWorkerStore(
# class below that is used on the main process.
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"push_rules_stream",
"stream_id",
is_writer=hs.config.worker.worker_app is None,
@@ -151,6 +158,13 @@ class PushRulesWorkerStore(
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == PushRulesStream.NAME:
+ self._push_rules_stream_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index fee37b9ce4..df53e726e6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -62,6 +62,7 @@ class PusherWorkerStore(SQLBaseStore):
# class below that is used on the main process.
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
@@ -111,12 +112,12 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
- def process_replication_rows(
- self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token)
- return super().process_replication_rows(stream_name, instance_name, token, rows)
+ super().process_replication_position(stream_name, instance_name, token)
async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str
@@ -325,14 +326,11 @@ class PusherWorkerStore(SQLBaseStore):
async def set_throttle_params(
self, pusher_id: str, room_id: str, params: ThrottleParams
) -> None:
- # no need to lock because `pusher_throttle` has a primary key on
- # (pusher, room_id) so simple_upsert will retry
await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
{"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
desc="set_throttle_params",
- lock=False,
)
async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int:
@@ -589,8 +587,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
device_id: Optional[str] = None,
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
- # no need to lock because `pushers` has a unique key on
- # (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -609,7 +605,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
"device_id": device_id,
},
desc="add_pusher",
- lock=False,
)
user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index a580e4bdda..29972d5204 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -73,6 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="receipts",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
@@ -91,6 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# SQLite).
self._receipts_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"receipts_linearized",
"stream_id",
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
@@ -588,6 +590,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == ReceiptsStream.NAME:
+ self._receipts_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
def _insert_linearized_receipt_txn(
self,
txn: LoggingTransaction,
@@ -924,39 +933,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
return batch_size
- async def _create_receipts_index(self, index_name: str, table: str) -> None:
- """Adds a unique index on `(room_id, receipt_type, user_id)` to the given
- receipts table, for non-thread receipts."""
-
- def _create_index(conn: LoggingDatabaseConnection) -> None:
- conn.rollback()
-
- # we have to set autocommit, because postgres refuses to
- # CREATE INDEX CONCURRENTLY without it.
- if isinstance(self.database_engine, PostgresEngine):
- conn.set_session(autocommit=True)
-
- try:
- c = conn.cursor()
-
- # Now that the duplicates are gone, we can create the index.
- concurrently = (
- "CONCURRENTLY"
- if isinstance(self.database_engine, PostgresEngine)
- else ""
- )
- sql = f"""
- CREATE UNIQUE INDEX {concurrently} {index_name}
- ON {table}(room_id, receipt_type, user_id)
- WHERE thread_id IS NULL
- """
- c.execute(sql)
- finally:
- if isinstance(self.database_engine, PostgresEngine):
- conn.set_session(autocommit=False)
-
- await self.db_pool.runWithConnection(_create_index)
-
async def _background_receipts_linearized_unique_index(
self, progress: dict, batch_size: int
) -> int:
@@ -965,10 +941,14 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
receipts."""
def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
+ if isinstance(self.database_engine, PostgresEngine):
+ ROW_ID_NAME = "ctid"
+ else:
+ ROW_ID_NAME = "rowid"
+
# Identify any duplicate receipts arising from
# https://github.com/matrix-org/synapse/issues/14406.
- # We expect the following query to use the per-thread receipt index and take
- # less than a minute.
+ # The following query takes less than a minute on matrix.org.
sql = """
SELECT MAX(stream_id), room_id, receipt_type, user_id
FROM receipts_linearized
@@ -980,28 +960,45 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn))
# Then remove duplicate receipts, keeping the one with the highest
- # `stream_id`. There should only be a single receipt with any given
- # `stream_id`.
- for max_stream_id, room_id, receipt_type, user_id in duplicate_keys:
- sql = """
+ # `stream_id`. Since there might be duplicate rows with the same
+ # `stream_id`, we delete by the ctid instead.
+ for stream_id, room_id, receipt_type, user_id in duplicate_keys:
+ sql = f"""
+ SELECT {ROW_ID_NAME}
+ FROM receipts_linearized
+ WHERE
+ room_id = ? AND
+ receipt_type = ? AND
+ user_id = ? AND
+ thread_id IS NULL AND
+ stream_id = ?
+ LIMIT 1
+ """
+ txn.execute(sql, (room_id, receipt_type, user_id, stream_id))
+ row_id = cast(Tuple[str], txn.fetchone())[0]
+
+ sql = f"""
DELETE FROM receipts_linearized
WHERE
room_id = ? AND
receipt_type = ? AND
user_id = ? AND
thread_id IS NULL AND
- stream_id < ?
+ {ROW_ID_NAME} != ?
"""
- txn.execute(sql, (room_id, receipt_type, user_id, max_stream_id))
+ txn.execute(sql, (room_id, receipt_type, user_id, row_id))
await self.db_pool.runInteraction(
self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME,
_remote_duplicate_receipts_txn,
)
- await self._create_receipts_index(
- "receipts_linearized_unique_index",
- "receipts_linearized",
+ await self.db_pool.updates.create_index_in_background(
+ index_name="receipts_linearized_unique_index",
+ table="receipts_linearized",
+ columns=["room_id", "receipt_type", "user_id"],
+ where_clause="thread_id IS NULL",
+ unique=True,
)
await self.db_pool.updates._end_background_update(
@@ -1050,9 +1047,12 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
_remote_duplicate_receipts_txn,
)
- await self._create_receipts_index(
- "receipts_graph_unique_index",
- "receipts_graph",
+ await self.db_pool.updates.create_index_in_background(
+ index_name="receipts_graph_unique_index",
+ table="receipts_graph",
+ columns=["room_id", "receipt_type", "user_id"],
+ where_clause="thread_id IS NULL",
+ unique=True,
)
await self.db_pool.updates._end_background_update(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index ca431002c8..0018d6f7ab 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,6 +20,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -29,7 +30,7 @@ from typing import (
import attr
-from synapse.api.constants import MAIN_TIMELINE, RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
@@ -39,9 +40,13 @@ from synapse.storage.database import (
LoggingTransaction,
make_in_list_sql_clause,
)
-from synapse.storage.databases.main.stream import generate_pagination_where_clause
+from synapse.storage.databases.main.stream import (
+ generate_next_token,
+ generate_pagination_bounds,
+ generate_pagination_where_clause,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
+from synapse.types import JsonDict, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -81,8 +86,6 @@ class _RelatedEvent:
event_id: str
# The sender of the related event.
sender: str
- topological_ordering: Optional[int]
- stream_ordering: int
class RelationsWorkerStore(SQLBaseStore):
@@ -165,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore):
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 5,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
@@ -178,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore):
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the most recent `limit` events.
- direction: Whether to fetch the most recent first (`"b"`) or the
- oldest first (`"f"`).
+ direction: Whether to fetch the most recent first (backwards) or the
+ oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
@@ -208,24 +211,23 @@ class RelationsWorkerStore(SQLBaseStore):
where_clause.append("type = ?")
where_args.append(event_type)
+ order, from_bound, to_bound = generate_pagination_bounds(
+ direction,
+ from_token.room_key if from_token else None,
+ to_token.room_key if to_token else None,
+ )
+
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token.room_key.as_historical_tuple()
- if from_token
- else None,
- to_token=to_token.room_key.as_historical_tuple() if to_token else None,
+ from_token=from_bound,
+ to_token=to_bound,
engine=self.database_engine,
)
if pagination_clause:
where_clause.append(pagination_clause)
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
sql = """
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
@@ -245,13 +247,17 @@ class RelationsWorkerStore(SQLBaseStore):
txn.execute(sql, where_args + [limit + 1])
events = []
- for event_id, relation_type, sender, topo_ordering, stream_ordering in txn:
+ topo_orderings: List[int] = []
+ stream_orderings: List[int] = []
+ for event_id, relation_type, sender, topo_ordering, stream_ordering in cast(
+ List[Tuple[str, str, str, int, int]], txn
+ ):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or relation_type != RelationTypes.REPLACE:
- events.append(
- _RelatedEvent(event_id, sender, topo_ordering, stream_ordering)
- )
+ events.append(_RelatedEvent(event_id, sender))
+ topo_orderings.append(topo_ordering)
+ stream_orderings.append(stream_ordering)
# If there are more events, generate the next pagination key from the
# last event returned.
@@ -260,17 +266,12 @@ class RelationsWorkerStore(SQLBaseStore):
# Instead of using the last row (which tells us there is more
# data), use the last row to be returned.
events = events[:limit]
+ topo_orderings = topo_orderings[:limit]
+ stream_orderings = stream_orderings[:limit]
- topo = events[-1].topological_ordering
- token = events[-1].stream_ordering
- if direction == "b":
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- token -= 1
- next_key = RoomStreamToken(topo, token)
+ next_key = generate_next_token(
+ direction, topo_orderings[-1], stream_orderings[-1]
+ )
if from_token:
next_token = from_token.copy_and_replace(
@@ -287,6 +288,7 @@ class RelationsWorkerStore(SQLBaseStore):
to_device_key=0,
device_list_key=0,
groups_key=0,
+ un_partial_stated_rooms_key=0,
)
return events[:limit], next_token
@@ -394,111 +396,195 @@ class RelationsWorkerStore(SQLBaseStore):
)
return result is not None
- @cached(tree=True)
- async def get_aggregation_groups_for_event(
- self, event_id: str, room_id: str, limit: int = 5
- ) -> List[JsonDict]:
- """Get a list of annotations on the event, grouped by event type and
+ @cached()
+ async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
+ )
+ async def get_aggregation_groups_for_events(
+ self, event_ids: Collection[str]
+ ) -> Mapping[str, Optional[List[JsonDict]]]:
+ """Get a list of annotations on the given events, grouped by event type and
aggregation key, sorted by count.
This is used e.g. to get the what and how many reactions have happend
on an event.
Args:
- event_id: Fetch events that relate to this event ID.
- room_id: The room the event belongs to.
- limit: Only fetch the `limit` groups.
+ event_ids: Fetch events that relate to these event IDs.
Returns:
- List of groups of annotations that match. Each row is a dict with
- `type`, `key` and `count` fields.
+ A map of event IDs to a list of groups of annotations that match.
+ Each entry is a dict with `type`, `key` and `count` fields.
+ """
+ # The number of entries to return per event ID.
+ limit = 5
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "relates_to_id", event_ids
+ )
+ args.append(RelationTypes.ANNOTATION)
+
+ sql = f"""
+ SELECT
+ relates_to_id,
+ annotation.type,
+ aggregation_key,
+ COUNT(DISTINCT annotation.sender)
+ FROM events AS annotation
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = annotation.room_id
+ WHERE
+ {clause}
+ AND relation_type = ?
+ GROUP BY relates_to_id, annotation.type, aggregation_key
+ ORDER BY relates_to_id, COUNT(*) DESC
"""
- args = [
- event_id,
- room_id,
- RelationTypes.ANNOTATION,
- limit,
- ]
-
- sql = """
- SELECT type, aggregation_key, COUNT(DISTINCT sender)
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
- GROUP BY relation_type, type, aggregation_key
- ORDER BY COUNT(*) DESC
- LIMIT ?
- """
-
- def _get_aggregation_groups_for_event_txn(
+ def _get_aggregation_groups_for_events_txn(
txn: LoggingTransaction,
- ) -> List[JsonDict]:
+ ) -> Mapping[str, List[JsonDict]]:
txn.execute(sql, args)
- return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
+ result: Dict[str, List[JsonDict]] = {}
+ for event_id, type, key, count in cast(
+ List[Tuple[str, str, str, int]], txn
+ ):
+ event_results = result.setdefault(event_id, [])
+
+ # Limit the number of results per event ID.
+ if len(event_results) == limit:
+ continue
+
+ event_results.append({"type": type, "key": key, "count": count})
+
+ return result
return await self.db_pool.runInteraction(
- "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+ "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn
)
async def get_aggregation_groups_for_users(
- self,
- event_id: str,
- room_id: str,
- limit: int,
- users: FrozenSet[str] = frozenset(),
- ) -> Dict[Tuple[str, str], int]:
+ self, event_ids: Collection[str], users: FrozenSet[str]
+ ) -> Dict[str, Dict[Tuple[str, str], int]]:
"""Fetch the partial aggregations for an event for specific users.
This is used, in conjunction with get_aggregation_groups_for_event, to
remove information from the results for ignored users.
Args:
- event_id: Fetch events that relate to this event ID.
- room_id: The room the event belongs to.
- limit: Only fetch the `limit` groups.
+ event_ids: Fetch events that relate to these event IDs.
users: The users to fetch information for.
Returns:
- A map of (event type, aggregation key) to a count of users.
+ A map of event ID to a map of (event type, aggregation key) to a
+ count of users.
"""
if not users:
return {}
- args: List[Union[str, int]] = [
- event_id,
- room_id,
- RelationTypes.ANNOTATION,
- ]
+ events_sql, args = make_in_list_sql_clause(
+ self.database_engine, "relates_to_id", event_ids
+ )
users_sql, users_args = make_in_list_sql_clause(
- self.database_engine, "sender", users
+ self.database_engine, "annotation.sender", users
)
args.extend(users_args)
+ args.append(RelationTypes.ANNOTATION)
sql = f"""
- SELECT type, aggregation_key, COUNT(DISTINCT sender)
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
- GROUP BY relation_type, type, aggregation_key
- ORDER BY COUNT(*) DESC
- LIMIT ?
+ SELECT
+ relates_to_id,
+ annotation.type,
+ aggregation_key,
+ COUNT(DISTINCT annotation.sender)
+ FROM events AS annotation
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = annotation.room_id
+ WHERE {events_sql} AND {users_sql} AND relation_type = ?
+ GROUP BY relates_to_id, annotation.type, aggregation_key
+ ORDER BY relates_to_id, COUNT(*) DESC
"""
def _get_aggregation_groups_for_users_txn(
txn: LoggingTransaction,
- ) -> Dict[Tuple[str, str], int]:
- txn.execute(sql, args + [limit])
+ ) -> Dict[str, Dict[Tuple[str, str], int]]:
+ txn.execute(sql, args)
- return {(row[0], row[1]): row[2] for row in txn}
+ result: Dict[str, Dict[Tuple[str, str], int]] = {}
+ for event_id, type, key, count in cast(
+ List[Tuple[str, str, str, int]], txn
+ ):
+ result.setdefault(event_id, {})[(type, key)] = count
+
+ return result
return await self.db_pool.runInteraction(
"get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
)
+ @cached()
+ async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
+ async def get_references_for_events(
+ self, event_ids: Collection[str]
+ ) -> Mapping[str, Optional[List[_RelatedEvent]]]:
+ """Get a list of references to the given events.
+
+ Args:
+ event_ids: Fetch events that relate to these event IDs.
+
+ Returns:
+ A map of event IDs to a list of related event IDs (and their senders).
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "relates_to_id", event_ids
+ )
+ args.append(RelationTypes.REFERENCE)
+
+ sql = f"""
+ SELECT relates_to_id, ref.event_id, ref.sender
+ FROM events AS ref
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = ref.room_id
+ WHERE
+ {clause}
+ AND relation_type = ?
+ ORDER BY ref.topological_ordering, ref.stream_ordering
+ """
+
+ def _get_references_for_events_txn(
+ txn: LoggingTransaction,
+ ) -> Mapping[str, List[_RelatedEvent]]:
+ txn.execute(sql, args)
+
+ result: Dict[str, List[_RelatedEvent]] = {}
+ for relates_to_id, event_id, sender in cast(
+ List[Tuple[str, str, str]], txn
+ ):
+ result.setdefault(relates_to_id, []).append(
+ _RelatedEvent(event_id, sender)
+ )
+
+ return result
+
+ return await self.db_pool.runInteraction(
+ "_get_references_for_events_txn", _get_references_for_events_txn
+ )
+
@cached()
def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
raise NotImplementedError()
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 4fbaefad73..644bbb8878 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1,5 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019, 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ from abc import abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Any,
Awaitable,
Collection,
@@ -25,7 +26,7 @@ from typing import (
List,
Mapping,
Optional,
- Sequence,
+ Set,
Tuple,
Union,
cast,
@@ -34,6 +35,7 @@ from typing import (
import attr
from synapse.api.constants import (
+ Direction,
EventContentFields,
EventTypes,
JoinRules,
@@ -43,6 +45,7 @@ from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -50,11 +53,17 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import IdGenerator
-from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ IdGenerator,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
+from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.stringutils import MXC_REGEX
if TYPE_CHECKING:
@@ -100,7 +109,7 @@ class RoomSortOrder(Enum):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PartialStateResyncInfo:
joined_via: Optional[str]
- servers_in_room: List[str] = attr.ib(factory=list)
+ servers_in_room: Set[str] = attr.ib(factory=set)
class RoomWorkerStore(CacheInvalidationWorkerStore):
@@ -114,6 +123,37 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self.config: HomeServerConfig = hs.config
+ self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
+
+ if isinstance(database.engine, PostgresEngine):
+ self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="un_partial_stated_room_stream",
+ instance_name=self._instance_name,
+ tables=[
+ ("un_partial_stated_room_stream", "instance_name", "stream_id")
+ ],
+ sequence_name="un_partial_stated_room_stream_sequence",
+ # TODO(faster_joins, multiple writers) Support multiple writers.
+ writers=["master"],
+ )
+ else:
+ self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
+ db_conn,
+ hs.get_replication_notifier(),
+ "un_partial_stated_room_stream",
+ "stream_id",
+ )
+
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == UnPartialStatedRoomStream.NAME:
+ self._un_partial_stated_rooms_stream_id_gen.advance(instance_name, token)
+ return super().process_replication_position(stream_name, instance_name, token)
+
async def store_room(
self,
room_id: str,
@@ -912,7 +952,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
event_json = db_to_json(content_json)
content = event_json["content"]
content_url = content.get("url")
- thumbnail_url = content.get("info", {}).get("thumbnail_url")
+ info = content.get("info")
+ if isinstance(info, dict):
+ thumbnail_url = info.get("thumbnail_url")
+ else:
+ thumbnail_url = None
for url in (content_url, thumbnail_url):
if not url:
@@ -1149,21 +1193,35 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_rooms_for_retention_period_in_range_txn,
)
- @cached(iterable=True)
- async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
- """Gets the list of servers in a partial state room at the time we joined it.
+ async def get_partial_state_servers_at_join(
+ self, room_id: str
+ ) -> Optional[AbstractSet[str]]:
+ """Gets the set of servers in a partial state room at the time we joined it.
Returns:
The `servers_in_room` list from the `/send_join` response for partial state
rooms. May not be accurate or complete, as it comes from a remote
homeserver.
- An empty list for full state rooms.
+ `None` for full state rooms.
"""
- return await self.db_pool.simple_select_onecol(
- "partial_state_rooms_servers",
- keyvalues={"room_id": room_id},
- retcol="server_name",
- desc="get_partial_state_servers_at_join",
+ servers_in_room = await self._get_partial_state_servers_at_join(room_id)
+
+ if len(servers_in_room) == 0:
+ return None
+
+ return servers_in_room
+
+ @cached(iterable=True)
+ async def _get_partial_state_servers_at_join(
+ self, room_id: str
+ ) -> AbstractSet[str]:
+ return frozenset(
+ await self.db_pool.simple_select_onecol(
+ "partial_state_rooms_servers",
+ keyvalues={"room_id": room_id},
+ retcol="server_name",
+ desc="get_partial_state_servers_at_join",
+ )
)
async def get_partial_state_room_resync_info(
@@ -1208,75 +1266,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# partial-joined between the two SELECTs, but this is unlikely to happen
# in practice.)
continue
- entry.servers_in_room.append(server_name)
+ entry.servers_in_room.add(server_name)
return room_servers
- async def clear_partial_state_room(self, room_id: str) -> bool:
- """Clears the partial state flag for a room.
-
- Args:
- room_id: The room whose partial state flag is to be cleared.
-
- Returns:
- `True` if the partial state flag has been cleared successfully.
-
- `False` if the partial state flag could not be cleared because the room
- still contains events with partial state.
- """
- try:
- await self.db_pool.runInteraction(
- "clear_partial_state_room", self._clear_partial_state_room_txn, room_id
- )
- return True
- except self.db_pool.engine.module.IntegrityError as e:
- # Assume that any `IntegrityError`s are due to partial state events.
- logger.info(
- "Exception while clearing lazy partial-state-room %s, retrying: %s",
- room_id,
- e,
- )
- return False
-
- def _clear_partial_state_room_txn(
- self, txn: LoggingTransaction, room_id: str
- ) -> None:
- DatabasePool.simple_delete_txn(
- txn,
- table="partial_state_rooms_servers",
- keyvalues={"room_id": room_id},
- )
- DatabasePool.simple_delete_one_txn(
- txn,
- table="partial_state_rooms",
- keyvalues={"room_id": room_id},
- )
- self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
- self._invalidate_cache_and_stream(
- txn, self.get_partial_state_servers_at_join, (room_id,)
- )
-
- # We now delete anything from `device_lists_remote_pending` with a
- # stream ID less than the minimum
- # `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
- device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
- txn,
- table="partial_state_rooms",
- keyvalues={},
- retcol="MIN(device_lists_stream_id)",
- allow_none=True,
- )
- if device_lists_stream_id is None:
- # There are no rooms being currently partially joined, so we delete everything.
- txn.execute("DELETE FROM device_lists_remote_pending")
- else:
- sql = """
- DELETE FROM device_lists_remote_pending
- WHERE stream_id <= ?
- """
- txn.execute(sql, (device_lists_stream_id,))
-
- @cached()
+ @cached(max_entries=10000)
async def is_partial_state_room(self, room_id: str) -> bool:
"""Checks if this room has partial state.
@@ -1295,6 +1289,27 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return entry is not None
+ @cachedList(cached_method_name="is_partial_state_room", list_name="room_ids")
+ async def is_partial_state_room_batched(
+ self, room_ids: StrCollection
+ ) -> Mapping[str, bool]:
+ """Checks if the given rooms have partial state.
+
+ Returns true for "partial-state" rooms, which means that the state
+ at events in the room, and `current_state_events`, may not yet be
+ complete.
+ """
+
+ rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch(
+ table="partial_state_rooms",
+ column="room_id",
+ iterable=room_ids,
+ retcols=("room_id",),
+ desc="is_partial_state_room_batched",
+ )
+ partial_state_rooms = {row_dict["room_id"] for row_dict in rows}
+ return {room_id: room_id in partial_state_rooms for room_id in room_ids}
+
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
self, room_id: str
) -> Tuple[str, int]:
@@ -1311,6 +1326,97 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
return result["join_event_id"], result["device_lists_stream_id"]
+ def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
+ return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
+ instance_name
+ )
+
+ async def get_un_partial_stated_rooms_between(
+ self, last_id: int, current_id: int, room_ids: Collection[str]
+ ) -> Set[str]:
+ """Get all rooms that got un partial stated between `last_id` exclusive and
+ `current_id` inclusive.
+
+ Returns:
+ The list of room ids.
+ """
+
+ if last_id == current_id:
+ return set()
+
+ def _get_un_partial_stated_rooms_between_txn(
+ txn: LoggingTransaction,
+ ) -> Set[str]:
+ sql = """
+ SELECT DISTINCT room_id FROM un_partial_stated_room_stream
+ WHERE ? < stream_id AND stream_id <= ? AND
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+
+ txn.execute(sql + clause, [last_id, current_id] + args)
+
+ return {r[0] for r in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_un_partial_stated_rooms_between",
+ _get_un_partial_stated_rooms_between_txn,
+ )
+
+ async def get_un_partial_stated_rooms_from_stream(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
+ """Get updates for un partial stated rooms replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
+ if last_id == current_id:
+ return [], current_id, False
+
+ def get_un_partial_stated_rooms_from_stream_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
+ sql = """
+ SELECT stream_id, room_id
+ FROM un_partial_stated_room_stream
+ WHERE ? < stream_id AND stream_id <= ? AND instance_name = ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
+ updates = [(row[0], (row[1],)) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db_pool.runInteraction(
+ "get_un_partial_stated_rooms_from_stream",
+ get_un_partial_stated_rooms_from_stream_txn,
+ )
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1802,6 +1908,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
+ self._instance_name = hs.get_instance_name()
+
async def upsert_room_on_join(
self, room_id: str, room_version: RoomVersion, state_events: List[EventBase]
) -> None:
@@ -1843,15 +1951,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
"creator": room_creator,
"has_auth_chain_index": has_auth_chain_index,
},
- # rooms has a unique constraint on room_id, so no need to lock when doing an
- # emulated upsert.
- lock=False,
)
async def store_partial_state_room(
self,
room_id: str,
- servers: Collection[str],
+ servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
@@ -1866,11 +1971,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
room_id: the ID of the room
- servers: other servers known to be in the room
+ servers: other servers known to be in the room. must include `joined_via`.
device_lists_stream_id: the device_lists stream ID at the time when we first
joined the room.
joined_via: the server name we requested a partial join from.
"""
+ assert joined_via in servers
+
await self.db_pool.runInteraction(
"store_partial_state_room",
self._store_partial_state_room_txn,
@@ -1884,7 +1991,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self,
txn: LoggingTransaction,
room_id: str,
- servers: Collection[str],
+ servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
@@ -1907,7 +2014,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
- txn, self.get_partial_state_servers_at_join, (room_id,)
+ txn, self._get_partial_state_servers_at_join, (room_id,)
)
async def write_partial_state_rooms_join_event_id(
@@ -1966,9 +2073,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
"creator": "",
"has_auth_chain_index": has_auth_chain_index,
},
- # rooms has a unique constraint on room_id, so no need to lock when doing an
- # emulated upsert.
- lock=False,
)
async def set_room_is_public(self, room_id: str, is_public: bool) -> None:
@@ -2117,7 +2221,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self,
start: int,
limit: int,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
user_id: Optional[str] = None,
room_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
@@ -2126,8 +2230,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
start: event offset to begin the query from
limit: number of rows to retrieve
- direction: Whether to fetch the most recent first (`"b"`) or the
- oldest first (`"f"`)
+ direction: Whether to fetch the most recent first (backwards) or the
+ oldest first (forwards)
user_id: search for user_id. Ignored if user_id is None
room_id: search for room_id. Ignored if room_id is None
Returns:
@@ -2149,7 +2253,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
filters.append("er.room_id LIKE ?")
args.extend(["%" + room_id + "%"])
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
@@ -2272,3 +2376,84 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self.is_room_blocked,
(room_id,),
)
+
+ async def clear_partial_state_room(self, room_id: str) -> Optional[int]:
+ """Clears the partial state flag for a room.
+
+ Args:
+ room_id: The room whose partial state flag is to be cleared.
+
+ Returns:
+ The corresponding stream id for the un-partial-stated rooms stream.
+
+ `None` if the partial state flag could not be cleared because the room
+ still contains events with partial state.
+ """
+ try:
+ async with self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id:
+ await self.db_pool.runInteraction(
+ "clear_partial_state_room",
+ self._clear_partial_state_room_txn,
+ room_id,
+ un_partial_state_room_stream_id,
+ )
+ return un_partial_state_room_stream_id
+ except self.db_pool.engine.module.IntegrityError as e:
+ # Assume that any `IntegrityError`s are due to partial state events.
+ logger.info(
+ "Exception while clearing lazy partial-state-room %s, retrying: %s",
+ room_id,
+ e,
+ )
+ return None
+
+ def _clear_partial_state_room_txn(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ un_partial_state_room_stream_id: int,
+ ) -> None:
+ DatabasePool.simple_delete_txn(
+ txn,
+ table="partial_state_rooms_servers",
+ keyvalues={"room_id": room_id},
+ )
+ DatabasePool.simple_delete_one_txn(
+ txn,
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ )
+ self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
+ self._invalidate_cache_and_stream(
+ txn, self._get_partial_state_servers_at_join, (room_id,)
+ )
+
+ DatabasePool.simple_insert_txn(
+ txn,
+ "un_partial_stated_room_stream",
+ {
+ "stream_id": un_partial_state_room_stream_id,
+ "instance_name": self._instance_name,
+ "room_id": room_id,
+ },
+ )
+
+ # We now delete anything from `device_lists_remote_pending` with a
+ # stream ID less than the minimum
+ # `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
+ device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
+ txn,
+ table="partial_state_rooms",
+ keyvalues={},
+ retcol="MIN(device_lists_stream_id)",
+ allow_none=True,
+ )
+ if device_lists_stream_id is None:
+ # There are no rooms being currently partially joined, so we delete everything.
+ txn.execute("DELETE FROM device_lists_remote_pending")
+ else:
+ sql = """
+ DELETE FROM device_lists_remote_pending
+ WHERE stream_id <= ?
+ """
+ txn.execute(sql, (device_lists_stream_id,))
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index 39e80f6f5b..131f357d04 100644
--- a/synapse/storage/databases/main/room_batch.py
+++ b/synapse/storage/databases/main/room_batch.py
@@ -44,6 +44,4 @@ class RoomBatchStore(SQLBaseStore):
table="event_to_state_groups",
keyvalues={"event_id": event_id},
values={"state_group": state_group_id, "event_id": event_id},
- # Unique constraint on event_id so we don't have to lock
- lock=False,
)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index f02c1d7ea7..ea6a5e2f34 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from itertools import chain
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Collection,
Dict,
FrozenSet,
@@ -47,7 +49,13 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
+from synapse.types import (
+ JsonDict,
+ PersistedEventPosition,
+ StateMap,
+ StrCollection,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -385,7 +393,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self,
user_id: str,
membership_list: Collection[str],
- excluded_rooms: Optional[List[str]] = None,
+ excluded_rooms: StrCollection = (),
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -412,10 +420,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
# Now we filter out forgotten and excluded rooms
- rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
+ rooms_to_exclude = await self.get_forgotten_rooms_for_user(user_id)
if excluded_rooms is not None:
- rooms_to_exclude.update(set(excluded_rooms))
+ # Take a copy to avoid mutating the in-cache set
+ rooms_to_exclude = set(rooms_to_exclude)
+ rooms_to_exclude.update(excluded_rooms)
return [room for room in rooms if room.room_id not in rooms_to_exclude]
@@ -1122,12 +1132,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
+ #
+ # We need to fetch all hosts joined to the room according to `state` by
+ # inspecting all join memberships in `state`. However, if the `state` is
+ # relatively recent then many of its events are likely to be held in
+ # the current state of the room, which is easily available and likely
+ # cached.
+ #
+ # We therefore compute the set of `state` events not in the
+ # current state and only fetch those.
+ current_memberships = (
+ await self._get_approximate_current_memberships_in_room(room_id)
+ )
+ unknown_state_events = {}
+ joined_users_in_current_state = []
+
+ for (type, state_key), event_id in state.items():
+ if event_id not in current_memberships:
+ unknown_state_events[type, state_key] = event_id
+ elif current_memberships[event_id] == Membership.JOIN:
+ joined_users_in_current_state.append(state_key)
+
joined_user_ids = await self.get_joined_user_ids_from_state(
- room_id, state
+ room_id, unknown_state_events
)
cache.hosts_to_joined_users = {}
- for user_id in joined_user_ids:
+ for user_id in chain(joined_user_ids, joined_users_in_current_state):
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
@@ -1138,6 +1169,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return frozenset(cache.hosts_to_joined_users)
+ async def _get_approximate_current_memberships_in_room(
+ self, room_id: str
+ ) -> Mapping[str, Optional[str]]:
+ """Build a map from event id to membership, for all events in the current state.
+
+ The event ids of non-memberships events (e.g. `m.room.power_levels`) are present
+ in the result, mapped to values of `None`.
+
+ The result is approximate for partially-joined rooms. It is fully accurate
+ for fully-joined rooms.
+ """
+
+ rows = await self.db_pool.simple_select_list(
+ "current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "membership"),
+ desc="has_completed_background_updates",
+ )
+ return {row["event_id"]: row["membership"] for row in rows}
+
@cached(max_entries=10000)
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache()
@@ -1169,7 +1220,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return count == 0
@cached()
- async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
+ async def get_forgotten_rooms_for_user(self, user_id: str) -> AbstractSet[str]:
"""Gets all rooms the user has forgotten.
Args:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index af7bebee80..ba325d390b 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,7 @@
# limitations under the License.
import collections.abc
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple
import attr
@@ -24,6 +24,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.opentracing import trace
+from synapse.replication.tcp.streams import UnPartialStatedEventStream
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -33,8 +35,8 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.storage.state import StateFilter
from synapse.types import JsonDict, JsonMapping, StateMap
+from synapse.types.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
@@ -80,6 +82,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
+ self._instance_name: str = hs.get_instance_name()
+
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
+ if stream_name == UnPartialStatedEventStream.NAME:
+ for row in rows:
+ assert isinstance(row, UnPartialStatedEventStreamRow)
+ self._get_state_group_for_event.invalidate((row.event_id,))
+ self.is_partial_state_event.invalidate((row.event_id,))
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
async def get_room_version(self, room_id: str) -> RoomVersion:
"""Get the room_version of a given room
@@ -404,18 +422,21 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
context: EventContext,
) -> None:
"""Update the state group for a partial state event"""
- await self.db_pool.runInteraction(
- "update_state_for_partial_state_event",
- self._update_state_for_partial_state_event_txn,
- event,
- context,
- )
+ async with self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id:
+ await self.db_pool.runInteraction(
+ "update_state_for_partial_state_event",
+ self._update_state_for_partial_state_event_txn,
+ event,
+ context,
+ un_partial_state_event_stream_id,
+ )
def _update_state_for_partial_state_event_txn(
self,
txn: LoggingTransaction,
event: EventBase,
context: EventContext,
+ un_partial_state_event_stream_id: int,
) -> None:
# we shouldn't have any outliers here
assert not event.internal_metadata.is_outlier()
@@ -436,7 +457,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# the event may now be rejected where it was not before, or vice versa,
# in which case we need to update the rejected flags.
- if bool(context.rejected) != (event.rejected_reason is not None):
+ rejection_status_changed = bool(context.rejected) != (
+ event.rejected_reason is not None
+ )
+ if rejection_status_changed:
self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
self.db_pool.simple_delete_one_txn(
@@ -445,8 +469,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
keyvalues={"event_id": event.event_id},
)
- # TODO(faster_joins): need to do something about workers here
- # https://github.com/matrix-org/synapse/issues/12994
txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,))
txn.call_after(
self._get_state_group_for_event.prefill,
@@ -454,6 +476,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
state_group,
)
+ self.db_pool.simple_insert_txn(
+ txn,
+ "un_partial_stated_event_stream",
+ {
+ "stream_id": un_partial_state_event_stream_id,
+ "instance_name": self._instance_name,
+ "event_id": event.event_id,
+ "rejection_status_changed": rejection_status_changed,
+ },
+ )
+ txn.call_after(self.hs.get_notifier().on_new_replication_data)
+
class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 356d4ca788..d7b7d0c3c9 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -22,13 +22,14 @@ from typing_extensions import Counter
from twisted.internet.defer import DeferredLock
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
+from synapse.storage.databases.main.events_worker import InvalidEventError
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -554,7 +555,17 @@ class StatsStore(StateDeltasStore):
"get_initial_state_for_room", _fetch_current_state_stats
)
- state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]
+ try:
+ state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]
+ except InvalidEventError as e:
+ # If an exception occurs fetching events then the room is broken;
+ # skip process it to avoid being stuck on a room.
+ logger.warning(
+ "Failed to fetch events for room %s, skipping stats calculation: %r.",
+ room_id,
+ e,
+ )
+ return
room_state: Dict[str, Union[None, bool, str]] = {
"join_rules": None,
@@ -652,7 +663,7 @@ class StatsStore(StateDeltasStore):
from_ts: Optional[int] = None,
until_ts: Optional[int] = None,
order_by: Optional[str] = UserSortOrder.USER_ID.value,
- direction: Optional[str] = "f",
+ direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
@@ -703,7 +714,7 @@ class StatsStore(StateDeltasStore):
500, "Incorrect value for order_by provided: %s" % order_by
)
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index cc27ec3804..818c46182e 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -55,6 +55,7 @@ from typing_extensions import Literal
from twisted.internet import defer
+from synapse.api.constants import Direction
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -67,7 +68,7 @@ from synapse.storage.database import (
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import PersistedEventPosition, RoomStreamToken
from synapse.util.caches.descriptors import cached
@@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000
_STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
-
# Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
@@ -104,7 +104,7 @@ class _EventsAround:
def generate_pagination_where_clause(
- direction: str,
+ direction: Direction,
column_names: Tuple[str, str],
from_token: Optional[Tuple[Optional[int], int]],
to_token: Optional[Tuple[Optional[int], int]],
@@ -130,27 +130,26 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
- direction: Whether we're paginating backwards("b") or forwards ("f").
+ direction: Whether we're paginating backwards or forwards.
column_names: The column names to bound. Must *not* be user defined as
these get inserted directly into the SQL statement without escapes.
from_token: The start point for the pagination. This is an exclusive
- minimum bound if direction is "f", and an inclusive maximum bound if
- direction is "b".
+ minimum bound if direction is forwards, and an inclusive maximum bound if
+ direction is backwards.
to_token: The endpoint point for the pagination. This is an inclusive
- maximum bound if direction is "f", and an exclusive minimum bound if
- direction is "b".
+ maximum bound if direction is forwards, and an exclusive minimum bound if
+ direction is backwards.
engine: The database engine to generate the clauses for
Returns:
The sql expression
"""
- assert direction in ("b", "f")
where_clause = []
if from_token:
where_clause.append(
_make_generic_sql_bound(
- bound=">=" if direction == "b" else "<",
+ bound=">=" if direction == Direction.BACKWARDS else "<",
column_names=column_names,
values=from_token,
engine=engine,
@@ -160,7 +159,7 @@ def generate_pagination_where_clause(
if to_token:
where_clause.append(
_make_generic_sql_bound(
- bound="<" if direction == "b" else ">=",
+ bound="<" if direction == Direction.BACKWARDS else ">=",
column_names=column_names,
values=to_token,
engine=engine,
@@ -170,6 +169,104 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause)
+def generate_pagination_bounds(
+ direction: Direction,
+ from_token: Optional[RoomStreamToken],
+ to_token: Optional[RoomStreamToken],
+) -> Tuple[
+ str, Optional[Tuple[Optional[int], int]], Optional[Tuple[Optional[int], int]]
+]:
+ """
+ Generate a start and end point for this page of events.
+
+ Args:
+ direction: Whether pagination is going forwards or backwards.
+ from_token: The token to start pagination at, or None to start at the first value.
+ to_token: The token to end pagination at, or None to not limit the end point.
+
+ Returns:
+ A three tuple of:
+
+ ASC or DESC for sorting of the query.
+
+ The starting position as a tuple of ints representing
+ (topological position, stream position) or None if no from_token was
+ provided. The topological position may be None for live tokens.
+
+ The end position in the same format as the starting position, or None
+ if no to_token was provided.
+ """
+
+ # Tokens really represent positions between elements, but we use
+ # the convention of pointing to the event before the gap. Hence
+ # we have a bit of asymmetry when it comes to equalities.
+ if direction == Direction.BACKWARDS:
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ # The bounds for the stream tokens are complicated by the fact
+ # that we need to handle the instance_map part of the tokens. We do this
+ # by fetching all events between the min stream token and the maximum
+ # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+ # then filtering the results.
+ from_bound: Optional[Tuple[Optional[int], int]] = None
+ if from_token:
+ if from_token.topological is not None:
+ from_bound = from_token.as_historical_tuple()
+ elif direction == Direction.BACKWARDS:
+ from_bound = (
+ None,
+ from_token.get_max_stream_pos(),
+ )
+ else:
+ from_bound = (
+ None,
+ from_token.stream,
+ )
+
+ to_bound: Optional[Tuple[Optional[int], int]] = None
+ if to_token:
+ if to_token.topological is not None:
+ to_bound = to_token.as_historical_tuple()
+ elif direction == Direction.BACKWARDS:
+ to_bound = (
+ None,
+ to_token.stream,
+ )
+ else:
+ to_bound = (
+ None,
+ to_token.get_max_stream_pos(),
+ )
+
+ return order, from_bound, to_bound
+
+
+def generate_next_token(
+ direction: Direction, last_topo_ordering: int, last_stream_ordering: int
+) -> RoomStreamToken:
+ """
+ Generate the next room stream token based on the currently returned data.
+
+ Args:
+ direction: Whether pagination is going forwards or backwards.
+ last_topo_ordering: The last topological ordering being returned.
+ last_stream_ordering: The last stream ordering being returned.
+
+ Returns:
+ A new RoomStreamToken to return to the client.
+ """
+ if direction == Direction.BACKWARDS:
+ # Tokens are positions between events.
+ # This token points *after* the last event in the chunk.
+ # We need it to point to the event before it in the chunk
+ # when we are going backwards so we subtract one from the
+ # stream part.
+ last_stream_ordering -= 1
+ return RoomStreamToken(last_topo_ordering, last_stream_ordering)
+
+
def _make_generic_sql_bound(
bound: str,
column_names: Tuple[str, str],
@@ -801,13 +898,66 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
before this stream ordering.
"""
- last_row = await self.get_room_event_before_stream_ordering(
- room_id=room_id,
- stream_ordering=end_token.stream,
+ def get_last_event_in_room_before_stream_ordering_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[str]:
+ # We need to handle the fact that the stream tokens can be vector
+ # clocks. We do this by getting all rows between the minimum and
+ # maximum stream ordering in the token, plus one row less than the
+ # minimum stream ordering. We then filter the results against the
+ # token and return the first row that matches.
+
+ sql = """
+ SELECT * FROM (
+ SELECT instance_name, stream_ordering, topological_ordering, event_id
+ FROM events
+ LEFT JOIN rejections USING (event_id)
+ WHERE room_id = ?
+ AND ? < stream_ordering AND stream_ordering <= ?
+ AND NOT outlier
+ AND rejections.event_id IS NULL
+ ORDER BY stream_ordering DESC
+ ) AS a
+ UNION
+ SELECT * FROM (
+ SELECT instance_name, stream_ordering, topological_ordering, event_id
+ FROM events
+ LEFT JOIN rejections USING (event_id)
+ WHERE room_id = ?
+ AND stream_ordering <= ?
+ AND NOT outlier
+ AND rejections.event_id IS NULL
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ ) AS b
+ """
+ txn.execute(
+ sql,
+ (
+ room_id,
+ end_token.stream,
+ end_token.get_max_stream_pos(),
+ room_id,
+ end_token.stream,
+ ),
+ )
+
+ for instance_name, stream_ordering, topological_ordering, event_id in txn:
+ if _filter_results(
+ lower_token=None,
+ upper_token=end_token,
+ instance_name=instance_name,
+ topological_ordering=topological_ordering,
+ stream_ordering=stream_ordering,
+ ):
+ return event_id
+
+ return None
+
+ return await self.db_pool.runInteraction(
+ "get_last_event_in_room_before_stream_ordering",
+ get_last_event_in_room_before_stream_ordering_txn,
)
- if last_row:
- return last_row[2]
- return None
async def get_current_room_stream_token_for_room_id(
self, room_id: str
@@ -891,12 +1041,40 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id
stream_key
"""
- sql = (
- "SELECT coalesce(MIN(topological_ordering), 0) FROM events"
- " WHERE room_id = ? AND stream_ordering >= ?"
- )
+ if isinstance(self.database_engine, PostgresEngine):
+ min_function = "LEAST"
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ min_function = "MIN"
+ else:
+ raise RuntimeError(f"Unknown database engine {self.database_engine}")
+
+ # This query used to be
+ # SELECT COALESCE(MIN(topological_ordering), 0) FROM events
+ # WHERE room_id = ? and events.stream_ordering >= {stream_key}
+ # which returns 0 if the stream_key is newer than any event in
+ # the room. That's not wrong, but it seems to interact oddly with backfill,
+ # requiring a second call to /messages to actually backfill from a remote
+ # homeserver.
+ #
+ # Instead, rollback the stream ordering to that after the most recent event in
+ # this room.
+ sql = f"""
+ WITH fallback(max_stream_ordering) AS (
+ SELECT MAX(stream_ordering)
+ FROM events
+ WHERE room_id = ?
+ )
+ SELECT COALESCE(MIN(topological_ordering), 0) FROM events
+ WHERE
+ room_id = ?
+ AND events.stream_ordering >= {min_function}(
+ ?,
+ (SELECT max_stream_ordering FROM fallback)
+ )
+ """
+
row = await self.db_pool.execute(
- "get_current_topological_token", None, sql, room_id, stream_key
+ "get_current_topological_token", None, sql, room_id, room_id, stream_key
)
return row[0][0] if row else 0
@@ -1022,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
before_token,
- direction="b",
+ direction=Direction.BACKWARDS,
limit=before_limit,
event_filter=event_filter,
)
@@ -1032,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
after_token,
- direction="f",
+ direction=Direction.FORWARDS,
limit=after_limit,
event_filter=event_filter,
)
@@ -1195,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
@@ -1206,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id
from_token: The token used to stream from
to_token: A token which if given limits the results to only those before
- direction: Either 'b' or 'f' to indicate whether we are paginating
- forwards or backwards from `from_key`.
+ direction: Indicates whether we are paginating forwards or backwards
+ from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to
those that match the filter.
@@ -1219,47 +1397,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
`to_token`), or `limit` is zero.
"""
- # Tokens really represent positions between elements, but we use
- # the convention of pointing to the event before the gap. Hence
- # we have a bit of asymmetry when it comes to equalities.
args = [False, room_id]
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
- # The bounds for the stream tokens are complicated by the fact
- # that we need to handle the instance_map part of the tokens. We do this
- # by fetching all events between the min stream token and the maximum
- # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
- # then filtering the results.
- if from_token.topological is not None:
- from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
- elif direction == "b":
- from_bound = (
- None,
- from_token.get_max_stream_pos(),
- )
- else:
- from_bound = (
- None,
- from_token.stream,
- )
-
- to_bound: Optional[Tuple[Optional[int], int]] = None
- if to_token:
- if to_token.topological is not None:
- to_bound = to_token.as_historical_tuple()
- elif direction == "b":
- to_bound = (
- None,
- to_token.stream,
- )
- else:
- to_bound = (
- None,
- to_token.get_max_stream_pos(),
- )
+ order, from_bound, to_bound = generate_pagination_bounds(
+ direction, from_token, to_token
+ )
bounds = generate_pagination_where_clause(
direction=direction,
@@ -1346,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
_EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results(
- lower_token=to_token if direction == "b" else from_token,
- upper_token=from_token if direction == "b" else to_token,
+ lower_token=to_token
+ if direction == Direction.BACKWARDS
+ else from_token,
+ upper_token=from_token
+ if direction == Direction.BACKWARDS
+ else to_token,
instance_name=instance_name,
topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
@@ -1355,16 +1501,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
][:limit]
if rows:
- topo = rows[-1].topological_ordering
- token = rows[-1].stream_ordering
- if direction == "b":
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- token -= 1
- next_token = RoomStreamToken(topo, token)
+ assert rows[-1].topological_ordering is not None
+ next_token = generate_next_token(
+ direction, rows[-1].topological_ordering, rows[-1].stream_ordering
+ )
else:
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
@@ -1377,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id: str,
from_key: RoomStreamToken,
to_key: Optional[RoomStreamToken] = None,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
@@ -1387,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id
from_key: The token used to stream from
to_key: A token which if given limits the results to only those before
- direction: Either 'b' or 'f' to indicate whether we are paginating
- forwards or backwards from `from_key`.
+ direction: Indicates whether we are paginating forwards or backwards
+ from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to those that match the filter.
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index b0f5de67a3..d5500cdd47 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,7 +17,8 @@
import logging
from typing import Any, Dict, Iterable, List, Tuple, cast
-from synapse.replication.tcp.streams import TagAccountDataStream
+from synapse.api.constants import AccountDataTypes
+from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@@ -54,7 +55,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str]], int, bool]:
"""Get updates for tags replication stream.
Args:
@@ -73,7 +74,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
The token returned can be used in a subsequent call to this
function to get further updatees.
- The updates are a list of 2-tuples of stream ID and the row data
+ The updates are a list of tuples of stream ID, user ID and room ID
"""
if last_id == current_id:
@@ -96,38 +97,13 @@ class TagsWorkerStore(AccountDataWorkerStore):
"get_all_updated_tags", get_all_updated_tags_txn
)
- def get_tag_content(
- txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
- ) -> List[Tuple[int, Tuple[str, str, str]]]:
- sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
- results = []
- for stream_id, user_id, room_id in tag_ids:
- txn.execute(sql, (user_id, room_id))
- tags = []
- for tag, content in txn:
- tags.append(json_encoder.encode(tag) + ":" + content)
- tag_json = "{" + ",".join(tags) + "}"
- results.append((stream_id, (user_id, room_id, tag_json)))
-
- return results
-
- batch_size = 50
- results = []
- for i in range(0, len(tag_ids), batch_size):
- tags = await self.db_pool.runInteraction(
- "get_all_updated_tag_content",
- get_tag_content,
- tag_ids[i : i + batch_size],
- )
- results.extend(tags)
-
limited = False
upto_token = current_id
- if len(results) >= limit:
- upto_token = results[-1][0]
+ if len(tag_ids) >= limit:
+ upto_token = tag_ids[-1][0]
limited = True
- return results, upto_token, limited
+ return tag_ids, upto_token, limited
async def get_updated_tags(
self, user_id: str, stream_id: int
@@ -299,11 +275,13 @@ class TagsWorkerStore(AccountDataWorkerStore):
token: int,
rows: Iterable[Any],
) -> None:
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
+ if stream_name == AccountDataStream.NAME:
for row in rows:
- self.get_tags_for_user.invalidate((row.user_id,))
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ if row.data_type == AccountDataTypes.TAG:
+ self.get_tags_for_user.invalidate((row.user_id,))
+ self._account_data_stream_cache.entity_has_changed(
+ row.user_id, token
+ )
super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index f8c6877ee8..6b33d809b6 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
import attr
from canonicaljson import encode_canonical_json
+from synapse.api.constants import Direction
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import db_to_json
from synapse.storage.database import (
@@ -496,7 +497,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit: int,
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
- direction: str = "f",
+ direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
@@ -518,7 +519,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) -> Tuple[List[JsonDict], int]:
order_by_column = DestinationSortOrder(order_by).value
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
@@ -550,7 +551,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
)
async def get_destination_rooms_paginate(
- self, destination: str, start: int, limit: int, direction: str = "f"
+ self,
+ destination: str,
+ start: int,
+ limit: int,
+ direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destination's rooms.
This will return a json list of rooms and the
@@ -569,7 +574,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 698d6f7515..14ef5b040d 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,14 @@ from typing import (
cast,
)
+try:
+ # Figure out if ICU support is available for searching users.
+ import icu
+
+ USE_ICU = True
+except ModuleNotFoundError:
+ USE_ICU = False
+
from typing_extensions import TypedDict
from synapse.api.errors import StoreError
@@ -481,7 +489,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
table="user_directory",
keyvalues={"user_id": user_id},
values={"display_name": display_name, "avatar_url": avatar_url},
- lock=False, # We're only inserter
)
if isinstance(self.database_engine, PostgresEngine):
@@ -511,7 +518,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
table="user_directory_search",
keyvalues={"user_id": user_id},
values={"value": value},
- lock=False, # We're only inserter
)
else:
# This should be unreachable.
@@ -888,7 +894,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
limited = len(results) > limit
- return {"limited": limited, "results": results}
+ return {"limited": limited, "results": results[0:limit]}
def _parse_query_sqlite(search_term: str) -> str:
@@ -902,7 +908,7 @@ def _parse_query_sqlite(search_term: str) -> str:
"""
# Pull out the individual words, discarding any non-word characters.
- results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+ results = _parse_words(search_term)
return " & ".join("(%s* OR %s)" % (result, result) for result in results)
@@ -912,12 +918,63 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
We use this so that we can add prefix matching, which isn't something
that is supported by default.
"""
-
- # Pull out the individual words, discarding any non-word characters.
- results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+ results = _parse_words(search_term)
both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
exact = " & ".join("%s" % (result,) for result in results)
prefix = " & ".join("%s:*" % (result,) for result in results)
return both, exact, prefix
+
+
+def _parse_words(search_term: str) -> List[str]:
+ """Split the provided search string into a list of its words.
+
+ If support for ICU (International Components for Unicode) is available, use it.
+ Otherwise, fall back to using a regex to detect word boundaries. This latter
+ solution works well enough for most latin-based languages, but doesn't work as well
+ with other languages.
+
+ Args:
+ search_term: The search string.
+
+ Returns:
+ A list of the words in the search string.
+ """
+ if USE_ICU:
+ return _parse_words_with_icu(search_term)
+
+ return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+
+def _parse_words_with_icu(search_term: str) -> List[str]:
+ """Break down the provided search string into its individual words using ICU
+ (International Components for Unicode).
+
+ Args:
+ search_term: The search string.
+
+ Returns:
+ A list of the words in the search string.
+ """
+ results = []
+ breaker = icu.BreakIterator.createWordInstance(icu.Locale.getDefault())
+ breaker.setText(search_term)
+ i = 0
+ while True:
+ j = breaker.nextBoundary()
+ if j < 0:
+ break
+
+ result = search_term[i:j]
+
+ # libicu considers spaces and punctuation between words as words, but we don't
+ # want to include those in results as they would result in syntax errors in SQL
+ # queries (e.g. "foo bar" would result in the search query including "foo & &
+ # bar").
+ if len(re.findall(r"([\w\-]+)", result, re.UNICODE)):
+ results.append(result)
+
+ i = j
+
+ return results
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index a7fcc564a9..d743282f13 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -22,8 +22,8 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap
+from synapse.types.state import StateFilter
from synapse.util.caches import intern_string
if TYPE_CHECKING:
@@ -93,13 +93,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
- # Unless the filter clause is empty, we're going to append it after an
- # existing where clause
- if where_clause:
- where_clause = " AND (%s)" % (where_clause,)
-
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
@@ -110,31 +103,91 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
- # The PARTITION is used to get the event_id in the greatest state
- # group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
- WITH RECURSIVE state(state_group) AS (
+ WITH RECURSIVE sgs(state_group) AS (
VALUES(?::bigint)
UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
+ SELECT prev_state_group FROM state_group_edges e, sgs s
WHERE s.state_group = e.state_group
)
- SELECT DISTINCT ON (type, state_key)
- type, state_key, event_id
- FROM state_groups_state
- WHERE state_group IN (
- SELECT state_group FROM state
- ) %s
- ORDER BY type, state_key, state_group DESC
+ %s
"""
+ overall_select_query_args: List[Union[int, str]] = []
+
+ # This is an optimization to create a select clause per-condition. This
+ # makes the query planner a lot smarter on what rows should pull out in the
+ # first place and we end up with something that takes 10x less time to get a
+ # result.
+ use_condition_optimization = (
+ not state_filter.include_others and not state_filter.is_full()
+ )
+ state_filter_condition_combos: List[Tuple[str, Optional[str]]] = []
+ # We don't need to caclculate this list if we're not using the condition
+ # optimization
+ if use_condition_optimization:
+ for etype, state_keys in state_filter.types.items():
+ if state_keys is None:
+ state_filter_condition_combos.append((etype, None))
+ else:
+ for state_key in state_keys:
+ state_filter_condition_combos.append((etype, state_key))
+ # And here is the optimization itself. We don't want to do the optimization
+ # if there are too many individual conditions. 10 is an arbitrary number
+ # with no testing behind it but we do know that we specifically made this
+ # optimization for when we grab the necessary state out for
+ # `filter_events_for_client` which just uses 2 conditions
+ # (`EventTypes.RoomHistoryVisibility` and `EventTypes.Member`).
+ if use_condition_optimization and len(state_filter_condition_combos) < 10:
+ select_clause_list: List[str] = []
+ for etype, skey in state_filter_condition_combos:
+ if skey is None:
+ where_clause = "(type = ?)"
+ overall_select_query_args.extend([etype])
+ else:
+ where_clause = "(type = ? AND state_key = ?)"
+ overall_select_query_args.extend([etype, skey])
+
+ select_clause_list.append(
+ f"""
+ (
+ SELECT DISTINCT ON (type, state_key)
+ type, state_key, event_id
+ FROM state_groups_state
+ INNER JOIN sgs USING (state_group)
+ WHERE {where_clause}
+ ORDER BY type, state_key, state_group DESC
+ )
+ """
+ )
+
+ overall_select_clause = " UNION ".join(select_clause_list)
+ else:
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
+
+ overall_select_query_args.extend(where_args)
+
+ overall_select_clause = f"""
+ SELECT DISTINCT ON (type, state_key)
+ type, state_key, event_id
+ FROM state_groups_state
+ WHERE state_group IN (
+ SELECT state_group FROM sgs
+ ) {where_clause}
+ ORDER BY type, state_key, state_group DESC
+ """
+
for group in groups:
args: List[Union[int, str]] = [group]
- args.extend(where_args)
+ args.extend(overall_select_query_args)
- txn.execute(sql % (where_clause,), args)
+ txn.execute(sql % (overall_select_clause,), args)
for row in txn:
typ, state_key, event_id = row
key = (intern_string(typ), intern_string(state_key))
@@ -142,6 +195,12 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
else:
max_entries_returned = state_filter.max_entries_returned()
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
+
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index f8cfcaca83..1a7232b276 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -25,10 +25,10 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
-from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
+from synapse.types.state import StateFilter
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.cancellation import cancellable
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 70e594a68f..0363cdc038 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -132,6 +132,10 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
"""Execute a chunk of SQL containing multiple semicolon-delimited statements.
This is not provided by DBAPI2, and so needs engine-specific support.
+
+ Any ongoing transaction is committed before executing the script in its own
+ transaction. The script transaction is left open and it is the responsibility of
+ the caller to commit it.
"""
...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 719a517336..b350f57ccb 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -77,7 +77,7 @@ class PostgresEngine(
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
- self._version = cast(int, db_conn.server_version)
+ self._version = db_conn.server_version
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
@@ -220,5 +220,9 @@ class PostgresEngine(
"""Execute a chunk of SQL containing multiple semicolon-delimited statements.
Psycopg2 seems happy to do this in DBAPI2's `execute()` function.
+
+ For consistency with SQLite, any ongoing transaction is committed before
+ executing the script in its own transaction. The script transaction is
+ left open and it is the responsibility of the caller to commit it.
"""
- cursor.execute(script)
+ cursor.execute(f"COMMIT; BEGIN TRANSACTION; {script}")
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 14260442b6..28751e89a5 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -135,13 +135,16 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
> than one statement with it, it will raise a Warning. Use executescript() if
> you want to execute multiple SQL statements with one call.
- Though the docs for `executescript` warn:
+ The script is prefixed with a `BEGIN TRANSACTION`, since the docs for
+ `executescript` warn:
> If there is a pending transaction, an implicit COMMIT statement is executed
> first. No other implicit transaction control is performed; any transaction
> control must be added to sql_script.
"""
- cursor.executescript(script)
+ # The implementation of `executescript` can be found at
+ # https://github.com/python/cpython/blob/3.11/Modules/_sqlite/cursor.c#L1035.
+ cursor.executescript(f"BEGIN TRANSACTION; {script}")
# Following functions taken from: https://github.com/coleifer/peewee
diff --git a/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql
new file mode 100644
index 0000000000..93d7fcb79b
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql
@@ -0,0 +1,53 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Prior to this schema delta, we tracked the set of unconverted rows in
+-- `device_lists_changes_in_room` using the `converted_to_destinations` flag. When rows
+-- were converted to `device_lists_outbound_pokes`, the `converted_to_destinations` flag
+-- would be set.
+--
+-- After this schema delta, the `converted_to_destinations` is still populated like
+-- before, but the set of unconverted rows is determined by the `stream_id` in the new
+-- `device_lists_changes_converted_stream_position` table.
+--
+-- If rolled back, Synapse will re-send all device list changes that happened since the
+-- schema delta.
+
+CREATE TABLE IF NOT EXISTS device_lists_changes_converted_stream_position(
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ -- The (stream id, room id) of the last row in `device_lists_changes_in_room` that
+ -- has been converted to `device_lists_outbound_pokes`. Rows with a strictly larger
+ -- (stream id, room id) where `converted_to_destinations` is `FALSE` have not been
+ -- converted.
+ stream_id BIGINT NOT NULL,
+ -- `room_id` may be an empty string, which compares less than all valid room IDs.
+ room_id TEXT NOT NULL,
+ CHECK (Lock='X')
+);
+
+INSERT INTO device_lists_changes_converted_stream_position (stream_id, room_id) VALUES (
+ (
+ SELECT COALESCE(
+ -- The last converted stream id is the smallest unconverted stream id minus
+ -- one.
+ MIN(stream_id) - 1,
+ -- If there is no unconverted stream id, the last converted stream id is the
+ -- largest stream id.
+ -- Otherwise, pick 1, since stream ids start at 2.
+ (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room)
+ ) FROM device_lists_changes_in_room WHERE NOT converted_to_destinations
+ ),
+ ''
+);
diff --git a/synapse/storage/schema/main/delta/73/13add_device_lists_index.sql b/synapse/storage/schema/main/delta/73/13add_device_lists_index.sql
new file mode 100644
index 0000000000..3725022a13
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/13add_device_lists_index.sql
@@ -0,0 +1,20 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Adds an index on `device_lists_changes_in_room (room_id, stream_id)`, which
+-- speeds up `/sync` queries.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7313, 'device_lists_changes_in_room_by_room_index', '{}');
diff --git a/synapse/storage/schema/main/delta/73/20_un_partial_stated_room_stream.sql b/synapse/storage/schema/main/delta/73/20_un_partial_stated_room_stream.sql
new file mode 100644
index 0000000000..743196cfe3
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/20_un_partial_stated_room_stream.sql
@@ -0,0 +1,32 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Stream for notifying that a room has become un-partial-stated.
+CREATE TABLE un_partial_stated_room_stream(
+ -- Position in the stream
+ stream_id BIGINT PRIMARY KEY NOT NULL,
+
+ -- Which instance wrote this entry.
+ instance_name TEXT NOT NULL,
+
+ -- Which room has been un-partial-stated.
+ room_id TEXT NOT NULL REFERENCES rooms(room_id) ON DELETE CASCADE
+);
+
+-- We want an index here because of the foreign key constraint:
+-- upon deleting a room, the database needs to be able to check here.
+-- This index is not unique because we can join a room multiple times in a server's lifetime,
+-- so the same room could be un-partial-stated multiple times!
+CREATE INDEX un_partial_stated_room_stream_room_id ON un_partial_stated_room_stream (room_id);
diff --git a/synapse/storage/schema/main/delta/73/21_un_partial_stated_room_stream_seq.sql.postgres b/synapse/storage/schema/main/delta/73/21_un_partial_stated_room_stream_seq.sql.postgres
new file mode 100644
index 0000000000..c1aac0b385
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/21_un_partial_stated_room_stream_seq.sql.postgres
@@ -0,0 +1,20 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS un_partial_stated_room_stream_sequence;
+
+SELECT setval('un_partial_stated_room_stream_sequence', (
+ SELECT COALESCE(MAX(stream_id), 1) FROM un_partial_stated_room_stream
+));
diff --git a/synapse/storage/schema/main/delta/73/22_rebuild_user_dir_stats.sql b/synapse/storage/schema/main/delta/73/22_rebuild_user_dir_stats.sql
new file mode 100644
index 0000000000..afab1e4bb7
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/22_rebuild_user_dir_stats.sql
@@ -0,0 +1,29 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ -- Set up user directory staging tables.
+ (7322, 'populate_user_directory_createtables', '{}', NULL),
+ -- Run through each room and update the user directory according to who is in it.
+ (7322, 'populate_user_directory_process_rooms', '{}', 'populate_user_directory_createtables'),
+ -- Insert all users into the user directory, if search_all_users is on.
+ (7322, 'populate_user_directory_process_users', '{}', 'populate_user_directory_process_rooms'),
+ -- Clean up user directory staging tables.
+ (7322, 'populate_user_directory_cleanup', '{}', 'populate_user_directory_process_users'),
+ -- Rebuild the room_stats_current and room_stats_state tables.
+ (7322, 'populate_stats_process_rooms', '{}', NULL),
+ -- Update the user_stats_current table.
+ (7322, 'populate_stats_process_users', '{}', NULL)
+ON CONFLICT (update_name) DO NOTHING;
diff --git a/synapse/storage/schema/main/delta/73/22_un_partial_stated_event_stream.sql b/synapse/storage/schema/main/delta/73/22_un_partial_stated_event_stream.sql
new file mode 100644
index 0000000000..0e571f78c3
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/22_un_partial_stated_event_stream.sql
@@ -0,0 +1,34 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Stream for notifying that an event has become un-partial-stated.
+CREATE TABLE un_partial_stated_event_stream(
+ -- Position in the stream
+ stream_id BIGINT PRIMARY KEY NOT NULL,
+
+ -- Which instance wrote this entry.
+ instance_name TEXT NOT NULL,
+
+ -- Which event has been un-partial-stated.
+ event_id TEXT NOT NULL REFERENCES events(event_id) ON DELETE CASCADE,
+
+ -- true iff the `rejected` status of the event changed when it became
+ -- un-partial-stated.
+ rejection_status_changed BOOLEAN NOT NULL
+);
+
+-- We want an index here because of the foreign key constraint:
+-- upon deleting an event, the database needs to be able to check here.
+CREATE UNIQUE INDEX un_partial_stated_event_stream_room_id ON un_partial_stated_event_stream (event_id);
diff --git a/synapse/storage/schema/main/delta/73/23_fix_thread_index.sql b/synapse/storage/schema/main/delta/73/23_fix_thread_index.sql
new file mode 100644
index 0000000000..ec519ceebf
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/23_fix_thread_index.sql
@@ -0,0 +1,33 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- If a Synapse deployment made a large jump in versions (from < 1.62.0 to >= 1.70.0)
+-- in a single upgrade then it might be possible for the event_push_summary_unique_index
+-- to be created in the background from delta 71/02event_push_summary_unique.sql after
+-- delta 73/06thread_notifications_thread_id_idx.sql is executed, causing it to
+-- not drop the event_push_summary_unique_index index.
+--
+-- See https://github.com/matrix-org/synapse/issues/14641
+
+-- Stop the index from being scheduled for creation in the background.
+DELETE FROM background_updates WHERE update_name = 'event_push_summary_unique_index';
+
+-- The above background job also replaces another index, so ensure that side-effect
+-- is applied.
+DROP INDEX IF EXISTS event_push_summary_user_rm;
+
+-- Fix deployments which ran the 73/06thread_notifications_thread_id_idx.sql delta
+-- before the event_push_summary_unique_index background job was run.
+DROP INDEX IF EXISTS event_push_summary_unique_index;
diff --git a/synapse/storage/schema/main/delta/73/23_un_partial_stated_room_stream_seq.sql.postgres b/synapse/storage/schema/main/delta/73/23_un_partial_stated_room_stream_seq.sql.postgres
new file mode 100644
index 0000000000..1ec24702f3
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/23_un_partial_stated_room_stream_seq.sql.postgres
@@ -0,0 +1,20 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS un_partial_stated_event_stream_sequence;
+
+SELECT setval('un_partial_stated_event_stream_sequence', (
+ SELECT COALESCE(MAX(stream_id), 1) FROM un_partial_stated_event_stream
+));
diff --git a/synapse/storage/schema/main/delta/73/24_events_jump_to_date_index.sql b/synapse/storage/schema/main/delta/73/24_events_jump_to_date_index.sql
new file mode 100644
index 0000000000..67059909a1
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/24_events_jump_to_date_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7324, 'events_jump_to_date_index', '{}');
diff --git a/synapse/storage/schema/main/delta/73/25drop_presence.sql b/synapse/storage/schema/main/delta/73/25drop_presence.sql
new file mode 100644
index 0000000000..9f6ffa20b6
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/25drop_presence.sql
@@ -0,0 +1,17 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this table is unused
+DROP TABLE presence;
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 0d7108f01b..9adff3f4f5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -20,6 +20,7 @@ from collections import OrderedDict
from contextlib import contextmanager
from types import TracebackType
from typing import (
+ TYPE_CHECKING,
AsyncContextManager,
ContextManager,
Dict,
@@ -49,6 +50,9 @@ from synapse.storage.database import (
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator
+if TYPE_CHECKING:
+ from synapse.notifier import ReplicationNotifier
+
logger = logging.getLogger(__name__)
@@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def __init__(
self,
db_conn: LoggingDatabaseConnection,
+ notifier: "ReplicationNotifier",
table: str,
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
@@ -205,6 +210,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
+ self._notifier = notifier
+
def advance(self, instance_name: str, new_id: int) -> None:
# Advance should never be called on a writer instance, only over replication
if self._is_writer:
@@ -227,6 +234,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
with self._lock:
self._unfinished_ids.pop(next_id)
+ self._notifier.notify_replication()
+
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
@@ -250,6 +259,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
for next_id in next_ids:
self._unfinished_ids.pop(next_id)
+ self._notifier.notify_replication()
+
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
@@ -296,6 +307,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self,
db_conn: LoggingDatabaseConnection,
db: DatabasePool,
+ notifier: "ReplicationNotifier",
stream_name: str,
instance_name: str,
tables: List[Tuple[str, str, str]],
@@ -304,6 +316,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
positive: bool = True,
) -> None:
self._db = db
+ self._notifier = notifier
self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
@@ -378,6 +391,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._current_positions.values(), default=1
)
+ if not writers:
+ # If there have been no explicit writers given then any instance can
+ # write to the stream. In which case, let's pre-seed our own
+ # position with the current minimum.
+ self._current_positions[self._instance_name] = self._persisted_upto_position
+
def _load_current_ids(
self,
db_conn: LoggingDatabaseConnection,
@@ -529,7 +548,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
# controls the return type. If `None` or omitted, the context manager yields
# a single integer stream_id; otherwise it yields a list of stream_ids.
- return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
+ return cast(
+ AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier)
+ )
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
# If we have a list of instances that are allowed to write to this
@@ -538,7 +559,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
raise Exception("Tried to allocate stream ID on non-writer")
# Cast safety: see get_next.
- return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
+ return cast(
+ AsyncContextManager[List[int]],
+ _MultiWriterCtxManager(self, self._notifier, n),
+ )
def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
@@ -557,6 +581,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
+ txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
@@ -695,24 +720,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
heapq.heappush(self._known_persisted_positions, new_id)
- # If we're a writer and we don't have any active writes we update our
- # current position to the latest position seen. This allows the instance
- # to report a recent position when asked, rather than a potentially old
- # one (if this instance hasn't written anything for a while).
- our_current_position = self._current_positions.get(self._instance_name)
- if (
- our_current_position
- and not self._unfinished_ids
- and not self._in_flight_fetches
- ):
- self._current_positions[self._instance_name] = max(
- our_current_position, new_id
- )
-
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
- min_curr = min(self._current_positions.values(), default=0)
+ our_current_position = self._current_positions.get(self._instance_name, 0)
+ min_curr = min(
+ (
+ token
+ for name, token in self._current_positions.items()
+ if name != self._instance_name
+ ),
+ default=our_current_position,
+ )
+
+ if our_current_position and (self._unfinished_ids or self._in_flight_fetches):
+ min_curr = min(min_curr, our_current_position)
+
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are
@@ -783,6 +806,7 @@ class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator"""
id_gen: MultiWriterIdGenerator
+ notifier: "ReplicationNotifier"
multiple_ids: Optional[int] = None
stream_ids: List[int] = attr.Factory(list)
@@ -810,6 +834,8 @@ class _MultiWriterCtxManager:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
+ self.notifier.notify_replication()
+
if exc_type is not None:
return False
diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index 2dcd43d0a2..c6c8a0315c 100644
--- a/synapse/streams/__init__.py
+++ b/synapse/streams/__init__.py
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Collection, Generic, List, Optional, Tuple, TypeVar
+from typing import Generic, List, Optional, Tuple, TypeVar
-from synapse.types import UserID
+from synapse.types import StrCollection, UserID
# The key, this is either a stream token or int.
K = TypeVar("K")
@@ -28,7 +28,7 @@ class EventSource(Generic[K, R]):
user: UserID,
from_key: K,
limit: int,
- room_ids: Collection[str],
+ room_ids: StrCollection,
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[R], K]:
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 6df2de919c..a044280410 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -16,8 +16,9 @@ from typing import Optional
import attr
+from synapse.api.constants import Direction
from synapse.api.errors import SynapseError
-from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.servlet import parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.storage.databases.main import DataStore
from synapse.types import StreamToken
@@ -34,7 +35,7 @@ class PaginationConfig:
from_token: Optional[StreamToken]
to_token: Optional[StreamToken]
- direction: str
+ direction: Direction
limit: int
@classmethod
@@ -43,11 +44,9 @@ class PaginationConfig:
store: "DataStore",
request: SynapseRequest,
default_limit: int,
- default_dir: str = "f",
+ default_dir: Direction = Direction.FORWARDS,
) -> "PaginationConfig":
- direction = parse_string(
- request, "dir", default=default_dir, allowed_values=["f", "b"]
- )
+ direction = parse_enum(request, "dir", Direction, default=default_dir)
from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index f331e1af16..d7084d2358 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -53,11 +53,15 @@ class EventSources:
*(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner))
)
self.store = hs.get_datastores().main
+ self._instance_name = hs.get_instance_name()
def get_current_token(self) -> StreamToken:
push_rules_key = self.store.get_max_push_rules_stream_id()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
+ un_partial_stated_rooms_key = self.store.get_un_partial_stated_rooms_token(
+ self._instance_name
+ )
token = StreamToken(
room_key=self.sources.room.get_current_key(),
@@ -70,9 +74,23 @@ class EventSources:
device_list_key=device_list_key,
# Groups key is unused.
groups_key=0,
+ un_partial_stated_rooms_key=un_partial_stated_rooms_key,
)
return token
+ @trace
+ async def get_start_token_for_pagination(self, room_id: str) -> StreamToken:
+ """Get the start token for a given room to be used to paginate
+ events.
+
+ The returned token does not have the current values for fields other
+ than `room`, since they are not used during pagination.
+
+ Returns:
+ The start token for pagination.
+ """
+ return StreamToken.START
+
@trace
async def get_current_token_for_pagination(self, room_id: str) -> StreamToken:
"""Get the current token for a given room to be used to paginate
@@ -94,5 +112,6 @@ class EventSources:
to_device_key=0,
device_list_key=0,
groups_key=0,
+ un_partial_stated_rooms_key=0,
)
return token
diff --git a/synapse/types.py b/synapse/types/__init__.py
similarity index 97%
rename from synapse/types.py
rename to synapse/types/__init__.py
index f2d436ddc3..f82d1cfc29 100644
--- a/synapse/types.py
+++ b/synapse/types/__init__.py
@@ -17,6 +17,7 @@ import re
import string
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Any,
ClassVar,
Dict,
@@ -77,6 +78,10 @@ JsonMapping = Mapping[str, Any]
# A JSON-serialisable object.
JsonSerializable = object
+# Collection[str] that does not include str itself; str being a Sequence[str]
+# is very misleading and results in bugs.
+StrCollection = Union[Tuple[str, ...], List[str], AbstractSet[str]]
+
# Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface.
@@ -600,6 +605,12 @@ class RoomStreamToken:
elif self.instance_map:
entries = []
for name, pos in self.instance_map.items():
+ if pos <= self.stream:
+ # Ignore instances who are below the minimum stream position
+ # (we might know they've advanced without seeing a recent
+ # write from them).
+ continue
+
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
@@ -623,6 +634,7 @@ class StreamKeyType:
PUSH_RULES: Final = "push_rules_key"
TO_DEVICE: Final = "to_device_key"
DEVICE_LIST: Final = "device_list_key"
+ UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -630,7 +642,7 @@ class StreamToken:
"""A collection of keys joined together by underscores in the following
order and which represent the position in their respective streams.
- ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1`
+ ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379`
1. `room_key`: `s2633508` which is a `RoomStreamToken`
- `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59`
- See the docstring for `RoomStreamToken` for more details.
@@ -642,12 +654,13 @@ class StreamToken:
7. `to_device_key`: `274711`
8. `device_list_key`: `265584`
9. `groups_key`: `1` (note that this key is now unused)
+ 10. `un_partial_stated_rooms_key`: `379`
You can see how many of these keys correspond to the various
fields in a "/sync" response:
```json
{
- "next_batch": "s12_4_0_1_1_1_1_4_1",
+ "next_batch": "s12_4_0_1_1_1_1_4_1_1",
"presence": {
"events": []
},
@@ -659,7 +672,7 @@ class StreamToken:
"!QrZlfIDQLNLdZHqTnt:hs1": {
"timeline": {
"events": [],
- "prev_batch": "s10_4_0_1_1_1_1_4_1",
+ "prev_batch": "s10_4_0_1_1_1_1_4_1_1",
"limited": false
},
"state": {
@@ -695,6 +708,7 @@ class StreamToken:
device_list_key: int
# Note that the groups key is no longer used and may have bogus values.
groups_key: int
+ un_partial_stated_rooms_key: int
_SEPARATOR = "_"
START: ClassVar["StreamToken"]
@@ -733,6 +747,7 @@ class StreamToken:
# serialized so that there will not be confusion in the future
# if additional tokens are added.
str(self.groups_key),
+ str(self.un_partial_stated_rooms_key),
]
)
@@ -765,7 +780,7 @@ class StreamToken:
return attr.evolve(self, **{key: new_value})
-StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0)
+StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
@attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/synapse/storage/state.py b/synapse/types/state.py
similarity index 96%
rename from synapse/storage/state.py
rename to synapse/types/state.py
index 0004d955b4..743a4f9217 100644
--- a/synapse/storage/state.py
+++ b/synapse/types/state.py
@@ -118,6 +118,15 @@ class StateFilter:
)
)
+ def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
+ """The inverse to `from_types`."""
+ for (event_type, state_keys) in self.types.items():
+ if state_keys is None:
+ yield event_type, None
+ else:
+ for state_key in state_keys:
+ yield event_type, state_key
+
@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member
@@ -343,6 +352,15 @@ class StateFilter:
for s in state_keys
]
+ def wildcard_types(self) -> List[str]:
+ """Returns a list of event types which require us to fetch all state keys.
+ This will be empty unless `has_wildcards` returns True.
+
+ Returns:
+ A list of event types.
+ """
+ return [t for t, state_keys in self.types.items() if state_keys is None]
+
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 8a63a73bca..d612fca03d 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -205,7 +205,10 @@ T = TypeVar("T")
async def concurrently_execute(
- func: Callable[[T], Any], args: Iterable[T], limit: int
+ func: Callable[[T], Any],
+ args: Iterable[T],
+ limit: int,
+ delay_cancellation: bool = False,
) -> None:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions.
@@ -215,6 +218,8 @@ async def concurrently_execute(
args: List of arguments to pass to func, each invocation of func
gets a single argument.
limit: Maximum number of conccurent executions.
+ delay_cancellation: Whether to delay cancellation until after the invocations
+ have finished.
Returns:
None, when all function invocations have finished. The return values
@@ -233,9 +238,16 @@ async def concurrently_execute(
# We use `itertools.islice` to handle the case where the number of args is
# less than the limit, avoiding needlessly spawning unnecessary background
# tasks.
- await yieldable_gather_results(
- _concurrently_execute_inner, (value for value in itertools.islice(it, limit))
- )
+ if delay_cancellation:
+ await yieldable_gather_results_delaying_cancellation(
+ _concurrently_execute_inner,
+ (value for value in itertools.islice(it, limit)),
+ )
+ else:
+ await yieldable_gather_results(
+ _concurrently_execute_inner,
+ (value for value in itertools.islice(it, limit)),
+ )
P = ParamSpec("P")
@@ -292,6 +304,41 @@ async def yieldable_gather_results(
raise dfe.subFailure.value from None
+async def yieldable_gather_results_delaying_cancellation(
+ func: Callable[Concatenate[T, P], Awaitable[R]],
+ iter: Iterable[T],
+ *args: P.args,
+ **kwargs: P.kwargs,
+) -> List[R]:
+ """Executes the function with each argument concurrently.
+ Cancellation is delayed until after all the results have been gathered.
+
+ See `yieldable_gather_results`.
+
+ Args:
+ func: Function to execute that returns a Deferred
+ iter: An iterable that yields items that get passed as the first
+ argument to the function
+ *args: Arguments to be passed to each call to func
+ **kwargs: Keyword arguments to be passed to each call to func
+
+ Returns
+ A list containing the results of the function
+ """
+ try:
+ return await make_deferred_yieldable(
+ delay_cancellation(
+ defer.gatherResults(
+ [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
+ consumeErrors=True,
+ )
+ )
+ )
+ except defer.FirstError as dfe:
+ assert isinstance(dfe.subFailure.value, BaseException)
+ raise dfe.subFailure.value from None
+
+
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 4264730e4f..740d9585cf 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -53,9 +53,9 @@ F = TypeVar("F", bound=Callable[..., Any])
class CachedFunction(Generic[F]):
- invalidate: Any = None
- invalidate_all: Any = None
- prefill: Any = None
+ invalidate: Callable[[Tuple[Any, ...]], None]
+ invalidate_all: Callable[[], None]
+ prefill: Callable[[Tuple[Any, ...], Any], None]
cache: Any = None
num_args: Any = None
@@ -503,7 +503,7 @@ def cachedList(
is specified as a list that is iterated through to lookup keys in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
the cache gets passed to the original function, which is expected to results
- in a map of key to value for each passed value. THe new results are stored in the
+ in a map of key to value for each passed value. The new results are stored in the
original cache. Note that any missing values are cached as None.
Args:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index dcf0eac3bf..452d5d04c1 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -788,26 +788,21 @@ class LruCache(Generic[KT, VT]):
def __contains__(self, key: KT) -> bool:
return self.contains(key)
- def set_cache_factor(self, factor: float) -> bool:
+ def set_cache_factor(self, factor: float) -> None:
"""
Set the cache factor for this individual cache.
This will trigger a resize if it changes, which may require evicting
items from the cache.
-
- Returns:
- Whether the cache changed size or not.
"""
if not self.apply_cache_factor_from_config:
- return False
+ return
new_size = int(self._original_max_size * factor)
if new_size != self.max_size:
self.max_size = new_size
if self._on_resize:
self._on_resize()
- return True
- return False
def __del__(self) -> None:
# We're about to be deleted, so we make sure to clear up all the nodes
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index a3eb5f741b..340e5e9145 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -167,12 +167,10 @@ class ResponseCache(Generic[KV]):
# the should_cache bit, we leave it in the cache for now and schedule
# its removal later.
if self.timeout_sec and context.should_cache:
- self.clock.call_later(
- self.timeout_sec, self._result_cache.pop, key, None
- )
+ self.clock.call_later(self.timeout_sec, self.unset, key)
else:
# otherwise, remove the result immediately.
- self._result_cache.pop(key, None)
+ self.unset(key)
return r
# make sure we do this *after* adding the entry to result_cache,
@@ -181,6 +179,14 @@ class ResponseCache(Generic[KV]):
result.addBoth(on_complete)
return entry
+ def unset(self, key: KV) -> None:
+ """Remove the cached value for this key from the cache, if any.
+
+ Args:
+ key: key used to remove the cached value
+ """
+ self._result_cache.pop(key, None)
+
async def wrap(
self,
key: KV,
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 666f4b6895..1657459549 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -16,6 +16,7 @@ import logging
import math
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
+import attr
from sortedcontainers import SortedDict
from synapse.util import caches
@@ -26,14 +27,41 @@ logger = logging.getLogger(__name__)
EntityType = str
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class AllEntitiesChangedResult:
+ """Return type of `get_all_entities_changed`.
+
+ Callers must check that there was a cache hit, via `result.hit`, before
+ using the entities in `result.entities`.
+
+ This specifically does *not* implement helpers such as `__bool__` to ensure
+ that callers do the correct checks.
+ """
+
+ _entities: Optional[List[EntityType]]
+
+ @property
+ def hit(self) -> bool:
+ return self._entities is not None
+
+ @property
+ def entities(self) -> List[EntityType]:
+ assert self._entities is not None
+ return self._entities
+
+
class StreamChangeCache:
- """Keeps track of the stream positions of the latest change in a set of entities.
+ """
+ Keeps track of the stream positions of the latest change in a set of entities.
- Typically the entity will be a room or user id.
+ The entity will is typically a room ID or user ID, but can be any string.
- Given a list of entities and a stream position, it will give a subset of
- entities that may have changed since that position. If position key is too
- old then the cache will simply return all given entities.
+ Can be queried for whether a specific entity has changed after a stream position
+ or for a list of changed entities after a stream position. See the individual
+ methods for more information.
+
+ Only tracks to a maximum cache size, any position earlier than the earliest
+ known stream position must be treated as unknown.
"""
def __init__(
@@ -45,16 +73,20 @@ class StreamChangeCache:
) -> None:
self._original_max_size: int = max_size
self._max_size = math.floor(max_size)
- self._entity_to_key: Dict[EntityType, int] = {}
- # map from stream id to the a set of entities which changed at that stream id.
+ # map from stream id to the set of entities which changed at that stream id.
self._cache: SortedDict[int, Set[EntityType]] = SortedDict()
+ # map from entity to the stream ID of the latest change for that entity.
+ #
+ # Must be kept in sync with _cache.
+ self._entity_to_key: Dict[EntityType, int] = {}
# the earliest stream_pos for which we can reliably answer
# get_all_entities_changed. In other words, one less than the earliest
# stream_pos for which we know _cache is valid.
#
self._earliest_known_stream_pos = current_stream_pos
+
self.name = name
self.metrics = caches.register_cache(
"cache", self.name, self._cache, resize_callback=self.set_cache_factor
@@ -82,22 +114,46 @@ class StreamChangeCache:
return False
def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
- """Returns True if the entity may have been updated since stream_pos"""
+ """
+ Returns True if the entity may have been updated after stream_pos.
+
+ Args:
+ entity: The entity to check for changes.
+ stream_pos: The stream position to check for changes after.
+
+ Return:
+ True if the entity may have been updated, this happens if:
+ * The given stream position is at or earlier than the earliest
+ known stream position.
+ * The given stream position is earlier than the latest change for
+ the entity.
+
+ False otherwise:
+ * The entity is unknown.
+ * The given stream position is at or later than the latest change
+ for the entity.
+ """
assert isinstance(stream_pos, int)
- if stream_pos < self._earliest_known_stream_pos:
+ # _cache is not valid at or before the earliest known stream position, so
+ # return that the entity has changed.
+ if stream_pos <= self._earliest_known_stream_pos:
self.metrics.inc_misses()
return True
+ # If the entity is unknown, it hasn't changed.
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
self.metrics.inc_hits()
return False
+ # This is a known entity, return true if the stream position is earlier
+ # than the last change.
if stream_pos < latest_entity_change_pos:
self.metrics.inc_misses()
return True
+ # Otherwise, the stream position is after the latest change: return false.
self.metrics.inc_hits()
return False
@@ -105,23 +161,35 @@ class StreamChangeCache:
self, entities: Collection[EntityType], stream_pos: int
) -> Union[Set[EntityType], FrozenSet[EntityType]]:
"""
- Returns subset of entities that have had new things since the given
- position. Entities unknown to the cache will be returned. If the
- position is too old it will just return the given list.
+ Returns the subset of the given entities that have had changes after the given position.
+
+ Entities unknown to the cache will be returned.
+
+ If the position is too old it will just return the given list.
+
+ Args:
+ entities: Entities to check for changes.
+ stream_pos: The stream position to check for changes after.
+
+ Return:
+ A subset of entities which have changed after the given stream position.
+
+ This will be all entities if the given stream position is at or earlier
+ than the earliest known stream position.
"""
- changed_entities = self.get_all_entities_changed(stream_pos)
- if changed_entities is not None:
+ cache_result = self.get_all_entities_changed(stream_pos)
+ if cache_result.hit:
# We now do an intersection, trying to do so in the most efficient
# way possible (some of these sets are *large*). First check in the
- # given iterable is already set that we can reuse, otherwise we
+ # given iterable is already a set that we can reuse, otherwise we
# create a set of the *smallest* of the two iterables and call
# `intersection(..)` on it (this can be twice as fast as the reverse).
if isinstance(entities, (set, frozenset)):
- result = entities.intersection(changed_entities)
- elif len(changed_entities) < len(entities):
- result = set(changed_entities).intersection(entities)
+ result = entities.intersection(cache_result.entities)
+ elif len(cache_result.entities) < len(entities):
+ result = set(cache_result.entities).intersection(entities)
else:
- result = set(entities).intersection(changed_entities)
+ result = set(entities).intersection(cache_result.entities)
self.metrics.inc_hits()
else:
result = set(entities)
@@ -130,43 +198,76 @@ class StreamChangeCache:
return result
def has_any_entity_changed(self, stream_pos: int) -> bool:
- """Returns if any entity has changed"""
- assert type(stream_pos) is int
+ """
+ Returns true if any entity has changed after the given stream position.
- if not self._cache:
- # If the cache is empty, nothing can have changed.
- return False
+ Args:
+ stream_pos: The stream position to check for changes after.
- if stream_pos >= self._earliest_known_stream_pos:
- self.metrics.inc_hits()
- return self._cache.bisect_right(stream_pos) < len(self._cache)
- else:
+ Return:
+ True if any entity has changed after the given stream position or
+ if the given stream position is at or earlier than the earliest
+ known stream position.
+
+ False otherwise.
+ """
+ assert isinstance(stream_pos, int)
+
+ # _cache is not valid at or before the earliest known stream position, so
+ # return that an entity has changed.
+ if stream_pos <= self._earliest_known_stream_pos:
self.metrics.inc_misses()
return True
- def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
- """Returns all entities that have had new things since the given
- position. If the position is too old it will return None.
+ # If the cache is empty, nothing can have changed.
+ if not self._cache:
+ self.metrics.inc_misses()
+ return False
+
+ self.metrics.inc_hits()
+ return stream_pos < self._cache.peekitem()[0]
+
+ def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
+ """
+ Returns all entities that have had changes after the given position.
+
+ If the stream change cache does not go far enough back, i.e. the
+ position is too old, it will return None.
Returns the entities in the order that they were changed.
- """
- assert type(stream_pos) is int
- if stream_pos < self._earliest_known_stream_pos:
- return None
+ Args:
+ stream_pos: The stream position to check for changes after.
+
+ Return:
+ A class indicating if we have the requested data cached, and if so
+ includes the entities in the order they were changed.
+ """
+ assert isinstance(stream_pos, int)
+
+ # _cache is not valid at or before the earliest known stream position, so
+ # return None to mark that it is unknown if an entity has changed.
+ if stream_pos <= self._earliest_known_stream_pos:
+ return AllEntitiesChangedResult(None)
changed_entities: List[EntityType] = []
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
changed_entities.extend(self._cache[k])
- return changed_entities
+ return AllEntitiesChangedResult(changed_entities)
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
- """Informs the cache that the entity has been changed at the given
- position.
"""
- assert type(stream_pos) is int
+ Informs the cache that the entity has been changed at the given position.
+ Args:
+ entity: The entity to mark as changed.
+ stream_pos: The stream position to update the entity to.
+ """
+ assert isinstance(stream_pos, int)
+
+ # For a change before _cache is valid (e.g. at or before the earliest known
+ # stream position) there's nothing to do.
if stream_pos <= self._earliest_known_stream_pos:
return
@@ -189,6 +290,11 @@ class StreamChangeCache:
self._evict()
def _evict(self) -> None:
+ """
+ Ensure the cache has not exceeded the maximum size.
+
+ Evicts entries until it is at the maximum size.
+ """
# if the cache is too big, remove entries
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
@@ -199,5 +305,12 @@ class StreamChangeCache:
def get_max_pos_of_last_change(self, entity: EntityType) -> int:
"""Returns an upper bound of the stream id of the last change to an
entity.
+
+ Args:
+ entity: The entity to check.
+
+ Return:
+ The stream position of the latest change for the given entity or
+ the earliest known stream position if the entitiy is unknown.
"""
return self._entity_to_key.get(entity, self._earliest_known_stream_pos)
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index a0606851f7..39fab4fe06 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -15,7 +15,9 @@
import logging
from typing import Dict
-from twisted.web.resource import NoResource, Resource
+from twisted.web.resource import Resource
+
+from synapse.http.server import UnrecognizedRequestResource
logger = logging.getLogger(__name__)
@@ -49,7 +51,7 @@ def create_resource_tree(
for path_seg in full_path.split(b"/")[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
- child_resource: Resource = NoResource()
+ child_resource: Resource = UnrecognizedRequestResource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index 5df03d3ddc..644c341e8c 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -110,6 +110,9 @@ class OidcSessionData:
ui_auth_session_id: str
"""The session ID of the ongoing UI Auth ("" if this is a login)"""
+ code_verifier: str
+ """The random string used in the RFC7636 code challenge ("" if PKCE is not being used)."""
+
class MacaroonGenerator:
def __init__(self, clock: Clock, location: str, secret_key: bytes):
@@ -187,6 +190,7 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat(
f"ui_auth_session_id = {session_data.ui_auth_session_id}"
)
+ macaroon.add_first_party_caveat(f"code_verifier = {session_data.code_verifier}")
macaroon.add_first_party_caveat(f"time < {expiry}")
return macaroon.serialize()
@@ -278,6 +282,7 @@ class MacaroonGenerator:
v.satisfy_general(lambda c: c.startswith("idp_id = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
+ v.satisfy_general(lambda c: c.startswith("code_verifier = "))
satisfy_expiry(v, self._clock.time_msec)
v.verify(macaroon, self._secret_key)
@@ -287,11 +292,13 @@ class MacaroonGenerator:
idp_id = get_value_from_macaroon(macaroon, "idp_id")
client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
+ code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
return OidcSessionData(
nonce=nonce,
idp_id=idp_id,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
+ code_verifier=code_verifier,
)
def _generate_base_macaroon(self, type: MacaroonType) -> pymacaroons.Macaroon:
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 84337b0796..e01645f1ab 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -364,12 +364,22 @@ class _PerHostRatelimiter:
def _on_exit(self, request_id: object) -> None:
logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
- self.current_processing.discard(request_id)
- try:
- # start processing the next item on the queue.
- _, deferred = self.ready_request_queue.popitem(last=False)
- with PreserveLoggingContext():
- deferred.callback(None)
- except KeyError:
- pass
+ # When requests complete synchronously, we will recursively start the next
+ # request in the queue. To avoid stack exhaustion, we defer starting the next
+ # request until the next reactor tick.
+
+ def start_next_request() -> None:
+ # We only remove the completed request from the list when we're about to
+ # start the next one, otherwise we can allow extra requests through.
+ self.current_processing.discard(request_id)
+ try:
+ # start processing the next item on the queue.
+ _, deferred = self.ready_request_queue.popitem(last=False)
+
+ with PreserveLoggingContext():
+ deferred.callback(None)
+ except KeyError:
+ pass
+
+ self.clock.call_later(0.0, start_next_request)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index b443857571..e442de3173 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -26,8 +26,8 @@ from synapse.events.utils import prune_event
from synapse.logging.opentracing import trace
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
-from synapse.storage.state import StateFilter
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
+from synapse.types.state import StateFilter
from synapse.util import Clock
logger = logging.getLogger(__name__)
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index e0f363555b..6e36e73f0d 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -31,7 +31,7 @@ from synapse.api.errors import (
from synapse.appservice import ApplicationService
from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import Requester
+from synapse.types import Requester, UserID
from synapse.util import Clock
from tests import unittest
@@ -41,10 +41,12 @@ from tests.utils import mock_getRawHeaders
class AuthTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = Mock()
- hs.datastores.main = self.store
+ # type-ignore: datastores is None until hs.setup() is called---but it'll
+ # have been called by the HomeserverTestCase machinery.
+ hs.datastores.main = self.store # type: ignore[union-attr]
hs.get_auth_handler().store = self.store
self.auth = Auth(hs)
@@ -61,7 +63,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.insert_client_ip = simple_async_mock(None)
self.store.is_support_user = simple_async_mock(False)
- def test_get_user_by_req_user_valid_token(self):
+ def test_get_user_by_req_user_valid_token(self) -> None:
user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
@@ -74,7 +76,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user)
- def test_get_user_by_req_user_bad_token(self):
+ def test_get_user_by_req_user_bad_token(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
@@ -86,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
- def test_get_user_by_req_user_missing_token(self):
+ def test_get_user_by_req_user_missing_token(self) -> None:
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = simple_async_mock(user_info)
@@ -98,7 +100,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
- def test_get_user_by_req_appservice_valid_token(self):
+ def test_get_user_by_req_appservice_valid_token(self) -> None:
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
@@ -112,7 +114,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user)
- def test_get_user_by_req_appservice_valid_token_good_ip(self):
+ def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None:
from netaddr import IPSet
app_service = Mock(
@@ -131,7 +133,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user)
- def test_get_user_by_req_appservice_valid_token_bad_ip(self):
+ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None:
from netaddr import IPSet
app_service = Mock(
@@ -153,7 +155,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
- def test_get_user_by_req_appservice_bad_token(self):
+ def test_get_user_by_req_appservice_bad_token(self) -> None:
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = simple_async_mock(None)
@@ -166,7 +168,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
- def test_get_user_by_req_appservice_missing_token(self):
+ def test_get_user_by_req_appservice_missing_token(self) -> None:
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
@@ -179,7 +181,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
- def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
+ def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None:
masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
@@ -200,7 +202,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
- def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
+ def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None:
masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
@@ -217,7 +219,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_failure(self.auth.get_user_by_req(request), AuthError)
@override_config({"experimental_features": {"msc3202_device_masquerading": True}})
- def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
+ def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None:
"""
Tests that when an application service passes the device_id URL parameter
with the ID of a valid device for the user in question,
@@ -249,7 +251,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8"))
@override_config({"experimental_features": {"msc3202_device_masquerading": True}})
- def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
+ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None:
"""
Tests that when an application service passes the device_id URL parameter
with an ID that is not a valid device ID for the user in question,
@@ -279,7 +281,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.value.code, 400)
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
- def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self):
+ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(
user_id="@baldrick:matrix.org",
@@ -298,7 +300,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_success(self.auth.get_user_by_req(request))
self.store.insert_client_ip.assert_called_once()
- def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self):
+ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
self.auth._track_puppeted_user_ips = True
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(
@@ -318,7 +320,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(self.store.insert_client_ip.call_count, 2)
- def test_get_user_from_macaroon(self):
+ def test_get_user_from_macaroon(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None)
user_id = "@baldrick:matrix.org"
@@ -336,7 +338,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth.get_user_by_access_token(serialized), InvalidClientTokenError
)
- def test_get_guest_user_from_macaroon(self):
+ def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = simple_async_mock({"is_guest": True})
self.store.get_user_by_access_token = simple_async_mock(None)
@@ -357,7 +359,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
- def test_blocking_mau(self):
+ def test_blocking_mau(self) -> None:
self.auth_blocking._limit_usage_by_mau = False
self.auth_blocking._max_mau_value = 50
lots_of_users = 100
@@ -381,7 +383,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
self.get_success(self.auth_blocking.check_auth_blocking())
- def test_blocking_mau__depending_on_user_type(self):
+ def test_blocking_mau__depending_on_user_type(self) -> None:
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
@@ -400,7 +402,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
# Real users not allowed
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
- def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
+ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(
+ self,
+ ) -> None:
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = False
@@ -418,7 +422,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
sender="@appservice:sender",
)
requester = Requester(
- user="@appservice:server",
+ user=UserID.from_string("@appservice:server"),
access_token_id=None,
device_id="FOOBAR",
is_guest=False,
@@ -428,7 +432,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
- def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
+ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(
+ self,
+ ) -> None:
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = True
@@ -446,7 +452,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
sender="@appservice:sender",
)
requester = Requester(
- user="@appservice:server",
+ user=UserID.from_string("@appservice:server"),
access_token_id=None,
device_id="FOOBAR",
is_guest=False,
@@ -459,7 +465,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
- def test_reserved_threepid(self):
+ def test_reserved_threepid(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = simple_async_mock(2)
@@ -476,7 +482,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
- def test_hs_disabled(self):
+ def test_hs_disabled(self) -> None:
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure(
@@ -486,7 +492,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
- def test_hs_disabled_no_server_notices_user(self):
+ def test_hs_disabled_no_server_notices_user(self) -> None:
"""Check that 'hs_disabled_message' works correctly when there is no
server_notices user.
"""
@@ -503,7 +509,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
- def test_server_notices_mxid_special_cased(self):
+ def test_server_notices_mxid_special_cased(self) -> None:
self.auth_blocking._hs_disabled = True
user = "@user:server"
self.auth_blocking._server_notices_mxid = user
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index d5524d296e..0f45615160 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -14,40 +14,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import List
from unittest.mock import patch
import jsonschema
from frozendict import frozendict
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
-from synapse.events import make_event_from_dict
+from synapse.api.presence import UserPresenceState
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
+from tests.events.test_utils import MockEvent
user_localpart = "test_user"
-def MockEvent(**kwargs):
- if "event_id" not in kwargs:
- kwargs["event_id"] = "fake_event_id"
- if "type" not in kwargs:
- kwargs["type"] = "fake_type"
- if "content" not in kwargs:
- kwargs["content"] = {}
- return make_event_from_dict(kwargs)
-
-
class FilteringTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.filtering = hs.get_filtering()
self.datastore = hs.get_datastores().main
- def test_errors_on_invalid_filters(self):
+ def test_errors_on_invalid_filters(self) -> None:
# See USER_FILTER_SCHEMA for the filter schema.
- invalid_filters = [
+ invalid_filters: List[JsonDict] = [
# `account_data` must be a dictionary
{"account_data": "Hello World"},
# `event_fields` entries must not contain backslashes
@@ -63,10 +59,10 @@ class FilteringTestCase(unittest.HomeserverTestCase):
with self.assertRaises(SynapseError):
self.filtering.check_valid_filter(filter)
- def test_ignores_unknown_filter_fields(self):
+ def test_ignores_unknown_filter_fields(self) -> None:
# For forward compatibility, we must ignore unknown filter fields.
# See USER_FILTER_SCHEMA for the filter schema.
- filters = [
+ filters: List[JsonDict] = [
{"org.matrix.msc9999.future_option": True},
{"presence": {"org.matrix.msc9999.future_option": True}},
{"room": {"org.matrix.msc9999.future_option": True}},
@@ -76,8 +72,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.filtering.check_valid_filter(filter)
# Must not raise.
- def test_valid_filters(self):
- valid_filters = [
+ def test_valid_filters(self) -> None:
+ valid_filters: List[JsonDict] = [
{
"room": {
"timeline": {"limit": 20},
@@ -132,22 +128,22 @@ class FilteringTestCase(unittest.HomeserverTestCase):
except jsonschema.ValidationError as e:
self.fail(e)
- def test_limits_are_applied(self):
+ def test_limits_are_applied(self) -> None:
# TODO
pass
- def test_definition_types_works_with_literals(self):
+ def test_definition_types_works_with_literals(self) -> None:
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertTrue(Filter(self.hs, definition)._check(event))
- def test_definition_types_works_with_wildcards(self):
+ def test_definition_types_works_with_wildcards(self) -> None:
definition = {"types": ["m.*", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertTrue(Filter(self.hs, definition)._check(event))
- def test_definition_types_works_with_unknowns(self):
+ def test_definition_types_works_with_unknowns(self) -> None:
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(
sender="@foo:bar",
@@ -156,24 +152,24 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_not_types_works_with_literals(self):
+ def test_definition_not_types_works_with_literals(self) -> None:
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_not_types_works_with_wildcards(self):
+ def test_definition_not_types_works_with_wildcards(self) -> None:
definition = {"not_types": ["m.room.message", "org.matrix.*"]}
event = MockEvent(
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_not_types_works_with_unknowns(self):
+ def test_definition_not_types_works_with_unknowns(self) -> None:
definition = {"not_types": ["m.*", "org.*"]}
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
self.assertTrue(Filter(self.hs, definition)._check(event))
- def test_definition_not_types_takes_priority_over_types(self):
+ def test_definition_not_types_takes_priority_over_types(self) -> None:
definition = {
"not_types": ["m.*", "org.*"],
"types": ["m.room.message", "m.room.topic"],
@@ -181,35 +177,35 @@ class FilteringTestCase(unittest.HomeserverTestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_senders_works_with_literals(self):
+ def test_definition_senders_works_with_literals(self) -> None:
definition = {"senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
)
self.assertTrue(Filter(self.hs, definition)._check(event))
- def test_definition_senders_works_with_unknowns(self):
+ def test_definition_senders_works_with_unknowns(self) -> None:
definition = {"senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_not_senders_works_with_literals(self):
+ def test_definition_not_senders_works_with_literals(self) -> None:
definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_not_senders_works_with_unknowns(self):
+ def test_definition_not_senders_works_with_unknowns(self) -> None:
definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
)
self.assertTrue(Filter(self.hs, definition)._check(event))
- def test_definition_not_senders_takes_priority_over_senders(self):
+ def test_definition_not_senders_takes_priority_over_senders(self) -> None:
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets", "@misspiggy:muppets"],
@@ -219,14 +215,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_rooms_works_with_literals(self):
+ def test_definition_rooms_works_with_literals(self) -> None:
definition = {"rooms": ["!secretbase:unknown"]}
event = MockEvent(
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
)
self.assertTrue(Filter(self.hs, definition)._check(event))
- def test_definition_rooms_works_with_unknowns(self):
+ def test_definition_rooms_works_with_unknowns(self) -> None:
definition = {"rooms": ["!secretbase:unknown"]}
event = MockEvent(
sender="@foo:bar",
@@ -235,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_not_rooms_works_with_literals(self):
+ def test_definition_not_rooms_works_with_literals(self) -> None:
definition = {"not_rooms": ["!anothersecretbase:unknown"]}
event = MockEvent(
sender="@foo:bar",
@@ -244,7 +240,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_not_rooms_works_with_unknowns(self):
+ def test_definition_not_rooms_works_with_unknowns(self) -> None:
definition = {"not_rooms": ["!secretbase:unknown"]}
event = MockEvent(
sender="@foo:bar",
@@ -253,7 +249,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertTrue(Filter(self.hs, definition)._check(event))
- def test_definition_not_rooms_takes_priority_over_rooms(self):
+ def test_definition_not_rooms_takes_priority_over_rooms(self) -> None:
definition = {
"not_rooms": ["!secretbase:unknown"],
"rooms": ["!secretbase:unknown"],
@@ -263,7 +259,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_combined_event(self):
+ def test_definition_combined_event(self) -> None:
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
@@ -279,7 +275,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertTrue(Filter(self.hs, definition)._check(event))
- def test_definition_combined_event_bad_sender(self):
+ def test_definition_combined_event_bad_sender(self) -> None:
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
@@ -295,7 +291,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_combined_event_bad_room(self):
+ def test_definition_combined_event_bad_room(self) -> None:
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
@@ -311,7 +307,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_definition_combined_event_bad_type(self):
+ def test_definition_combined_event_bad_type(self) -> None:
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
@@ -327,7 +323,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
- def test_filter_labels(self):
+ def test_filter_labels(self) -> None:
definition = {"org.matrix.labels": ["#fun"]}
event = MockEvent(
sender="@foo:bar",
@@ -356,7 +352,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertTrue(Filter(self.hs, definition)._check(event))
- def test_filter_not_labels(self):
+ def test_filter_not_labels(self) -> None:
definition = {"org.matrix.not_labels": ["#fun"]}
event = MockEvent(
sender="@foo:bar",
@@ -377,7 +373,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event))
@unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
- def test_filter_rel_type(self):
+ def test_filter_rel_type(self) -> None:
definition = {"org.matrix.msc3874.rel_types": ["m.thread"]}
event = MockEvent(
sender="@foo:bar",
@@ -407,7 +403,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event))
@unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
- def test_filter_not_rel_type(self):
+ def test_filter_not_rel_type(self) -> None:
definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]}
event = MockEvent(
sender="@foo:bar",
@@ -436,15 +432,25 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event))
- def test_filter_presence_match(self):
- user_filter_json = {"presence": {"types": ["m.*"]}}
+ def test_filter_presence_match(self) -> None:
+ """Check that filter_presence return events which matches the filter."""
+ user_filter_json = {"presence": {"senders": ["@foo:bar"]}}
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
)
- event = MockEvent(sender="@foo:bar", type="m.profile")
- events = [event]
+ presence_states = [
+ UserPresenceState(
+ user_id="@foo:bar",
+ state="unavailable",
+ last_active_ts=0,
+ last_federation_update_ts=0,
+ last_user_sync_ts=0,
+ status_msg=None,
+ currently_active=False,
+ ),
+ ]
user_filter = self.get_success(
self.filtering.get_user_filter(
@@ -452,23 +458,29 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
- results = self.get_success(user_filter.filter_presence(events=events))
- self.assertEqual(events, results)
+ results = self.get_success(user_filter.filter_presence(presence_states))
+ self.assertEqual(presence_states, results)
- def test_filter_presence_no_match(self):
- user_filter_json = {"presence": {"types": ["m.*"]}}
+ def test_filter_presence_no_match(self) -> None:
+ """Check that filter_presence does not return events rejected by the filter."""
+ user_filter_json = {"presence": {"not_senders": ["@foo:bar"]}}
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_filter=user_filter_json
)
)
- event = MockEvent(
- event_id="$asdasd:localhost",
- sender="@foo:bar",
- type="custom.avatar.3d.crazy",
- )
- events = [event]
+ presence_states = [
+ UserPresenceState(
+ user_id="@foo:bar",
+ state="unavailable",
+ last_active_ts=0,
+ last_federation_update_ts=0,
+ last_user_sync_ts=0,
+ status_msg=None,
+ currently_active=False,
+ ),
+ ]
user_filter = self.get_success(
self.filtering.get_user_filter(
@@ -476,10 +488,10 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
- results = self.get_success(user_filter.filter_presence(events=events))
+ results = self.get_success(user_filter.filter_presence(presence_states))
self.assertEqual([], results)
- def test_filter_room_state_match(self):
+ def test_filter_room_state_match(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success(
self.datastore.add_user_filter(
@@ -498,7 +510,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
results = self.get_success(user_filter.filter_room_state(events=events))
self.assertEqual(events, results)
- def test_filter_room_state_no_match(self):
+ def test_filter_room_state_no_match(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success(
self.datastore.add_user_filter(
@@ -519,7 +531,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
results = self.get_success(user_filter.filter_room_state(events))
self.assertEqual([], results)
- def test_filter_rooms(self):
+ def test_filter_rooms(self) -> None:
definition = {
"rooms": ["!allowed:example.com", "!excluded:example.com"],
"not_rooms": ["!excluded:example.com"],
@@ -535,7 +547,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertEqual(filtered_room_ids, ["!allowed:example.com"])
- def test_filter_relations(self):
+ def test_filter_relations(self) -> None:
events = [
# An event without a relation.
MockEvent(
@@ -551,9 +563,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="org.matrix.custom.event",
room_id="!foo:bar",
),
- # Non-EventBase objects get passed through.
- {},
]
+ jsondicts: List[JsonDict] = [{}]
# For the following tests we patch the datastore method (intead of injecting
# events). This is a bit cheeky, but tests the logic of _check_event_relations.
@@ -561,7 +572,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
# Filter for a particular sender.
definition = {"related_by_senders": ["@foo:bar"]}
- async def events_have_relations(*args, **kwargs):
+ async def events_have_relations(*args: object, **kwargs: object) -> List[str]:
return ["$with_relation"]
with patch.object(
@@ -572,9 +583,17 @@ class FilteringTestCase(unittest.HomeserverTestCase):
Filter(self.hs, definition)._check_event_relations(events)
)
)
- self.assertEqual(filtered_events, events[1:])
+ # Non-EventBase objects get passed through.
+ filtered_jsondicts = list(
+ self.get_success(
+ Filter(self.hs, definition)._check_event_relations(jsondicts)
+ )
+ )
- def test_add_filter(self):
+ self.assertEqual(filtered_events, events[1:])
+ self.assertEqual(filtered_jsondicts, [{}])
+
+ def test_add_filter(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success(
@@ -595,7 +614,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
),
)
- def test_get_filter(self):
+ def test_get_filter(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success(
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index c86f783c5b..fa6c1c02ce 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -6,9 +6,12 @@ from tests import unittest
class TestRatelimiter(unittest.HomeserverTestCase):
- def test_allowed_via_can_do_action(self):
+ def test_allowed_via_can_do_action(self) -> None:
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=1,
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", _time_now_s=0)
@@ -28,9 +31,9 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed)
self.assertEqual(20.0, time_allowed)
- def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
+ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self) -> None:
appservice = ApplicationService(
- None,
+ token="fake_token",
id="foo",
rate_limited=True,
sender="@as:example.com",
@@ -38,7 +41,10 @@ class TestRatelimiter(unittest.HomeserverTestCase):
as_requester = create_requester("@user:example.com", app_service=appservice)
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=1,
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(as_requester, _time_now_s=0)
@@ -58,9 +64,9 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed)
self.assertEqual(20.0, time_allowed)
- def test_allowed_appservice_via_can_requester_do_action(self):
+ def test_allowed_appservice_via_can_requester_do_action(self) -> None:
appservice = ApplicationService(
- None,
+ token="fake_token",
id="foo",
rate_limited=False,
sender="@as:example.com",
@@ -68,7 +74,10 @@ class TestRatelimiter(unittest.HomeserverTestCase):
as_requester = create_requester("@user:example.com", app_service=appservice)
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=1,
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(as_requester, _time_now_s=0)
@@ -88,9 +97,12 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed)
self.assertEqual(-1, time_allowed)
- def test_allowed_via_ratelimit(self):
+ def test_allowed_via_ratelimit(self) -> None:
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=1,
)
# Shouldn't raise
@@ -108,13 +120,16 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter.ratelimit(None, key="test_id", _time_now_s=10)
)
- def test_allowed_via_can_do_action_and_overriding_parameters(self):
+ def test_allowed_via_can_do_action_and_overriding_parameters(self) -> None:
"""Test that we can override options of can_do_action that would otherwise fail
an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=1,
)
# First attempt should be allowed
@@ -154,13 +169,16 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed)
self.assertEqual(1.0, time_allowed)
- def test_allowed_via_ratelimit_and_overriding_parameters(self):
+ def test_allowed_via_ratelimit_and_overriding_parameters(self) -> None:
"""Test that we can override options of the ratelimit method that would otherwise
fail an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=1,
)
# First attempt should be allowed
@@ -186,9 +204,12 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
)
- def test_pruning(self):
+ def test_pruning(self) -> None:
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=1,
)
self.get_success_or_raise(
limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
@@ -202,7 +223,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertNotIn("test_id_1", limiter.actions)
- def test_db_user_override(self):
+ def test_db_user_override(self) -> None:
"""Test that users that have ratelimiting disabled in the DB aren't
ratelimited.
"""
@@ -223,15 +244,18 @@ class TestRatelimiter(unittest.HomeserverTestCase):
)
)
- limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1)
# Shouldn't raise
for _ in range(20):
self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
- def test_multiple_actions(self):
+ def test_multiple_actions(self) -> None:
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=3,
)
# Test that 4 actions aren't allowed with a maximum burst of 3.
allowed, time_allowed = self.get_success_or_raise(
@@ -295,7 +319,10 @@ class TestRatelimiter(unittest.HomeserverTestCase):
extra tokens by timing requests.
"""
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=3,
)
def consume_at(time: float) -> bool:
@@ -317,7 +344,10 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_record_action_which_doesnt_fill_bucket(self) -> None:
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=3,
)
# Observe two actions, leaving room in the bucket for one more.
@@ -337,7 +367,10 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_record_action_which_fills_bucket(self) -> None:
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=3,
)
# Observe three actions, filling up the bucket.
@@ -363,7 +396,10 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_record_action_which_overfills_bucket(self) -> None:
limiter = Ratelimiter(
- store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ rate_hz=0.1,
+ burst_count=3,
)
# Observe four actions, exceeding the bucket.
diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py
index cbcada0451..788c935537 100644
--- a/tests/app/test_homeserver_start.py
+++ b/tests/app/test_homeserver_start.py
@@ -19,7 +19,7 @@ from tests.config.utils import ConfigFileTestCase
class HomeserverAppStartTestCase(ConfigFileTestCase):
- def test_wrong_start_caught(self):
+ def test_wrong_start_caught(self) -> None:
# Generate a config with a worker_app
self.generate_config()
# Add a blank line as otherwise the next addition ends up on a line with a comment
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 8d03da7f96..5d89ba94ad 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -11,26 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List
from unittest.mock import Mock, patch
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.server import make_request
from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
- def default_config(self):
+ def default_config(self) -> JsonDict:
conf = super().default_config()
# we're using FederationReaderServer, which uses a SlavedStore, so we
# have to tell the FederationHandler not to try to access stuff that is only
@@ -47,7 +53,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
(["openid"], "auth_fail"),
]
)
- def test_openid_listener(self, names, expectation):
+ def test_openid_listener(self, names: List[str], expectation: str) -> None:
"""
Test different openid listener configurations.
@@ -81,7 +87,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
@patch("synapse.app.homeserver.KeyResource", new=Mock())
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
federation_http_client=None, homeserver_to_use=SynapseHomeServer
)
@@ -95,7 +101,7 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
(["openid"], "auth_fail"),
]
)
- def test_openid_listener(self, names, expectation):
+ def test_openid_listener(self, names: List[str], expectation: str) -> None:
"""
Test different openid listener configurations.
diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
index df731eb599..a860eedbcf 100644
--- a/tests/app/test_phone_stats_home.py
+++ b/tests/app/test_phone_stats_home.py
@@ -1,8 +1,11 @@
import synapse
from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
+from tests.server import ThreadedMemoryReactorClock
from tests.unittest import HomeserverTestCase
FIVE_MINUTES_IN_SECONDS = 300
@@ -19,7 +22,7 @@ class PhoneHomeTestCase(HomeserverTestCase):
# Override the retention time for the user_ips table because otherwise it
# gets pruned too aggressively for our R30 test.
@unittest.override_config({"user_ips_max_age": "365d"})
- def test_r30_minimum_usage(self):
+ def test_r30_minimum_usage(self) -> None:
"""
Tests the minimum amount of interaction necessary for the R30 metric
to consider a user 'retained'.
@@ -68,7 +71,7 @@ class PhoneHomeTestCase(HomeserverTestCase):
r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
self.assertEqual(r30_results, {"all": 0})
- def test_r30_minimum_usage_using_default_config(self):
+ def test_r30_minimum_usage_using_default_config(self) -> None:
"""
Tests the minimum amount of interaction necessary for the R30 metric
to consider a user 'retained'.
@@ -122,7 +125,7 @@ class PhoneHomeTestCase(HomeserverTestCase):
r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
self.assertEqual(r30_results, {"all": 0})
- def test_r30_user_must_be_retained_for_at_least_a_month(self):
+ def test_r30_user_must_be_retained_for_at_least_a_month(self) -> None:
"""
Tests that a newly-registered user must be retained for a whole month
before appearing in the R30 statistic, even if they post every day
@@ -164,12 +167,14 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
login.register_servlets,
]
- def _advance_to(self, desired_time_secs: float):
+ def _advance_to(self, desired_time_secs: float) -> None:
now = self.hs.get_clock().time()
assert now < desired_time_secs
self.reactor.advance(desired_time_secs - now)
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(
+ self, reactor: ThreadedMemoryReactorClock, clock: Clock
+ ) -> HomeServer:
hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
# We don't want our tests to actually report statistics, so check
@@ -181,7 +186,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
start_phone_stats_home(hs)
return hs
- def test_r30v2_minimum_usage(self):
+ def test_r30v2_minimum_usage(self) -> None:
"""
Tests the minimum amount of interaction necessary for the R30v2 metric
to consider a user 'retained'.
@@ -250,7 +255,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
)
- def test_r30v2_user_must_be_retained_for_at_least_a_month(self):
+ def test_r30v2_user_must_be_retained_for_at_least_a_month(self) -> None:
"""
Tests that a newly-registered user must be retained for a whole month
before appearing in the R30v2 statistic, even if they post every day
@@ -316,7 +321,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0}
)
- def test_r30v2_returning_dormant_users_not_counted(self):
+ def test_r30v2_returning_dormant_users_not_counted(self) -> None:
"""
Tests that dormant users (users inactive for a long time) do not
contribute to R30v2 when they return for just a single day.
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 0b22afdc75..0a1ae83a2b 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -69,7 +69,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
- one_time_key_counts={},
+ one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
@@ -96,7 +96,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
- one_time_key_counts={},
+ one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
@@ -125,7 +125,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events=events,
ephemeral=[],
to_device_messages=[],
- one_time_key_counts={},
+ one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
diff --git a/tests/config/test___main__.py b/tests/config/test___main__.py
index b1c73d3612..cb5d4b05c3 100644
--- a/tests/config/test___main__.py
+++ b/tests/config/test___main__.py
@@ -17,15 +17,15 @@ from tests.config.utils import ConfigFileTestCase
class ConfigMainFileTestCase(ConfigFileTestCase):
- def test_executes_without_an_action(self):
+ def test_executes_without_an_action(self) -> None:
self.generate_config()
main(["", "-c", self.config_file])
- def test_read__error_if_key_not_found(self):
+ def test_read__error_if_key_not_found(self) -> None:
self.generate_config()
with self.assertRaises(SystemExit):
main(["", "read", "foo.bar.hello", "-c", self.config_file])
- def test_read__passes_if_key_found(self):
+ def test_read__passes_if_key_found(self) -> None:
self.generate_config()
main(["", "read", "server.server_name", "-c", self.config_file])
diff --git a/tests/config/test_api.py b/tests/config/test_api.py
new file mode 100644
index 0000000000..6773c9a277
--- /dev/null
+++ b/tests/config/test_api.py
@@ -0,0 +1,145 @@
+from unittest import TestCase as StdlibTestCase
+
+import yaml
+
+from synapse.config import ConfigError
+from synapse.config.api import ApiConfig
+from synapse.types.state import StateFilter
+
+DEFAULT_PREJOIN_STATE_PAIRS = {
+ ("m.room.join_rules", ""),
+ ("m.room.canonical_alias", ""),
+ ("m.room.avatar", ""),
+ ("m.room.encryption", ""),
+ ("m.room.name", ""),
+ ("m.room.create", ""),
+ ("m.room.topic", ""),
+}
+
+
+class TestRoomPrejoinState(StdlibTestCase):
+ def read_config(self, source: str) -> ApiConfig:
+ config = ApiConfig()
+ config.read_config(yaml.safe_load(source))
+ return config
+
+ def test_no_prejoin_state(self) -> None:
+ config = self.read_config("foo: bar")
+ self.assertFalse(config.room_prejoin_state.has_wildcards())
+ self.assertEqual(
+ set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS
+ )
+
+ def test_disable_default_event_types(self) -> None:
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ """
+ )
+ self.assertEqual(config.room_prejoin_state, StateFilter.none())
+
+ def test_event_without_state_key(self) -> None:
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ additional_event_types:
+ - foo
+ """
+ )
+ self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
+ self.assertEqual(config.room_prejoin_state.concrete_types(), [])
+
+ def test_event_with_specific_state_key(self) -> None:
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ additional_event_types:
+ - [foo, bar]
+ """
+ )
+ self.assertFalse(config.room_prejoin_state.has_wildcards())
+ self.assertEqual(
+ set(config.room_prejoin_state.concrete_types()),
+ {("foo", "bar")},
+ )
+
+ def test_repeated_event_with_specific_state_key(self) -> None:
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ additional_event_types:
+ - [foo, bar]
+ - [foo, baz]
+ """
+ )
+ self.assertFalse(config.room_prejoin_state.has_wildcards())
+ self.assertEqual(
+ set(config.room_prejoin_state.concrete_types()),
+ {("foo", "bar"), ("foo", "baz")},
+ )
+
+ def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ additional_event_types:
+ - [foo, bar]
+ - foo
+ """
+ )
+ self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
+ self.assertEqual(config.room_prejoin_state.concrete_types(), [])
+
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ additional_event_types:
+ - foo
+ - [foo, bar]
+ """
+ )
+ self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
+ self.assertEqual(config.room_prejoin_state.concrete_types(), [])
+
+ def test_bad_event_type_entry_raises(self) -> None:
+ with self.assertRaises(ConfigError):
+ self.read_config(
+ """
+room_prejoin_state:
+ additional_event_types:
+ - []
+ """
+ )
+
+ with self.assertRaises(ConfigError):
+ self.read_config(
+ """
+room_prejoin_state:
+ additional_event_types:
+ - [a]
+ """
+ )
+
+ with self.assertRaises(ConfigError):
+ self.read_config(
+ """
+room_prejoin_state:
+ additional_event_types:
+ - [a, b, c]
+ """
+ )
+
+ with self.assertRaises(ConfigError):
+ self.read_config(
+ """
+room_prejoin_state:
+ additional_event_types:
+ - [true, 1.23]
+ """
+ )
diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py
index 0c32c1ca29..e4bad2ba6e 100644
--- a/tests/config/test_background_update.py
+++ b/tests/config/test_background_update.py
@@ -22,7 +22,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase):
# Tests that the default values in the config are correctly loaded. Note that the default
# values are loaded when the corresponding config options are commented out, which is why there isn't
# a config specified here.
- def test_default_configuration(self):
+ def test_default_configuration(self) -> None:
background_updater = BackgroundUpdater(
self.hs, self.hs.get_datastores().main.db_pool
)
@@ -46,7 +46,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase):
"""
)
)
- def test_custom_configuration(self):
+ def test_custom_configuration(self) -> None:
background_updater = BackgroundUpdater(
self.hs, self.hs.get_datastores().main.db_pool
)
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
index 6a52f862f4..3fbfe6c1da 100644
--- a/tests/config/test_base.py
+++ b/tests/config/test_base.py
@@ -24,13 +24,13 @@ from tests import unittest
class BaseConfigTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
# The root object needs a server property with a public_baseurl.
root = Mock()
root.server.public_baseurl = "http://test"
self.config = Config(root)
- def test_loading_missing_templates(self):
+ def test_loading_missing_templates(self) -> None:
# Use a temporary directory that exists on the system, but that isn't likely to
# contain template files
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -50,7 +50,7 @@ class BaseConfigTestCase(unittest.TestCase):
"Template file did not contain our test string",
)
- def test_loading_custom_templates(self):
+ def test_loading_custom_templates(self) -> None:
# Use a temporary directory that exists on the system
with tempfile.TemporaryDirectory() as tmp_dir:
# Create a temporary bogus template file
@@ -79,7 +79,7 @@ class BaseConfigTestCase(unittest.TestCase):
"Template file did not contain our test string",
)
- def test_multiple_custom_template_directories(self):
+ def test_multiple_custom_template_directories(self) -> None:
"""Tests that directories are searched in the right order if multiple custom
template directories are provided.
"""
@@ -137,7 +137,7 @@ class BaseConfigTestCase(unittest.TestCase):
for td in tempdirs:
td.cleanup()
- def test_loading_template_from_nonexistent_custom_directory(self):
+ def test_loading_template_from_nonexistent_custom_directory(self) -> None:
with self.assertRaises(ConfigError):
self.config.read_templates(
["some_filename.html"], ("a_nonexistent_directory",)
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index d2b3c299e3..96f66af328 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -13,26 +13,27 @@
# limitations under the License.
from synapse.config.cache import CacheConfig, add_resizable_cache
+from synapse.types import JsonDict
from synapse.util.caches.lrucache import LruCache
from tests.unittest import TestCase
class CacheConfigTests(TestCase):
- def setUp(self):
+ def setUp(self) -> None:
# Reset caches before each test since there's global state involved.
self.config = CacheConfig()
self.config.reset()
- def tearDown(self):
+ def tearDown(self) -> None:
# Also reset the caches after each test to leave state pristine.
self.config.reset()
- def test_individual_caches_from_environ(self):
+ def test_individual_caches_from_environ(self) -> None:
"""
Individual cache factors will be loaded from the environment.
"""
- config = {}
+ config: JsonDict = {}
self.config._environ = {
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
"SYNAPSE_NOT_CACHE": "BLAH",
@@ -42,15 +43,15 @@ class CacheConfigTests(TestCase):
self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
- def test_config_overrides_environ(self):
+ def test_config_overrides_environ(self) -> None:
"""
Individual cache factors defined in the environment will take precedence
over those in the config.
"""
- config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
+ config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
self.config._environ = {
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
- "SYNAPSE_CACHE_FACTOR_FOO": 1,
+ "SYNAPSE_CACHE_FACTOR_FOO": "1",
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
@@ -60,104 +61,104 @@ class CacheConfigTests(TestCase):
{"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
)
- def test_individual_instantiated_before_config_load(self):
+ def test_individual_instantiated_before_config_load(self) -> None:
"""
If a cache is instantiated before the config is read, it will be given
the default cache size in the interim, and then resized once the config
is loaded.
"""
- cache = LruCache(100)
+ cache: LruCache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 50)
- config = {"caches": {"per_cache_factors": {"foo": 3}}}
+ config: JsonDict = {"caches": {"per_cache_factors": {"foo": 3}}}
self.config.read_config(config)
self.config.resize_all_caches()
self.assertEqual(cache.max_size, 300)
- def test_individual_instantiated_after_config_load(self):
+ def test_individual_instantiated_after_config_load(self) -> None:
"""
If a cache is instantiated after the config is read, it will be
immediately resized to the correct size given the per_cache_factor if
there is one.
"""
- config = {"caches": {"per_cache_factors": {"foo": 2}}}
+ config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2}}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
- cache = LruCache(100)
+ cache: LruCache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 200)
- def test_global_instantiated_before_config_load(self):
+ def test_global_instantiated_before_config_load(self) -> None:
"""
If a cache is instantiated before the config is read, it will be given
the default cache size in the interim, and then resized to the new
default cache size once the config is loaded.
"""
- cache = LruCache(100)
+ cache: LruCache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 50)
- config = {"caches": {"global_factor": 4}}
+ config: JsonDict = {"caches": {"global_factor": 4}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
self.assertEqual(cache.max_size, 400)
- def test_global_instantiated_after_config_load(self):
+ def test_global_instantiated_after_config_load(self) -> None:
"""
If a cache is instantiated after the config is read, it will be
immediately resized to the correct size given the global factor if there
is no per-cache factor.
"""
- config = {"caches": {"global_factor": 1.5}}
+ config: JsonDict = {"caches": {"global_factor": 1.5}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
- cache = LruCache(100)
+ cache: LruCache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 150)
- def test_cache_with_asterisk_in_name(self):
+ def test_cache_with_asterisk_in_name(self) -> None:
"""Some caches have asterisks in their name, test that they are set correctly."""
- config = {
+ config: JsonDict = {
"caches": {
"per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
}
}
self.config._environ = {
"SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
- "SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
+ "SYNAPSE_CACHE_FACTOR_CACHE_B": "3",
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
- cache_a = LruCache(100)
+ cache_a: LruCache = LruCache(100)
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
self.assertEqual(cache_a.max_size, 200)
- cache_b = LruCache(100)
+ cache_b: LruCache = LruCache(100)
add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
self.assertEqual(cache_b.max_size, 300)
- cache_c = LruCache(100)
+ cache_c: LruCache = LruCache(100)
add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
self.assertEqual(cache_c.max_size, 200)
- def test_apply_cache_factor_from_config(self):
+ def test_apply_cache_factor_from_config(self) -> None:
"""Caches can disable applying cache factor updates, mainly used by
event cache size.
"""
- config = {"caches": {"event_cache_size": "10k"}}
+ config: JsonDict = {"caches": {"event_cache_size": "10k"}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
- cache = LruCache(
+ cache: LruCache = LruCache(
max_size=self.config.event_cache_size,
apply_cache_factor_from_config=False,
)
diff --git a/tests/config/test_database.py b/tests/config/test_database.py
index 9eca10bbe9..240277bcc6 100644
--- a/tests/config/test_database.py
+++ b/tests/config/test_database.py
@@ -20,7 +20,7 @@ from tests import unittest
class DatabaseConfigTestCase(unittest.TestCase):
- def test_database_configured_correctly(self):
+ def test_database_configured_correctly(self) -> None:
conf = yaml.safe_load(
DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
)
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index fdfbb0e38e..3a02366932 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -25,14 +25,14 @@ from tests import unittest
class ConfigGenerationTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.dir = tempfile.mkdtemp()
self.file = os.path.join(self.dir, "homeserver.yaml")
- def tearDown(self):
+ def tearDown(self) -> None:
shutil.rmtree(self.dir)
- def test_generate_config_generates_files(self):
+ def test_generate_config_generates_files(self) -> None:
with redirect_stdout(StringIO()):
HomeServerConfig.load_or_generate_config(
"",
@@ -56,7 +56,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
os.path.join(os.getcwd(), "homeserver.log"),
)
- def assert_log_filename_is(self, log_config_file, expected):
+ def assert_log_filename_is(self, log_config_file: str, expected: str) -> None:
with open(log_config_file) as f:
config = f.read()
# find the 'filename' line
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 69a4e9413b..fcbe79cc7a 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -21,14 +21,14 @@ from tests.config.utils import ConfigFileTestCase
class ConfigLoadingFileTestCase(ConfigFileTestCase):
- def test_load_fails_if_server_name_missing(self):
+ def test_load_fails_if_server_name_missing(self) -> None:
self.generate_config_and_remove_lines_containing("server_name")
with self.assertRaises(ConfigError):
HomeServerConfig.load_config("", ["-c", self.config_file])
with self.assertRaises(ConfigError):
HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
- def test_generates_and_loads_macaroon_secret_key(self):
+ def test_generates_and_loads_macaroon_secret_key(self) -> None:
self.generate_config()
with open(self.config_file) as f:
@@ -58,7 +58,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
"was: %r" % (config2.key.macaroon_secret_key,)
)
- def test_load_succeeds_if_macaroon_secret_key_missing(self):
+ def test_load_succeeds_if_macaroon_secret_key_missing(self) -> None:
self.generate_config_and_remove_lines_containing("macaroon")
config1 = HomeServerConfig.load_config("", ["-c", self.config_file])
config2 = HomeServerConfig.load_config("", ["-c", self.config_file])
@@ -73,7 +73,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config1.key.macaroon_secret_key, config3.key.macaroon_secret_key
)
- def test_disable_registration(self):
+ def test_disable_registration(self) -> None:
self.generate_config()
self.add_lines_to_config(
["enable_registration: true", "disable_registration: true"]
@@ -93,7 +93,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
assert config3 is not None
self.assertTrue(config3.registration.enable_registration)
- def test_stats_enabled(self):
+ def test_stats_enabled(self) -> None:
self.generate_config_and_remove_lines_containing("enable_metrics")
self.add_lines_to_config(["enable_metrics: true"])
@@ -101,7 +101,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config = HomeServerConfig.load_config("", ["-c", self.config_file])
self.assertFalse(config.metrics.metrics_flags.known_servers)
- def test_depreciated_identity_server_flag_throws_error(self):
+ def test_depreciated_identity_server_flag_throws_error(self) -> None:
self.generate_config()
# Needed to ensure that actual key/value pair added below don't end up on a line with a comment
self.add_lines_to_config([" "])
diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py
index 1b63e1adfd..f12147eaa0 100644
--- a/tests/config/test_ratelimiting.py
+++ b/tests/config/test_ratelimiting.py
@@ -18,7 +18,7 @@ from tests.utils import default_config
class RatelimitConfigTestCase(TestCase):
- def test_parse_rc_federation(self):
+ def test_parse_rc_federation(self) -> None:
config_dict = default_config("test")
config_dict["rc_federation"] = {
"window_size": 20000,
diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py
index 33d7b70e32..f6869d7f06 100644
--- a/tests/config/test_registration_config.py
+++ b/tests/config/test_registration_config.py
@@ -21,7 +21,7 @@ from tests.utils import default_config
class RegistrationConfigTestCase(ConfigFileTestCase):
- def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
+ def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self) -> None:
"""
session_lifetime should logically be larger than, or at least as large as,
all the different token lifetimes.
@@ -91,7 +91,7 @@ class RegistrationConfigTestCase(ConfigFileTestCase):
"",
)
- def test_refuse_to_start_if_open_registration_and_no_verification(self):
+ def test_refuse_to_start_if_open_registration_and_no_verification(self) -> None:
self.generate_config()
self.add_lines_to_config(
[
diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
index db745815ef..297ab37792 100644
--- a/tests/config/test_room_directory.py
+++ b/tests/config/test_room_directory.py
@@ -20,7 +20,7 @@ from tests import unittest
class RoomDirectoryConfigTestCase(unittest.TestCase):
- def test_alias_creation_acl(self):
+ def test_alias_creation_acl(self) -> None:
config = yaml.safe_load(
"""
alias_creation_rules:
@@ -78,7 +78,7 @@ class RoomDirectoryConfigTestCase(unittest.TestCase):
)
)
- def test_room_publish_acl(self):
+ def test_room_publish_acl(self) -> None:
config = yaml.safe_load(
"""
alias_creation_rules: []
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index 1f27a54701..41a3fb0b6d 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -21,7 +21,7 @@ from tests import unittest
class ServerConfigTestCase(unittest.TestCase):
- def test_is_threepid_reserved(self):
+ def test_is_threepid_reserved(self) -> None:
user1 = {"medium": "email", "address": "user1@example.com"}
user2 = {"medium": "email", "address": "user2@example.com"}
user3 = {"medium": "email", "address": "user3@example.com"}
@@ -32,7 +32,7 @@ class ServerConfigTestCase(unittest.TestCase):
self.assertFalse(is_threepid_reserved(config, user3))
self.assertFalse(is_threepid_reserved(config, user1_msisdn))
- def test_unsecure_listener_no_listeners_open_private_ports_false(self):
+ def test_unsecure_listener_no_listeners_open_private_ports_false(self) -> None:
conf = yaml.safe_load(
ServerConfig().generate_config_section(
"CONFDIR", "/data_dir_path", "che.org", False, None
@@ -52,7 +52,7 @@ class ServerConfigTestCase(unittest.TestCase):
self.assertEqual(conf["listeners"], expected_listeners)
- def test_unsecure_listener_no_listeners_open_private_ports_true(self):
+ def test_unsecure_listener_no_listeners_open_private_ports_true(self) -> None:
conf = yaml.safe_load(
ServerConfig().generate_config_section(
"CONFDIR", "/data_dir_path", "che.org", True, None
@@ -71,7 +71,7 @@ class ServerConfigTestCase(unittest.TestCase):
self.assertEqual(conf["listeners"], expected_listeners)
- def test_listeners_set_correctly_open_private_ports_false(self):
+ def test_listeners_set_correctly_open_private_ports_false(self) -> None:
listeners = [
{
"port": 8448,
@@ -95,7 +95,7 @@ class ServerConfigTestCase(unittest.TestCase):
self.assertEqual(conf["listeners"], listeners)
- def test_listeners_set_correctly_open_private_ports_true(self):
+ def test_listeners_set_correctly_open_private_ports_true(self) -> None:
listeners = [
{
"port": 8448,
@@ -131,14 +131,14 @@ class ServerConfigTestCase(unittest.TestCase):
class GenerateIpSetTestCase(unittest.TestCase):
- def test_empty(self):
+ def test_empty(self) -> None:
ip_set = generate_ip_set(())
self.assertFalse(ip_set)
ip_set = generate_ip_set((), ())
self.assertFalse(ip_set)
- def test_generate(self):
+ def test_generate(self) -> None:
"""Check adding IPv4 and IPv6 addresses."""
# IPv4 address
ip_set = generate_ip_set(("1.2.3.4",))
@@ -160,7 +160,7 @@ class GenerateIpSetTestCase(unittest.TestCase):
ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
self.assertEqual(len(ip_set.iter_cidrs()), 4)
- def test_extra(self):
+ def test_extra(self) -> None:
"""Extra IP addresses are treated the same."""
ip_set = generate_ip_set((), ("1.2.3.4",))
self.assertEqual(len(ip_set.iter_cidrs()), 4)
@@ -172,7 +172,7 @@ class GenerateIpSetTestCase(unittest.TestCase):
ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
self.assertEqual(len(ip_set.iter_cidrs()), 4)
- def test_bad_value(self):
+ def test_bad_value(self) -> None:
"""An error should be raised if a bad value is passed in."""
with self.assertRaises(ConfigError):
generate_ip_set(("not-an-ip",))
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index 9ba5781573..7510fc4643 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -13,13 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import cast
+
import idna
from OpenSSL import SSL
from synapse.config._base import Config, RootConfig
+from synapse.config.homeserver import HomeServerConfig
from synapse.config.tls import ConfigError, TlsConfig
-from synapse.crypto.context_factory import FederationPolicyForHTTPS
+from synapse.crypto.context_factory import (
+ FederationPolicyForHTTPS,
+ SSLClientConnectionCreator,
+)
+from synapse.types import JsonDict
from tests.unittest import TestCase
@@ -27,7 +34,7 @@ from tests.unittest import TestCase
class FakeServer(Config):
section = "server"
- def has_tls_listener(self):
+ def has_tls_listener(self) -> bool:
return False
@@ -36,21 +43,21 @@ class TestConfig(RootConfig):
class TLSConfigTests(TestCase):
- def test_tls_client_minimum_default(self):
+ def test_tls_client_minimum_default(self) -> None:
"""
The default client TLS version is 1.0.
"""
- config = {}
+ config: JsonDict = {}
t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
- def test_tls_client_minimum_set(self):
+ def test_tls_client_minimum_set(self) -> None:
"""
The default client TLS version can be set to 1.0, 1.1, and 1.2.
"""
- config = {"federation_client_minimum_tls_version": 1}
+ config: JsonDict = {"federation_client_minimum_tls_version": 1}
t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
@@ -76,7 +83,7 @@ class TLSConfigTests(TestCase):
t.tls.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
- def test_tls_client_minimum_1_point_3_missing(self):
+ def test_tls_client_minimum_1_point_3_missing(self) -> None:
"""
If TLS 1.3 support is missing and it's configured, it will raise a
ConfigError.
@@ -88,7 +95,7 @@ class TLSConfigTests(TestCase):
self.addCleanup(setattr, SSL, "SSL.OP_NO_TLSv1_3", OP_NO_TLSv1_3)
assert not hasattr(SSL, "OP_NO_TLSv1_3")
- config = {"federation_client_minimum_tls_version": 1.3}
+ config: JsonDict = {"federation_client_minimum_tls_version": 1.3}
t = TestConfig()
with self.assertRaises(ConfigError) as e:
t.tls.read_config(config, config_dir_path="", data_dir_path="")
@@ -100,7 +107,7 @@ class TLSConfigTests(TestCase):
),
)
- def test_tls_client_minimum_1_point_3_exists(self):
+ def test_tls_client_minimum_1_point_3_exists(self) -> None:
"""
If TLS 1.3 support exists and it's configured, it will be settable.
"""
@@ -110,20 +117,20 @@ class TLSConfigTests(TestCase):
self.addCleanup(lambda: delattr(SSL, "OP_NO_TLSv1_3"))
assert hasattr(SSL, "OP_NO_TLSv1_3")
- config = {"federation_client_minimum_tls_version": 1.3}
+ config: JsonDict = {"federation_client_minimum_tls_version": 1.3}
t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3")
- def test_tls_client_minimum_set_passed_through_1_2(self):
+ def test_tls_client_minimum_set_passed_through_1_2(self) -> None:
"""
The configured TLS version is correctly configured by the ContextFactory.
"""
- config = {"federation_client_minimum_tls_version": 1.2}
+ config: JsonDict = {"federation_client_minimum_tls_version": 1.2}
t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="")
- cf = FederationPolicyForHTTPS(t)
+ cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
options = _get_ssl_context_options(cf._verify_ssl_context)
# The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
@@ -131,15 +138,15 @@ class TLSConfigTests(TestCase):
self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
- def test_tls_client_minimum_set_passed_through_1_0(self):
+ def test_tls_client_minimum_set_passed_through_1_0(self) -> None:
"""
The configured TLS version is correctly configured by the ContextFactory.
"""
- config = {"federation_client_minimum_tls_version": 1}
+ config: JsonDict = {"federation_client_minimum_tls_version": 1}
t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="")
- cf = FederationPolicyForHTTPS(t)
+ cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
options = _get_ssl_context_options(cf._verify_ssl_context)
# The context has not had any of the NO_TLS set.
@@ -147,11 +154,11 @@ class TLSConfigTests(TestCase):
self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
- def test_whitelist_idna_failure(self):
+ def test_whitelist_idna_failure(self) -> None:
"""
The federation certificate whitelist will not allow IDNA domain names.
"""
- config = {
+ config: JsonDict = {
"federation_certificate_verification_whitelist": [
"example.com",
"*.ドメイン.テスト",
@@ -163,11 +170,11 @@ class TLSConfigTests(TestCase):
)
self.assertIn("IDNA domain names", str(e))
- def test_whitelist_idna_result(self):
+ def test_whitelist_idna_result(self) -> None:
"""
The federation certificate whitelist will match on IDNA encoded names.
"""
- config = {
+ config: JsonDict = {
"federation_certificate_verification_whitelist": [
"example.com",
"*.xn--eckwd4c7c.xn--zckzah",
@@ -176,14 +183,16 @@ class TLSConfigTests(TestCase):
t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="")
- cf = FederationPolicyForHTTPS(t)
+ cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
# Not in the whitelist
opts = cf.get_options(b"notexample.com")
+ assert isinstance(opts, SSLClientConnectionCreator)
self.assertTrue(opts._verifier._verify_certs)
# Caught by the wildcard
opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
+ assert isinstance(opts, SSLClientConnectionCreator)
self.assertFalse(opts._verifier._verify_certs)
@@ -191,4 +200,4 @@ def _get_ssl_context_options(ssl_context: SSL.Context) -> int:
"""get the options bits from an openssl context object"""
# the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to
# use the low-level interface
- return SSL._lib.SSL_CTX_get_options(ssl_context._context)
+ return SSL._lib.SSL_CTX_get_options(ssl_context._context) # type: ignore[attr-defined]
diff --git a/tests/config/test_util.py b/tests/config/test_util.py
index 3d4929daac..7073654832 100644
--- a/tests/config/test_util.py
+++ b/tests/config/test_util.py
@@ -21,7 +21,7 @@ from tests.unittest import TestCase
class ValidateConfigTestCase(TestCase):
"""Test cases for synapse.config._util.validate_config"""
- def test_bad_object_in_array(self):
+ def test_bad_object_in_array(self) -> None:
"""malformed objects within an array should be validated correctly"""
# consider a structure:
diff --git a/tests/config/utils.py b/tests/config/utils.py
index 94c18a052b..4c0e8a064a 100644
--- a/tests/config/utils.py
+++ b/tests/config/utils.py
@@ -17,19 +17,20 @@ import tempfile
import unittest
from contextlib import redirect_stdout
from io import StringIO
+from typing import List
from synapse.config.homeserver import HomeServerConfig
class ConfigFileTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.dir = tempfile.mkdtemp()
self.config_file = os.path.join(self.dir, "homeserver.yaml")
- def tearDown(self):
+ def tearDown(self) -> None:
shutil.rmtree(self.dir)
- def generate_config(self):
+ def generate_config(self) -> None:
with redirect_stdout(StringIO()):
HomeServerConfig.load_or_generate_config(
"",
@@ -43,7 +44,7 @@ class ConfigFileTestCase(unittest.TestCase):
],
)
- def generate_config_and_remove_lines_containing(self, needle):
+ def generate_config_and_remove_lines_containing(self, needle: str) -> None:
self.generate_config()
with open(self.config_file) as f:
@@ -52,7 +53,7 @@ class ConfigFileTestCase(unittest.TestCase):
with open(self.config_file, "w") as f:
f.write("".join(contents))
- def add_lines_to_config(self, lines):
+ def add_lines_to_config(self, lines: List[str]) -> None:
with open(self.config_file, "a") as f:
for line in lines:
f.write(line + "\n")
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index 8fa710c9dc..2b0972eee8 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -33,12 +33,12 @@ HOSTNAME = "domain"
class EventSigningTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.signing_key: SigningKey = decode_signing_key_base64(
KEY_ALG, KEY_VER, SIGNING_KEY_SEED
)
- def test_sign_minimal(self):
+ def test_sign_minimal(self) -> None:
event_dict = {
"event_id": "$0:domain",
"origin": "domain",
@@ -69,7 +69,7 @@ class EventSigningTestCase(unittest.TestCase):
"aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA",
)
- def test_sign_message(self):
+ def test_sign_message(self) -> None:
event_dict = {
"content": {"body": "Here is the message content"},
"event_id": "$0:domain",
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 63628aa6b0..0e8af2da54 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
-from typing import Dict, List
+from typing import Any, Dict, List, Optional, cast
from unittest.mock import Mock
import attr
@@ -20,10 +20,11 @@ import canonicaljson
import signedjson.key
import signedjson.sign
from signedjson.key import encode_verify_key_base64, get_verify_key
-from signedjson.types import SigningKey
+from signedjson.types import SigningKey, VerifyKey
from twisted.internet import defer
from twisted.internet.defer import Deferred, ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
@@ -33,11 +34,15 @@ from synapse.crypto.keyring import (
StoreKeyFetcher,
)
from synapse.logging.context import (
+ ContextRequest,
LoggingContext,
current_context,
make_deferred_yieldable,
)
+from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -45,15 +50,15 @@ from tests.unittest import logcontext_clean, override_config
class MockPerspectiveServer:
- def __init__(self):
+ def __init__(self) -> None:
self.server_name = "mock_server"
- self.key = signedjson.key.generate_signing_key(0)
+ self.key = signedjson.key.generate_signing_key("0")
- def get_verify_keys(self):
+ def get_verify_keys(self) -> Dict[str, str]:
vk = signedjson.key.get_verify_key(self.key)
return {"%s:%s" % (vk.alg, vk.version): encode_verify_key_base64(vk)}
- def get_signed_key(self, server_name, verify_key):
+ def get_signed_key(self, server_name: str, verify_key: VerifyKey) -> JsonDict:
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
res = {
"server_name": server_name,
@@ -64,34 +69,36 @@ class MockPerspectiveServer:
self.sign_response(res)
return res
- def sign_response(self, res):
+ def sign_response(self, res: JsonDict) -> None:
signedjson.sign.sign_json(res, self.server_name, self.key)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class FakeRequest:
- id = attr.ib()
+ id: str
@logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase):
- def check_context(self, val, expected):
+ def check_context(
+ self, val: ContextRequest, expected: Optional[ContextRequest]
+ ) -> ContextRequest:
self.assertEqual(getattr(current_context(), "request", None), expected)
return val
- def test_verify_json_objects_for_server_awaits_previous_requests(self):
+ def test_verify_json_objects_for_server_awaits_previous_requests(self) -> None:
mock_fetcher = Mock()
mock_fetcher.get_keys = Mock()
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
# a signed object that we are going to try to validate
- key1 = signedjson.key.generate_signing_key(1)
- json1 = {}
+ key1 = signedjson.key.generate_signing_key("1")
+ json1: JsonDict = {}
signedjson.sign.sign_json(json1, "server10", key1)
# start off a first set of lookups. We make the mock fetcher block until this
# deferred completes.
- first_lookup_deferred = Deferred()
+ first_lookup_deferred: "Deferred[None]" = Deferred()
async def first_lookup_fetch(
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
@@ -106,8 +113,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher.get_keys.side_effect = first_lookup_fetch
- async def first_lookup():
- with LoggingContext("context_11", request=FakeRequest("context_11")):
+ async def first_lookup() -> None:
+ with LoggingContext(
+ "context_11", request=cast(ContextRequest, FakeRequest("context_11"))
+ ):
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0), ("server11", {}, 0)]
)
@@ -144,8 +153,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher.get_keys.side_effect = second_lookup_fetch
second_lookup_state = [0]
- async def second_lookup():
- with LoggingContext("context_12", request=FakeRequest("context_12")):
+ async def second_lookup() -> None:
+ with LoggingContext(
+ "context_12", request=cast(ContextRequest, FakeRequest("context_12"))
+ ):
res_deferreds_2 = kr.verify_json_objects_for_server(
[
(
@@ -175,10 +186,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.get_success(d0)
self.get_success(d2)
- def test_verify_json_for_server(self):
+ def test_verify_json_for_server(self) -> None:
kr = keyring.Keyring(self.hs)
- key1 = signedjson.key.generate_signing_key(1)
+ key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
time.time() * 1000,
@@ -186,7 +197,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
self.get_success(r)
- json1 = {}
+ json1: JsonDict = {}
signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object
@@ -198,12 +209,12 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# self.assertFalse(d.called)
self.get_success(d)
- def test_verify_for_local_server(self):
+ def test_verify_for_local_server(self) -> None:
"""Ensure that locally signed JSON can be verified without fetching keys
over federation
"""
kr = keyring.Keyring(self.hs)
- json1 = {}
+ json1: JsonDict = {}
signedjson.sign.sign_json(json1, self.hs.hostname, self.hs.signing_key)
# Test that verify_json_for_server succeeds on a object signed by ourselves
@@ -216,22 +227,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
{
"old_signing_keys": {
f"{OLD_KEY.alg}:{OLD_KEY.version}": {
- "key": encode_verify_key_base64(OLD_KEY.verify_key),
+ "key": encode_verify_key_base64(
+ signedjson.key.get_verify_key(OLD_KEY)
+ ),
"expired_ts": 1000,
}
}
}
)
- def test_verify_for_local_server_old_key(self):
+ def test_verify_for_local_server_old_key(self) -> None:
"""Can also use keys in old_signing_keys for verification"""
- json1 = {}
+ json1: JsonDict = {}
signedjson.sign.sign_json(json1, self.hs.hostname, self.OLD_KEY)
kr = keyring.Keyring(self.hs)
d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
self.get_success(d)
- def test_verify_for_local_server_unknown_key(self):
+ def test_verify_for_local_server_unknown_key(self) -> None:
"""Local keys that we no longer have should be fetched via the fetcher"""
# the key we'll sign things with (nb, not known to the Keyring)
@@ -253,14 +266,14 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
# sign the json
- json1 = {}
+ json1: JsonDict = {}
signedjson.sign.sign_json(json1, self.hs.hostname, key2)
# ... and check we can verify it.
d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
self.get_success(d)
- def test_verify_json_for_server_with_null_valid_until_ms(self):
+ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
"""Tests that we correctly handle key requests for keys we've stored
with a null `ts_valid_until_ms`
"""
@@ -271,15 +284,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
)
- key1 = signedjson.key.generate_signing_key(1)
+ key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
time.time() * 1000,
- [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
+ # None is not a valid value in FetchKeyResult, but we're abusing this
+ # API to insert null values into the database. The nulls get converted
+ # to 0 when fetched in KeyStore.get_server_verify_keys.
+ [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], # type: ignore[arg-type]
)
self.get_success(r)
- json1 = {}
+ json1: JsonDict = {}
signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object
@@ -304,9 +320,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- def test_verify_json_dedupes_key_requests(self):
+ def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped."""
- key1 = signedjson.key.generate_signing_key(1)
+ key1 = signedjson.key.generate_signing_key("1")
async def get_keys(
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
@@ -322,7 +338,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher.get_keys = Mock(side_effect=get_keys)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
- json1 = {}
+ json1: JsonDict = {}
signedjson.sign.sign_json(json1, "server1", key1)
# the first request should succeed; the second should fail because the key
@@ -346,9 +362,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# there should have been a single call to the fetcher
mock_fetcher.get_keys.assert_called_once()
- def test_verify_json_falls_back_to_other_fetchers(self):
+ def test_verify_json_falls_back_to_other_fetchers(self) -> None:
"""If the first fetcher cannot provide a recent enough key, we fall back"""
- key1 = signedjson.key.generate_signing_key(1)
+ key1 = signedjson.key.generate_signing_key("1")
async def get_keys1(
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
@@ -372,7 +388,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
- json1 = {}
+ json1: JsonDict = {}
signedjson.sign.sign_json(json1, "server1", key1)
results = kr.verify_json_objects_for_server(
@@ -402,12 +418,12 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@logcontext_clean
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock()
hs = self.setup_test_homeserver(federation_http_client=self.http_client)
return hs
- def test_get_keys_from_server(self):
+ def test_get_keys_from_server(self) -> None:
# arbitrarily advance the clock a bit
self.reactor.advance(100)
@@ -431,9 +447,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
}
signedjson.sign.sign_json(response, SERVER_NAME, testkey)
- async def get_json(destination, path, **kwargs):
+ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict:
self.assertEqual(destination, SERVER_NAME)
- self.assertEqual(path, "/_matrix/key/v2/server/key1")
+ self.assertEqual(path, "/_matrix/key/v2/server")
return response
self.http_client.get_json.side_effect = get_json
@@ -469,21 +485,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
self.assertEqual(keys, {})
- def test_keyid_containing_forward_slash(self) -> None:
- """We should url-encode any url unsafe chars in key ids.
-
- Detects https://github.com/matrix-org/synapse/issues/14488.
- """
- fetcher = ServerKeyFetcher(self.hs)
- self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0))
-
- self.http_client.get_json.assert_called_once()
- args, kwargs = self.http_client.get_json.call_args
- self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato")
-
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock()
@@ -534,7 +538,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect a perspectives-server key query
"""
- async def post_json(destination, path, data, **kwargs):
+ async def post_json(
+ destination: str, path: str, data: JsonDict, **kwargs: Any
+ ) -> JsonDict:
self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query")
@@ -545,7 +551,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json
- def test_get_keys_from_perspectives(self):
+ def test_get_keys_from_perspectives(self) -> None:
# arbitrarily advance the clock a bit
self.reactor.advance(100)
@@ -590,7 +596,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
- def test_get_multiple_keys_from_perspectives(self):
+ def test_get_multiple_keys_from_perspectives(self) -> None:
"""Check that we can correctly request multiple keys for the same server"""
fetcher = PerspectivesKeyFetcher(self.hs)
@@ -618,7 +624,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
VALID_UNTIL_TS,
)
- async def post_json(destination, path, data, **kwargs):
+ async def post_json(
+ destination: str, path: str, data: JsonDict, **kwargs: str
+ ) -> JsonDict:
self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query")
@@ -660,7 +668,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# finally, ensure that only one request was sent
self.assertEqual(self.http_client.post_json.call_count, 1)
- def test_get_perspectives_own_key(self):
+ def test_get_perspectives_own_key(self) -> None:
"""Check that we can get the perspectives server's own keys
This is slightly complicated by the fact that the perspectives server may
@@ -709,7 +717,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
- def test_invalid_perspectives_responses(self):
+ def test_invalid_perspectives_responses(self) -> None:
"""Check that invalid responses from the perspectives server are rejected"""
# arbitrarily advance the clock a bit
self.reactor.advance(100)
@@ -720,12 +728,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
testverifykey_id = "ed25519:ver1"
VALID_UNTIL_TS = 200 * 1000
- def build_response():
+ def build_response() -> dict:
return self.build_perspectives_response(
SERVER_NAME, testkey, VALID_UNTIL_TS
)
- def get_key_from_perspectives(response):
+ def get_key_from_perspectives(response: JsonDict) -> Dict[str, FetchKeyResult]:
fetcher = PerspectivesKeyFetcher(self.hs)
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
return self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
@@ -749,6 +757,6 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
-def get_key_id(key):
+def get_key_id(key: SigningKey) -> str:
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
return "%s:%s" % (key.alg, key.version)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 685a9a6d52..a9893def74 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -16,6 +16,8 @@ from unittest.mock import Mock
import attr
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EduTypes
from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
from synapse.federation.units import Transaction
@@ -23,11 +25,13 @@ from synapse.handlers.presence import UserPresenceState
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, presence, room
+from synapse.server import HomeServer
from synapse.types import JsonDict, StreamToken, create_requester
+from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock
-from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+from tests.unittest import FederatingHomeserverTestCase, override_config
@attr.s
@@ -49,9 +53,7 @@ class LegacyPresenceRouterTestModule:
}
return users_to_state
- async def get_interested_users(
- self, user_id: str
- ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+ async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
if user_id in self._config.users_who_should_receive_all_presence:
return PresenceRouter.ALL_USERS
@@ -71,9 +73,14 @@ class LegacyPresenceRouterTestModule:
# Initialise a typed config object
config = PresenceRouterTestConfig()
- config.users_who_should_receive_all_presence = config_dict.get(
+ users_who_should_receive_all_presence = config_dict.get(
"users_who_should_receive_all_presence"
)
+ assert isinstance(users_who_should_receive_all_presence, list)
+
+ config.users_who_should_receive_all_presence = (
+ users_who_should_receive_all_presence
+ )
return config
@@ -96,9 +103,7 @@ class PresenceRouterTestModule:
}
return users_to_state
- async def get_interested_users(
- self, user_id: str
- ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+ async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
if user_id in self._config.users_who_should_receive_all_presence:
return PresenceRouter.ALL_USERS
@@ -118,14 +123,26 @@ class PresenceRouterTestModule:
# Initialise a typed config object
config = PresenceRouterTestConfig()
- config.users_who_should_receive_all_presence = config_dict.get(
+ users_who_should_receive_all_presence = config_dict.get(
"users_who_should_receive_all_presence"
)
+ assert isinstance(users_who_should_receive_all_presence, list)
+
+ config.users_who_should_receive_all_presence = (
+ users_who_should_receive_all_presence
+ )
return config
class PresenceRouterTestCase(FederatingHomeserverTestCase):
+ """
+ Test cases using a custom PresenceRouter
+
+ By default in test cases, federation sending is disabled. This class re-enables it
+ for the main process by setting `federation_sender_instances` to None.
+ """
+
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -133,7 +150,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
presence.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({})
@@ -146,10 +163,17 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
return hs
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.sync_handler = self.hs.get_sync_handler()
self.module_api = homeserver.get_module_api()
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["federation_sender_instances"] = None
+ return config
+
@override_config(
{
"presence": {
@@ -162,10 +186,9 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
}
},
- "send_federation": True,
}
)
- def test_receiving_all_presence_legacy(self):
+ def test_receiving_all_presence_legacy(self) -> None:
self.receiving_all_presence_test_body()
@override_config(
@@ -180,13 +203,12 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
},
],
- "send_federation": True,
}
)
- def test_receiving_all_presence(self):
+ def test_receiving_all_presence(self) -> None:
self.receiving_all_presence_test_body()
- def receiving_all_presence_test_body(self):
+ def receiving_all_presence_test_body(self) -> None:
"""Test that a user that does not share a room with another other can receive
presence for them, due to presence routing.
"""
@@ -290,10 +312,9 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
}
},
- "send_federation": True,
}
)
- def test_send_local_online_presence_to_with_module_legacy(self):
+ def test_send_local_online_presence_to_with_module_legacy(self) -> None:
self.send_local_online_presence_to_with_module_test_body()
@override_config(
@@ -310,13 +331,12 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
},
],
- "send_federation": True,
}
)
- def test_send_local_online_presence_to_with_module(self):
+ def test_send_local_online_presence_to_with_module(self) -> None:
self.send_local_online_presence_to_with_module_test_body()
- def send_local_online_presence_to_with_module_test_body(self):
+ def send_local_online_presence_to_with_module_test_body(self) -> None:
"""Tests that send_local_presence_to_users sends local online presence to a set
of specified local and remote users, with a custom PresenceRouter module enabled.
"""
@@ -439,18 +459,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
continue
# EDUs can contain multiple presence updates
- for presence_update in edu["content"]["push"]:
+ for presence_edu in edu["content"]["push"]:
# Check for presence updates that contain the user IDs we're after
- found_users.add(presence_update["user_id"])
+ found_users.add(presence_edu["user_id"])
# Ensure that no offline states are being sent out
- self.assertNotEqual(presence_update["presence"], "offline")
+ self.assertNotEqual(presence_edu["presence"], "offline")
self.assertEqual(found_users, expected_users)
def send_presence_update(
- testcase: TestCase,
+ testcase: FederatingHomeserverTestCase,
user_id: str,
access_token: str,
presence_state: str,
@@ -471,7 +491,7 @@ def send_presence_update(
def sync_presence(
- testcase: TestCase,
+ testcase: FederatingHomeserverTestCase,
user_id: str,
since_token: Optional[StreamToken] = None,
) -> Tuple[List[UserPresenceState], StreamToken]:
@@ -492,7 +512,7 @@ def sync_presence(
requester = create_requester(user_id)
sync_config = generate_sync_config(requester.user.to_string())
sync_result = testcase.get_success(
- testcase.sync_handler.wait_for_sync_for_user(
+ testcase.hs.get_sync_handler().wait_for_sync_for_user(
requester, sync_config, since_token
)
)
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 8ddce83b83..6687c28e8f 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -12,9 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
from tests.test_utils.event_injection import create_event
@@ -27,7 +32,7 @@ class TestEventContext(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
@@ -35,7 +40,7 @@ class TestEventContext(unittest.HomeserverTestCase):
self.user_tok = self.login("u1", "pass")
self.room_id = self.helper.create_room_as(tok=self.user_tok)
- def test_serialize_deserialize_msg(self):
+ def test_serialize_deserialize_msg(self) -> None:
"""Test that an EventContext for a message event is the same after
serialize/deserialize.
"""
@@ -51,7 +56,7 @@ class TestEventContext(unittest.HomeserverTestCase):
self._check_serialize_deserialize(event, context)
- def test_serialize_deserialize_state_no_prev(self):
+ def test_serialize_deserialize_state_no_prev(self) -> None:
"""Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize.
"""
@@ -67,7 +72,7 @@ class TestEventContext(unittest.HomeserverTestCase):
self._check_serialize_deserialize(event, context)
- def test_serialize_deserialize_state_prev(self):
+ def test_serialize_deserialize_state_prev(self) -> None:
"""Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize.
"""
@@ -84,7 +89,9 @@ class TestEventContext(unittest.HomeserverTestCase):
self._check_serialize_deserialize(event, context)
- def _check_serialize_deserialize(self, event, context):
+ def _check_serialize_deserialize(
+ self, event: EventBase, context: EventContext
+ ) -> None:
serialized = self.get_success(context.serialize(event, self.store))
d_context = EventContext.deserialize(self._storage_controllers, serialized)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index b1c47efac7..4174a237ec 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -12,30 +12,60 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import unittest as stdlib_unittest
+from typing import Any, List, Mapping, Optional
+
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import (
+ PowerLevelsContent,
SerializeEventConfig,
copy_and_fixup_power_levels_contents,
+ maybe_upsert_event_field,
prune_event,
serialize_event,
)
+from synapse.types import JsonDict
from synapse.util.frozenutils import freeze
-from tests import unittest
-
-def MockEvent(**kwargs):
+def MockEvent(**kwargs: Any) -> EventBase:
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
+ if "content" not in kwargs:
+ kwargs["content"] = {}
return make_event_from_dict(kwargs)
-class PruneEventTestCase(unittest.TestCase):
- def run_test(self, evdict, matchdict, **kwargs):
+class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
+ def test_update_okay(self) -> None:
+ event = make_event_from_dict({"event_id": "$1234"})
+ success = maybe_upsert_event_field(event, event.unsigned, "key", "value")
+ self.assertTrue(success)
+ self.assertEqual(event.unsigned["key"], "value")
+
+ def test_update_not_okay(self) -> None:
+ event = make_event_from_dict({"event_id": "$1234"})
+ LARGE_STRING = "a" * 100_000
+ success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
+ self.assertFalse(success)
+ self.assertNotIn("key", event.unsigned)
+
+ def test_update_not_okay_leaves_original_value(self) -> None:
+ event = make_event_from_dict(
+ {"event_id": "$1234", "unsigned": {"key": "value"}}
+ )
+ LARGE_STRING = "a" * 100_000
+ success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
+ self.assertFalse(success)
+ self.assertEqual(event.unsigned["key"], "value")
+
+
+class PruneEventTestCase(stdlib_unittest.TestCase):
+ def run_test(self, evdict: JsonDict, matchdict: JsonDict, **kwargs: Any) -> None:
"""
Asserts that a new event constructed with `evdict` will look like
`matchdict` when it is redacted.
@@ -49,7 +79,7 @@ class PruneEventTestCase(unittest.TestCase):
prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
)
- def test_minimal(self):
+ def test_minimal(self) -> None:
self.run_test(
{"type": "A", "event_id": "$test:domain"},
{
@@ -61,7 +91,7 @@ class PruneEventTestCase(unittest.TestCase):
},
)
- def test_basic_keys(self):
+ def test_basic_keys(self) -> None:
"""Ensure that the keys that should be untouched are kept."""
# Note that some of the values below don't really make sense, but the
# pruning of events doesn't worry about the values of any fields (with
@@ -113,7 +143,7 @@ class PruneEventTestCase(unittest.TestCase):
room_version=RoomVersions.MSC2176,
)
- def test_unsigned(self):
+ def test_unsigned(self) -> None:
"""Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
self.run_test(
{
@@ -134,7 +164,7 @@ class PruneEventTestCase(unittest.TestCase):
},
)
- def test_content(self):
+ def test_content(self) -> None:
"""The content dictionary should be stripped in most cases."""
self.run_test(
{"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
@@ -169,7 +199,7 @@ class PruneEventTestCase(unittest.TestCase):
},
)
- def test_create(self):
+ def test_create(self) -> None:
"""Create events are partially redacted until MSC2176."""
self.run_test(
{
@@ -198,7 +228,7 @@ class PruneEventTestCase(unittest.TestCase):
room_version=RoomVersions.MSC2176,
)
- def test_power_levels(self):
+ def test_power_levels(self) -> None:
"""Power level events keep a variety of content keys."""
self.run_test(
{
@@ -248,7 +278,7 @@ class PruneEventTestCase(unittest.TestCase):
room_version=RoomVersions.MSC2176,
)
- def test_alias_event(self):
+ def test_alias_event(self) -> None:
"""Alias events have special behavior up through room version 6."""
self.run_test(
{
@@ -277,7 +307,7 @@ class PruneEventTestCase(unittest.TestCase):
room_version=RoomVersions.V6,
)
- def test_redacts(self):
+ def test_redacts(self) -> None:
"""Redaction events have no special behaviour until MSC2174/MSC2176."""
self.run_test(
@@ -303,7 +333,7 @@ class PruneEventTestCase(unittest.TestCase):
room_version=RoomVersions.MSC2176,
)
- def test_join_rules(self):
+ def test_join_rules(self) -> None:
"""Join rules events have changed behavior starting with MSC3083."""
self.run_test(
{
@@ -346,7 +376,7 @@ class PruneEventTestCase(unittest.TestCase):
room_version=RoomVersions.V8,
)
- def test_member(self):
+ def test_member(self) -> None:
"""Member events have changed behavior starting with MSC3375."""
self.run_test(
{
@@ -391,13 +421,13 @@ class PruneEventTestCase(unittest.TestCase):
)
-class SerializeEventTestCase(unittest.TestCase):
- def serialize(self, ev, fields):
+class SerializeEventTestCase(stdlib_unittest.TestCase):
+ def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict:
return serialize_event(
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
)
- def test_event_fields_works_with_keys(self):
+ def test_event_fields_works_with_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"]
@@ -405,7 +435,7 @@ class SerializeEventTestCase(unittest.TestCase):
{"room_id": "!foo:bar"},
)
- def test_event_fields_works_with_nested_keys(self):
+ def test_event_fields_works_with_nested_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@@ -418,7 +448,7 @@ class SerializeEventTestCase(unittest.TestCase):
{"content": {"body": "A message"}},
)
- def test_event_fields_works_with_dot_keys(self):
+ def test_event_fields_works_with_dot_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@@ -431,7 +461,7 @@ class SerializeEventTestCase(unittest.TestCase):
{"content": {"key.with.dots": {}}},
)
- def test_event_fields_works_with_nested_dot_keys(self):
+ def test_event_fields_works_with_nested_dot_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@@ -447,7 +477,7 @@ class SerializeEventTestCase(unittest.TestCase):
{"content": {"nested.dot.key": {"leaf.key": 42}}},
)
- def test_event_fields_nops_with_unknown_keys(self):
+ def test_event_fields_nops_with_unknown_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@@ -460,7 +490,7 @@ class SerializeEventTestCase(unittest.TestCase):
{"content": {"foo": "bar"}},
)
- def test_event_fields_nops_with_non_dict_keys(self):
+ def test_event_fields_nops_with_non_dict_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@@ -473,7 +503,7 @@ class SerializeEventTestCase(unittest.TestCase):
{},
)
- def test_event_fields_nops_with_array_keys(self):
+ def test_event_fields_nops_with_array_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@@ -486,7 +516,7 @@ class SerializeEventTestCase(unittest.TestCase):
{},
)
- def test_event_fields_all_fields_if_empty(self):
+ def test_event_fields_all_fields_if_empty(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@@ -506,16 +536,16 @@ class SerializeEventTestCase(unittest.TestCase):
},
)
- def test_event_fields_fail_if_fields_not_str(self):
+ def test_event_fields_fail_if_fields_not_str(self) -> None:
with self.assertRaises(TypeError):
self.serialize(
- MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]
+ MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] # type: ignore[list-item]
)
-class CopyPowerLevelsContentTestCase(unittest.TestCase):
+class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
def setUp(self) -> None:
- self.test_content = {
+ self.test_content: PowerLevelsContent = {
"ban": 50,
"events": {"m.room.name": 100, "m.room.power_levels": 100},
"events_default": 0,
@@ -528,10 +558,11 @@ class CopyPowerLevelsContentTestCase(unittest.TestCase):
"users_default": 0,
}
- def _test(self, input):
+ def _test(self, input: PowerLevelsContent) -> None:
a = copy_and_fixup_power_levels_contents(input)
self.assertEqual(a["ban"], 50)
+ assert isinstance(a["events"], Mapping)
self.assertEqual(a["events"]["m.room.name"], 100)
# make sure that changing the copy changes the copy and not the orig
@@ -539,18 +570,19 @@ class CopyPowerLevelsContentTestCase(unittest.TestCase):
a["events"]["m.room.power_levels"] = 20
self.assertEqual(input["ban"], 50)
+ assert isinstance(input["events"], Mapping)
self.assertEqual(input["events"]["m.room.power_levels"], 100)
- def test_unfrozen(self):
+ def test_unfrozen(self) -> None:
self._test(self.test_content)
- def test_frozen(self):
+ def test_frozen(self) -> None:
input = freeze(self.test_content)
self._test(input)
- def test_stringy_integers(self):
+ def test_stringy_integers(self) -> None:
"""String representations of decimal integers are converted to integers."""
- input = {
+ input: PowerLevelsContent = {
"a": "100",
"b": {
"foo": 99,
@@ -578,9 +610,9 @@ class CopyPowerLevelsContentTestCase(unittest.TestCase):
def test_invalid_types_raise_type_error(self) -> None:
with self.assertRaises(TypeError):
- copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[arg-type]
- copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[arg-type]
+ copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[dict-item]
+ copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[dict-item]
def test_invalid_nesting_raises_type_error(self) -> None:
with self.assertRaises(TypeError):
- copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}})
+ copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) # type: ignore[dict-item]
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 2873b4d430..b8fee72898 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -7,13 +7,21 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.types import JsonDict
from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import FederatingHomeserverTestCase
class FederationCatchUpTestCases(FederatingHomeserverTestCase):
+ """
+ Tests cases of catching up over federation.
+
+ By default for test cases federation sending is disabled. This Test class has it
+ re-enabled for the main process.
+ """
+
servlets = [
admin.register_servlets,
room.register_servlets,
@@ -42,6 +50,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.record_transaction
)
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["federation_sender_instances"] = None
+ return config
+
async def record_transaction(self, txn, json_cb):
if self.is_online:
data = json_cb()
@@ -79,7 +92,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
)[0]
return {"event_id": event_id, "stream_ordering": stream_ordering}
- @override_config({"send_federation": True})
def test_catch_up_destination_rooms_tracking(self):
"""
Tests that we populate the `destination_rooms` table as needed.
@@ -105,7 +117,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertEqual(row_2["event_id"], event_id_2)
self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
- @override_config({"send_federation": True})
def test_catch_up_last_successful_stream_ordering_tracking(self):
"""
Tests that we populate the `destination_rooms` table as needed.
@@ -163,7 +174,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
"Send succeeded but not marked as last_successful_stream_ordering",
)
- @override_config({"send_federation": True}) # critical to federate
def test_catch_up_from_blank_state(self):
"""
Runs an overall test of federation catch-up from scratch.
@@ -260,7 +270,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
return per_dest_queue, results_list
- @override_config({"send_federation": True})
def test_catch_up_loop(self):
"""
Tests the behaviour of _catch_up_transmission_loop.
@@ -325,7 +334,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
event_5.internal_metadata.stream_ordering,
)
- @override_config({"send_federation": True})
def test_catch_up_on_synapse_startup(self):
"""
Tests the behaviour of get_catch_up_outstanding_destinations and
@@ -424,7 +432,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# - all destinations are woken exactly once; they appear once in woken.
self.assertCountEqual(woken, server_names[:-1])
- @override_config({"send_federation": True})
def test_not_latest_event(self):
"""Test that we send the latest event in the room even if its not ours."""
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index f1e357764f..8692d8190f 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -25,10 +25,17 @@ from synapse.rest.client import login
from synapse.types import JsonDict, ReadReceipt
from tests.test_utils import make_awaitable
-from tests.unittest import HomeserverTestCase, override_config
+from tests.unittest import HomeserverTestCase
class FederationSenderReceiptsTestCases(HomeserverTestCase):
+ """
+ Test federation sending to update receipts.
+
+ By default for test cases federation sending is disabled. This Test class has it
+ re-enabled for the main process.
+ """
+
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
@@ -38,9 +45,17 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
return_value=make_awaitable({"test", "host2"})
)
+ hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
+ hs.get_storage_controllers().state.get_current_hosts_in_room
+ )
+
return hs
- @override_config({"send_federation": True})
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["federation_sender_instances"] = None
+ return config
+
def test_send_receipts(self):
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
@@ -83,7 +98,82 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
],
)
- @override_config({"send_federation": True})
+ def test_send_receipts_thread(self):
+ mock_send_transaction = (
+ self.hs.get_federation_transport_client().send_transaction
+ )
+ mock_send_transaction.return_value = make_awaitable({})
+
+ # Create receipts for:
+ #
+ # * The same room / user on multiple threads.
+ # * A different user in the same room.
+ sender = self.hs.get_federation_sender()
+ for user, thread in (
+ ("alice", None),
+ ("alice", "thread"),
+ ("bob", None),
+ ("bob", "diff-thread"),
+ ):
+ receipt = ReadReceipt(
+ "room_id",
+ "m.read",
+ user,
+ ["event_id"],
+ thread_id=thread,
+ data={"ts": 1234},
+ )
+ self.successResultOf(
+ defer.ensureDeferred(sender.send_read_receipt(receipt))
+ )
+
+ self.pump()
+
+ # expect a call to send_transaction with two EDUs to separate threads.
+ mock_send_transaction.assert_called_once()
+ json_cb = mock_send_transaction.call_args[0][1]
+ data = json_cb()
+ # Note that the ordering of the EDUs doesn't matter.
+ self.assertCountEqual(
+ data["edus"],
+ [
+ {
+ "edu_type": EduTypes.RECEIPT,
+ "content": {
+ "room_id": {
+ "m.read": {
+ "alice": {
+ "event_ids": ["event_id"],
+ "data": {"ts": 1234, "thread_id": "thread"},
+ },
+ "bob": {
+ "event_ids": ["event_id"],
+ "data": {"ts": 1234, "thread_id": "diff-thread"},
+ },
+ }
+ }
+ },
+ },
+ {
+ "edu_type": EduTypes.RECEIPT,
+ "content": {
+ "room_id": {
+ "m.read": {
+ "alice": {
+ "event_ids": ["event_id"],
+ "data": {"ts": 1234},
+ },
+ "bob": {
+ "event_ids": ["event_id"],
+ "data": {"ts": 1234},
+ },
+ }
+ }
+ },
+ },
+ ],
+ )
+
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
@@ -170,6 +260,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
class FederationSenderDevicesTestCases(HomeserverTestCase):
+ """
+ Test federation sending to update devices.
+
+ By default for test cases federation sending is disabled. This Test class has it
+ re-enabled for the main process.
+ """
+
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -184,7 +281,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def default_config(self):
c = super().default_config()
- c["send_federation"] = True
+ # Enable federation sending on the main process.
+ c["federation_sender_instances"] = None
return c
def prepare(self, reactor, clock, hs):
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 177e5b5afc..be719e49c0 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -211,9 +211,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
- @override_config({"experimental_features": {"msc3706_enabled": True}})
def test_send_join_partial_state(self) -> None:
- """When MSC3706 support is enabled, /send_join should return partial state"""
+ """/send_join should return partial state, if requested"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)
@@ -224,7 +223,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
)
channel = self.make_signed_federation_request(
"PUT",
- f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
+ f"/_matrix/federation/v2/send_join/{self._room_id}/x?omit_members=true",
content=join_event_dict,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index b84c74fc0e..3d61b1e8a9 100644
--- a/tests/federation/transport/test_client.py
+++ b/tests/federation/transport/test_client.py
@@ -13,12 +13,14 @@
# limitations under the License.
import json
+from typing import List, Optional
from unittest.mock import Mock
import ijson.common
from synapse.api.room_versions import RoomVersions
from synapse.federation.transport.client import SendJoinParser
+from synapse.types import JsonDict
from synapse.util import ExceptionBundle
from tests.unittest import TestCase
@@ -66,38 +68,73 @@ class SendJoinParserTestCase(TestCase):
self.assertEqual(len(parsed_response.state), 1, parsed_response)
self.assertEqual(parsed_response.event_dict, {}, parsed_response)
self.assertIsNone(parsed_response.event, parsed_response)
- self.assertFalse(parsed_response.partial_state, parsed_response)
+ self.assertFalse(parsed_response.members_omitted, parsed_response)
self.assertEqual(parsed_response.servers_in_room, None, parsed_response)
def test_partial_state(self) -> None:
- """Check that the partial_state flag is correctly parsed"""
- parser = SendJoinParser(RoomVersions.V1, False)
- response = {
- "org.matrix.msc3706.partial_state": True,
- }
+ """Check that the members_omitted flag is correctly parsed"""
- serialised_response = json.dumps(response).encode()
+ def parse(response: JsonDict) -> bool:
+ parser = SendJoinParser(RoomVersions.V1, False)
+ serialised_response = json.dumps(response).encode()
- # Send data to the parser
- parser.write(serialised_response)
+ # Send data to the parser
+ parser.write(serialised_response)
- # Retrieve and check the parsed SendJoinResponse
- parsed_response = parser.finish()
- self.assertTrue(parsed_response.partial_state)
+ # Retrieve and check the parsed SendJoinResponse
+ parsed_response = parser.finish()
+ return parsed_response.members_omitted
+
+ self.assertTrue(parse({"members_omitted": True}))
+ self.assertTrue(parse({"org.matrix.msc3706.partial_state": True}))
+
+ self.assertFalse(parse({"members_omitted": False}))
+ self.assertFalse(parse({"org.matrix.msc3706.partial_state": False}))
+
+ # If there's a conflict, the stable field wins.
+ self.assertTrue(
+ parse({"members_omitted": True, "org.matrix.msc3706.partial_state": False})
+ )
+ self.assertFalse(
+ parse({"members_omitted": False, "org.matrix.msc3706.partial_state": True})
+ )
def test_servers_in_room(self) -> None:
"""Check that the servers_in_room field is correctly parsed"""
- parser = SendJoinParser(RoomVersions.V1, False)
- response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
- serialised_response = json.dumps(response).encode()
+ def parse(response: JsonDict) -> Optional[List[str]]:
+ parser = SendJoinParser(RoomVersions.V1, False)
+ serialised_response = json.dumps(response).encode()
- # Send data to the parser
- parser.write(serialised_response)
+ # Send data to the parser
+ parser.write(serialised_response)
- # Retrieve and check the parsed SendJoinResponse
- parsed_response = parser.finish()
- self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"])
+ # Retrieve and check the parsed SendJoinResponse
+ parsed_response = parser.finish()
+ return parsed_response.servers_in_room
+
+ self.assertEqual(
+ parse({"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}),
+ ["hs1", "hs2"],
+ )
+ self.assertEqual(parse({"servers_in_room": ["example.com"]}), ["example.com"])
+
+ # If both are provided, the stable identifier should win
+ self.assertEqual(
+ parse(
+ {
+ "org.matrix.msc3706.servers_in_room": ["old"],
+ "servers_in_room": ["new"],
+ }
+ ),
+ ["new"],
+ )
+
+ # And lastly, we should be able to tell if neither field was present.
+ self.assertEqual(
+ parse({}),
+ None,
+ )
def test_errors_closing_coroutines(self) -> None:
"""Check we close all coroutines, even if closing the first raises an Exception.
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index d21c11b716..ff589c0b6c 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -23,10 +23,10 @@ from synapse.server import HomeServer
from synapse.types import RoomAlias
from tests.test_utils import event_injection
-from tests.unittest import FederatingHomeserverTestCase, TestCase
+from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase
-class KnockingStrippedStateEventHelperMixin(TestCase):
+class KnockingStrippedStateEventHelperMixin(HomeserverTestCase):
def send_example_state_events_to_room(
self,
hs: "HomeServer",
@@ -49,7 +49,7 @@ class KnockingStrippedStateEventHelperMixin(TestCase):
# To set a canonical alias, we'll need to point an alias at the room first.
canonical_alias = "#fancy_alias:test"
self.get_success(
- self.store.create_room_alias_association(
+ self.hs.get_datastores().main.create_room_alias_association(
RoomAlias.from_string(canonical_alias), room_id, ["test"]
)
)
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index c1579dac61..6f300b8e11 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -38,6 +38,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_handler = hs.get_admin_handler()
+ self._store = hs.get_datastores().main
self.user1 = self.register_user("user1", "password")
self.token1 = self.login("user1", "password")
@@ -236,3 +237,62 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0], room_id)
self.assertEqual(args[1].content["membership"], "knock")
self.assertTrue(args[2]) # Assert there is at least one bit of state
+
+ def test_profile(self) -> None:
+ """Tests that user profile get exported."""
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_not_called()
+ writer.write_profile.assert_called_once()
+
+ # check only a few values, not all available
+ args = writer.write_profile.call_args[0]
+ self.assertEqual(args[0]["name"], self.user2)
+ self.assertIn("displayname", args[0])
+ self.assertIn("avatar_url", args[0])
+ self.assertIn("threepids", args[0])
+ self.assertIn("external_ids", args[0])
+ self.assertIn("creation_ts", args[0])
+
+ def test_devices(self) -> None:
+ """Tests that user devices get exported."""
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_not_called()
+ writer.write_devices.assert_called_once()
+
+ args = writer.write_devices.call_args[0]
+ self.assertEqual(len(args[0]), 1)
+ self.assertEqual(args[0][0]["user_id"], self.user2)
+ self.assertIn("device_id", args[0][0])
+ self.assertIsNone(args[0][0]["display_name"])
+ self.assertIsNone(args[0][0]["last_seen_user_agent"])
+ self.assertIsNone(args[0][0]["last_seen_ts"])
+ self.assertIsNone(args[0][0]["last_seen_ip"])
+
+ def test_connections(self) -> None:
+ """Tests that user sessions / connections get exported."""
+ # Insert a user IP
+ self.get_success(
+ self._store.insert_client_ip(
+ self.user2, "access_token", "ip", "user_agent", "MY_DEVICE"
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_not_called()
+ writer.write_connections.assert_called_once()
+
+ args = writer.write_connections.call_args[0]
+ self.assertEqual(len(args[0]), 1)
+ self.assertEqual(args[0][0]["ip"], "ip")
+ self.assertEqual(args[0][0]["user_agent"], "user_agent")
+ self.assertGreater(args[0][0]["last_seen"], 0)
+ self.assertNotIn("access_token", args[0][0])
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 144e49d0fd..a7495ab21a 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -25,13 +25,13 @@ import synapse.storage
from synapse.api.constants import EduTypes, EventTypes
from synapse.appservice import (
ApplicationService,
- TransactionOneTimeKeyCounts,
+ TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
-from synapse.types import RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -44,7 +44,7 @@ from tests.utils import MockClock
class AppServiceHandlerTestCase(unittest.TestCase):
"""Tests the ApplicationServicesHandler."""
- def setUp(self):
+ def setUp(self) -> None:
self.mock_store = Mock()
self.mock_as_api = Mock()
self.mock_scheduler = Mock()
@@ -61,7 +61,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler = ApplicationServicesHandler(hs)
self.event_source = hs.get_event_sources()
- def test_notify_interested_services(self):
+ def test_notify_interested_services(self) -> None:
interested_service = self._mkservice(is_interested_in_event=True)
services = [
self._mkservice(is_interested_in_event=False),
@@ -90,7 +90,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, events=[event]
)
- def test_query_user_exists_unknown_user(self):
+ def test_query_user_exists_unknown_user(self) -> None:
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
@@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
- def test_query_user_exists_known_user(self):
+ def test_query_user_exists_known_user(self) -> None:
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
@@ -127,7 +127,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"query_user called when it shouldn't have been.",
)
- def test_query_room_alias_exists(self):
+ def test_query_room_alias_exists(self) -> None:
room_alias_str = "#foo:bar"
room_alias = Mock()
room_alias.to_string.return_value = room_alias_str
@@ -157,7 +157,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.assertEqual(result.room_id, room_id)
self.assertEqual(result.servers, servers)
- def test_get_3pe_protocols_no_appservices(self):
+ def test_get_3pe_protocols_no_appservices(self) -> None:
self.mock_store.get_app_services.return_value = []
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
@@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.get_3pe_protocol.assert_not_called()
self.assertEqual(response, {})
- def test_get_3pe_protocols_no_protocols(self):
+ def test_get_3pe_protocols_no_protocols(self) -> None:
service = self._mkservice(False, [])
self.mock_store.get_app_services.return_value = [service]
response = self.successResultOf(
@@ -174,7 +174,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.get_3pe_protocol.assert_not_called()
self.assertEqual(response, {})
- def test_get_3pe_protocols_protocol_no_response(self):
+ def test_get_3pe_protocols_protocol_no_response(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
@@ -186,7 +186,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.assertEqual(response, {})
- def test_get_3pe_protocols_select_one_protocol(self):
+ def test_get_3pe_protocols_select_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
@@ -202,7 +202,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
)
- def test_get_3pe_protocols_one_protocol(self):
+ def test_get_3pe_protocols_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
@@ -218,7 +218,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
)
- def test_get_3pe_protocols_multiple_protocol(self):
+ def test_get_3pe_protocols_multiple_protocol(self) -> None:
service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["other-protocol"])
self.mock_store.get_app_services.return_value = [service_one, service_two]
@@ -237,11 +237,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
},
)
- def test_get_3pe_protocols_multiple_info(self):
+ def test_get_3pe_protocols_multiple_info(self) -> None:
service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["my-protocol"])
- async def get_3pe_protocol(service, unusedProtocol):
+ async def get_3pe_protocol(
+ service: ApplicationService, protocol: str
+ ) -> Optional[JsonDict]:
if service == service_one:
return {
"x-protocol-data": 42,
@@ -276,7 +278,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
},
)
- def test_notify_interested_services_ephemeral(self):
+ def test_notify_interested_services_ephemeral(self) -> None:
"""
Test sending ephemeral events to the appservice handler are scheduled
to be pushed out to interested appservices, and that the stream ID is
@@ -306,7 +308,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
580,
)
- def test_notify_interested_services_ephemeral_out_of_order(self):
+ def test_notify_interested_services_ephemeral_out_of_order(self) -> None:
"""
Test sending out of order ephemeral events to the appservice handler
are ignored.
@@ -390,7 +392,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipts.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
@@ -417,7 +419,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
- def _notify_interested_services(self):
+ def _notify_interested_services(self) -> None:
# This is normally set in `notify_interested_services` but we need to call the
# internal async version so the reactor gets pushed to completion.
self.hs.get_application_service_handler().current_max += 1
@@ -443,7 +445,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
)
def test_match_interesting_room_members(
self, interesting_user: str, should_notify: bool
- ):
+ ) -> None:
"""
Test to make sure that a interesting user (local or remote) in the room is
notified as expected when someone else in the room sends a message.
@@ -512,7 +514,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
else:
self.send_mock.assert_not_called()
- def test_application_services_receive_events_sent_by_interesting_local_user(self):
+ def test_application_services_receive_events_sent_by_interesting_local_user(
+ self,
+ ) -> None:
"""
Test to make sure that a messages sent from a local user can be interesting and
picked up by the appservice.
@@ -568,7 +572,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0]["type"], "m.room.message")
self.assertEqual(events[0]["sender"], alice)
- def test_sending_read_receipt_batches_to_application_services(self):
+ def test_sending_read_receipt_batches_to_application_services(self) -> None:
"""Tests that a large batch of read receipts are sent correctly to
interested application services.
"""
@@ -644,7 +648,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
- def test_application_services_receive_local_to_device(self):
+ def test_application_services_receive_local_to_device(self) -> None:
"""
Test that when a user sends a to-device message to another user
that is an application service's user namespace, the
@@ -722,7 +726,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
- def test_application_services_receive_bursts_of_to_device(self):
+ def test_application_services_receive_bursts_of_to_device(self) -> None:
"""
Test that when a user sends >100 to-device messages at once, any
interested AS's will receive them in separate transactions.
@@ -765,7 +769,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)]
messages = {
self.exclusive_as_user: {
- device_id: to_device_message_content for device_id in fake_device_ids
+ device_id: {
+ "type": "test_to_device_message",
+ "sender": "@some:sender",
+ "content": to_device_message_content,
+ }
+ for device_id in fake_device_ids
}
}
@@ -908,7 +917,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
experimental_feature_enabled: bool,
as_supports_txn_extensions: bool,
as_should_receive_device_list_updates: bool,
- ):
+ ) -> None:
"""
Tests that an application service receives notice of changed device
lists for a user, when a user changes their device lists.
@@ -1065,7 +1074,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
and a room for the users to talk in.
"""
- async def preparation():
+ async def preparation() -> None:
await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
await self._add_fallback_key_for_device(
self._sender_user, self._sender_device, used=True
@@ -1123,7 +1132,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Capture what was sent as an AS transaction.
self.send_mock.assert_called()
last_args, _last_kwargs = self.send_mock.call_args
- otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS]
+ otks: Optional[TransactionOneTimeKeysCount] = last_args[self.ARG_OTK_COUNTS]
unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[
self.ARG_FALLBACK_KEYS
]
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 2b21547d0f..2733719d82 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -199,7 +199,7 @@ class CasHandlerTestCase(HomeserverTestCase):
)
-def _mock_request():
+def _mock_request() -> Mock:
"""Returns a mock which will stand in as a SynapseRequest"""
mock = Mock(
spec=[
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index b8b465d35b..ce7525e29c 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,7 +19,7 @@ from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError
-from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
+from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer
from synapse.util import Clock
@@ -32,7 +32,9 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.store = hs.get_datastores().main
return hs
@@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display name")
def test_device_is_preserved_if_exists(self) -> None:
@@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res2, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display name")
def test_device_id_is_made_up_if_unspecified(self) -> None:
@@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display")
def test_get_devices_by_user(self) -> None:
@@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
@@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
)
)
- retrieved_device_id, device_data = self.get_success(
- self.handler.get_dehydrated_device(user_id=user_id)
- )
+ result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+ assert result is not None
+ retrieved_device_id, device_data = result
self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 3b72c4c9d0..90aec484c4 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.api.errors
import synapse.rest.admin
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
from synapse.rest.client import directory, login, room
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, create_requester
@@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
- def _create_alias(self, user) -> None:
+ def _create_alias(self, user: str) -> None:
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
@@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
return room_alias
- def _set_canonical_alias(self, content) -> None:
+ def _set_canonical_alias(self, content: JsonDict) -> None:
"""Configure the canonical alias state on the room."""
self.helper.send_state(
self.room_id,
@@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
tok=self.admin_user_tok,
)
- def _get_canonical_alias(self):
+ def _get_canonical_alias(self) -> EventBase:
"""Get the canonical alias state of the room."""
- return self.get_success(
+ result = self.get_success(
self._storage_controllers.state.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
+ assert result is not None
+ return result
def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too."""
@@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
data = self._get_canonical_alias()
- self.assertEqual(data["content"]["alias"], self.test_alias)
- self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
+ self.assertEqual(data.content["alias"], self.test_alias)
+ self.assertEqual(data.content["alt_aliases"], [self.test_alias])
# Finally, delete the alias.
self.get_success(
@@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
data = self._get_canonical_alias()
- self.assertNotIn("alias", data["content"])
- self.assertNotIn("alt_aliases", data["content"])
+ self.assertNotIn("alias", data.content)
+ self.assertNotIn("alt_aliases", data.content)
def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too."""
@@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
data = self._get_canonical_alias()
- self.assertEqual(data["content"]["alias"], self.test_alias)
+ self.assertEqual(data.content["alias"], self.test_alias)
self.assertEqual(
- data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
+ data.content["alt_aliases"], [self.test_alias, other_test_alias]
)
# Delete the second alias.
@@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
data = self._get_canonical_alias()
- self.assertEqual(data["content"]["alias"], self.test_alias)
- self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
+ self.assertEqual(data.content["alias"], self.test_alias)
+ self.assertEqual(data.content["alt_aliases"], [self.test_alias])
class TestCreateAliasACL(unittest.HomeserverTestCase):
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 9b7e7a8e9a..6c0b30de9e 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -17,7 +17,11 @@
import copy
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import SynapseError
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -39,14 +43,14 @@ room_keys = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(replication_layer=mock.Mock())
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_room_keys_handler()
self.local_user = "@boris:" + hs.hostname
- def test_get_missing_current_version_info(self):
+ def test_get_missing_current_version_info(self) -> None:
"""Check that we get a 404 if we ask for info about the current version
if there is no version.
"""
@@ -56,7 +60,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_get_missing_version_info(self):
+ def test_get_missing_version_info(self) -> None:
"""Check that we get a 404 if we ask for info about a specific version
if it doesn't exist.
"""
@@ -67,9 +71,9 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_create_version(self):
+ def test_create_version(self) -> None:
"""Check that we can create and then retrieve versions."""
- res = self.get_success(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -78,7 +82,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
)
- self.assertEqual(res, "1")
+ self.assertEqual(version, "1")
# check we can retrieve it as the current version
res = self.get_success(self.handler.get_version_info(self.local_user))
@@ -110,7 +114,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# upload a new one...
- res = self.get_success(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -119,7 +123,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
)
- self.assertEqual(res, "2")
+ self.assertEqual(version, "2")
# check we can retrieve it as the current version
res = self.get_success(self.handler.get_version_info(self.local_user))
@@ -134,7 +138,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_update_version(self):
+ def test_update_version(self) -> None:
"""Check that we can update versions."""
version = self.get_success(
self.handler.create_version(
@@ -173,7 +177,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_update_missing_version(self):
+ def test_update_missing_version(self) -> None:
"""Check that we get a 404 on updating nonexistent versions"""
e = self.get_failure(
self.handler.update_version(
@@ -190,7 +194,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_update_omitted_version(self):
+ def test_update_omitted_version(self) -> None:
"""Check that the update succeeds if the version is missing from the body"""
version = self.get_success(
self.handler.create_version(
@@ -227,7 +231,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_update_bad_version(self):
+ def test_update_bad_version(self) -> None:
"""Check that we get a 400 if the version in the body doesn't match"""
version = self.get_success(
self.handler.create_version(
@@ -255,7 +259,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 400)
- def test_delete_missing_version(self):
+ def test_delete_missing_version(self) -> None:
"""Check that we get a 404 on deleting nonexistent versions"""
e = self.get_failure(
self.handler.delete_version(self.local_user, "1"), SynapseError
@@ -263,15 +267,15 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_delete_missing_current_version(self):
+ def test_delete_missing_current_version(self) -> None:
"""Check that we get a 404 on deleting nonexistent current version"""
e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
res = e.value.code
self.assertEqual(res, 404)
- def test_delete_version(self):
+ def test_delete_version(self) -> None:
"""Check that we can create and then delete versions."""
- res = self.get_success(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -280,7 +284,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
)
- self.assertEqual(res, "1")
+ self.assertEqual(version, "1")
# check we can delete it
self.get_success(self.handler.delete_version(self.local_user, "1"))
@@ -292,7 +296,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_get_missing_backup(self):
+ def test_get_missing_backup(self) -> None:
"""Check that we get a 404 on querying missing backup"""
e = self.get_failure(
self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
@@ -300,7 +304,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_get_missing_room_keys(self):
+ def test_get_missing_room_keys(self) -> None:
"""Check we get an empty response from an empty backup"""
version = self.get_success(
self.handler.create_version(
@@ -319,7 +323,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
# TODO: test the locking semantics when uploading room_keys,
# although this is probably best done in sytest
- def test_upload_room_keys_no_versions(self):
+ def test_upload_room_keys_no_versions(self) -> None:
"""Check that we get a 404 on uploading keys when no versions are defined"""
e = self.get_failure(
self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
@@ -328,7 +332,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_upload_room_keys_bogus_version(self):
+ def test_upload_room_keys_bogus_version(self) -> None:
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
@@ -350,7 +354,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_upload_room_keys_wrong_version(self):
+ def test_upload_room_keys_wrong_version(self) -> None:
"""Check that we get a 403 on uploading keys for an old version"""
version = self.get_success(
self.handler.create_version(
@@ -380,7 +384,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 403)
- def test_upload_room_keys_insert(self):
+ def test_upload_room_keys_insert(self) -> None:
"""Check that we can insert and retrieve keys for a session"""
version = self.get_success(
self.handler.create_version(
@@ -416,7 +420,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertDictEqual(res, room_keys)
- def test_upload_room_keys_merge(self):
+ def test_upload_room_keys_merge(self) -> None:
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
version = self.get_success(
@@ -449,9 +453,11 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = self.get_success(self.handler.get_room_keys(self.local_user, version))
+ res_keys = self.get_success(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
- res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
+ res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)
@@ -465,9 +471,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = self.get_success(self.handler.get_room_keys(self.local_user, version))
+ res_keys = self.get_success(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
- res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
+ res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
+ "new",
)
# the etag should NOT be equal now, since the key changed
@@ -483,9 +492,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = self.get_success(self.handler.get_room_keys(self.local_user, version))
+ res_keys = self.get_success(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
- res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
+ res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
+ "new",
)
# the etag should be the same since the session did not change
@@ -494,7 +506,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
# TODO: check edge cases as well as the common variations here
- def test_delete_room_keys(self):
+ def test_delete_room_keys(self) -> None:
"""Check that we can insert and delete keys for a session"""
version = self.get_success(
self.handler.create_version(
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d00c69c229..57675fa407 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import cast
+from typing import Collection, Optional, cast
from unittest import TestCase
from unittest.mock import Mock, patch
+from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
@@ -439,7 +440,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
- def create_invite():
+ def create_invite() -> EventBase:
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
room_version = self.get_success(self.store.get_room_version(room_id))
return event_from_pdu_json(
@@ -655,7 +656,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
EVENT_INVITATION_MEMBERSHIP,
],
partial_state=True,
- servers_in_room=["example.com"],
+ servers_in_room={"example.com"},
)
)
)
@@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
f"Stale partial-stated room flag left over for {room_id} after a"
f" failed do_invite_join!",
)
+
+ def test_duplicate_partial_state_room_syncs(self) -> None:
+ """
+ Tests that concurrent partial state syncs are not started for the same room.
+ """
+ is_partial_state = True
+ end_sync: "Deferred[None]" = Deferred()
+
+ async def is_partial_state_room(room_id: str) -> bool:
+ return is_partial_state
+
+ async def sync_partial_state_room(
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
+ room_id: str,
+ ) -> None:
+ nonlocal end_sync
+ try:
+ await end_sync
+ finally:
+ end_sync = Deferred()
+
+ mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+ mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+ fed_handler = self.hs.get_federation_handler()
+ store = self.hs.get_datastores().main
+
+ with patch.object(
+ fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+ ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+ # Start the partial state sync.
+ fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # Try to start another partial state sync.
+ # Nothing should happen.
+ fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # End the partial state sync
+ is_partial_state = False
+ end_sync.callback(None)
+
+ # The partial state sync should not be restarted.
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # The next attempt to start the partial state sync should work.
+ is_partial_state = True
+ fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+ def test_partial_state_room_sync_restart(self) -> None:
+ """
+ Tests that partial state syncs are restarted when a second partial state sync
+ was deduplicated and the first partial state sync fails.
+ """
+ is_partial_state = True
+ end_sync: "Deferred[None]" = Deferred()
+
+ async def is_partial_state_room(room_id: str) -> bool:
+ return is_partial_state
+
+ async def sync_partial_state_room(
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
+ room_id: str,
+ ) -> None:
+ nonlocal end_sync
+ try:
+ await end_sync
+ finally:
+ end_sync = Deferred()
+
+ mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+ mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+ fed_handler = self.hs.get_federation_handler()
+ store = self.hs.get_datastores().main
+
+ with patch.object(
+ fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+ ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+ # Start the partial state sync.
+ fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # Fail the partial state sync.
+ # The partial state sync should not be restarted.
+ end_sync.errback(Exception("Failed to request /state_ids"))
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # Start the partial state sync again.
+ fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+ # Deduplicate another partial state sync.
+ fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+ # Fail the partial state sync.
+ # It should restart with the latest parameters.
+ end_sync.errback(Exception("Failed to request /state_ids"))
+ self.assertEqual(mock_sync_partial_state_room.call_count, 3)
+ mock_sync_partial_state_room.assert_called_with(
+ initial_destination="hs3",
+ other_destinations=["hs2"],
+ room_id="room_id",
+ )
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index e448cb1901..70ea4d15d4 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -14,6 +14,8 @@
from typing import Optional
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import AuthError, StoreError
from synapse.api.room_versions import RoomVersion
from synapse.event_auth import (
@@ -26,8 +28,10 @@ from synapse.federation.transport.client import StateRequestResponse
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import event_injection, make_awaitable
@@ -40,7 +44,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# mock out the federation transport client
self.mock_federation_transport_client = mock.Mock(
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
@@ -165,7 +169,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
else:
- async def get_event(destination: str, event_id: str, timeout=None):
+ async def get_event(
+ destination: str, event_id: str, timeout: Optional[int] = None
+ ) -> JsonDict:
self.assertEqual(destination, self.OTHER_SERVER_NAME)
self.assertEqual(event_id, prev_event.event_id)
return {"pdus": [prev_event.get_pdu_json()]}
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 99384837d0..c4727ab917 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -14,12 +14,16 @@
import logging
from typing import Tuple
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.types import create_requester
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -35,7 +39,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler()
self._persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence
@@ -94,7 +98,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
)
)
- def test_duplicated_txn_id(self):
+ def test_duplicated_txn_id(self) -> None:
"""Test that attempting to handle/persist an event with a transaction ID
that has already been persisted correctly returns the old event and does
*not* produce duplicate messages.
@@ -161,7 +165,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# rather than the new one.
self.assertEqual(ret_event1.event_id, ret_event4.event_id)
- def test_duplicated_txn_id_one_call(self):
+ def test_duplicated_txn_id_one_call(self) -> None:
"""Test that we correctly handle duplicates that we try and persist at
the same time.
"""
@@ -185,7 +189,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[0].event_id, events[1].event_id)
- def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self):
+ def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(
+ self,
+ ) -> None:
"""When we set allow_no_prev_events=True, should be able to create a
event without any prev_events (only auth_events).
"""
@@ -214,7 +220,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
self,
- ):
+ ) -> None:
"""When we set allow_no_prev_events=False, shouldn't be able to create a
event without any prev_events even if it has auth_events. Expect an
exception to be raised.
@@ -245,7 +251,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
self,
- ):
+ ) -> None:
"""When we set allow_no_prev_events=True, should be able to create a
event without any prev_events or auth_events. Expect an exception to be
raised.
@@ -277,12 +283,12 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
- def test_allow_server_acl(self):
+ def test_allow_server_acl(self) -> None:
"""Test that sending an ACL that blocks everyone but ourselves works."""
self.helper.send_state(
@@ -293,7 +299,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
expect_code=200,
)
- def test_deny_server_acl_block_outselves(self):
+ def test_deny_server_acl_block_outselves(self) -> None:
"""Test that sending an ACL that blocks ourselves does not work."""
self.helper.send_state(
self.room_id,
@@ -303,7 +309,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
expect_code=400,
)
- def test_deny_redact_server_acl(self):
+ def test_deny_redact_server_acl(self) -> None:
"""Test that attempting to redact an ACL is blocked."""
body = self.helper.send_state(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5955410524..adddbd002f 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
-from typing import Any, Dict, Tuple
+from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
@@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from synapse.util.stringutils import random_string
@@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config
try:
import authlib # noqa: F401
+ from authlib.oidc.core import UserInfo
+ from authlib.oidc.discovery import OpenIDProviderMetadata
+
+ from synapse.handlers.oidc import Token, UserAttributeDict
HAS_OIDC = True
except ImportError:
@@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = {
class TestMappingProvider:
@staticmethod
- def parse_config(config):
- return
+ def parse_config(config: JsonDict) -> None:
+ return None
- def __init__(self, config):
+ def __init__(self, config: None):
pass
- def get_remote_user_id(self, userinfo):
+ def get_remote_user_id(self, userinfo: "UserInfo") -> str:
return userinfo["sub"]
- async def map_user_attributes(self, userinfo, token):
- return {"localpart": userinfo["username"], "display_name": None}
+ async def map_user_attributes(
+ self, userinfo: "UserInfo", token: "Token"
+ ) -> "UserAttributeDict":
+ # This is testing not providing the full map.
+ return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item]
# Do not include get_extra_attributes to test backwards compatibility paths.
class TestMappingProviderExtra(TestMappingProvider):
- async def get_extra_attributes(self, userinfo, token):
+ async def get_extra_attributes(
+ self, userinfo: "UserInfo", token: "Token"
+ ) -> JsonDict:
return {"phone": userinfo["phone"]}
class TestMappingProviderFailures(TestMappingProvider):
- async def map_user_attributes(self, userinfo, token, failures):
- return {
+ # Superclass is testing the legacy interface for map_user_attributes.
+ async def map_user_attributes( # type: ignore[override]
+ self, userinfo: "UserInfo", token: "Token", failures: int
+ ) -> "UserAttributeDict":
+ return { # type: ignore[typeddict-item]
"localpart": userinfo["username"] + (str(failures) if failures else ""),
"display_name": None,
}
@@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.hs_patcher.stop()
return super().tearDown()
- def reset_mocks(self):
+ def reset_mocks(self) -> None:
"""Reset all the Mocks."""
self.fake_server.reset_mocks()
self.render_error.reset_mock()
self.complete_sso_login.reset_mock()
- def metadata_edit(self, values):
+ def metadata_edit(self, values: dict) -> ContextManager[Mock]:
"""Modify the result that will be returned by the well-known query"""
metadata = self.fake_server.get_metadata()
@@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
return _build_callback_request(code, state, session), grant
- def assertRenderedError(self, error, error_description=None):
+ def assertRenderedError(
+ self, error: str, error_description: Optional[str] = None
+ ) -> Tuple[Any, ...]:
self.render_error.assert_called_once()
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
@@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadatas are extensively validated."""
h = self.provider
- def force_load_metadata():
- async def force_load():
+ def force_load_metadata() -> Awaitable[None]:
+ async def force_load() -> "OpenIDProviderMetadata":
return await h.load_metadata(force=True)
return get_awaitable_result(force_load())
@@ -382,6 +396,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(params["client_id"], [CLIENT_ID])
self.assertEqual(len(params["state"]), 1)
self.assertEqual(len(params["nonce"]), 1)
+ self.assertNotIn("code_challenge", params)
# Check what is in the cookies
self.assertEqual(len(req.cookies), 2) # two cookies
@@ -397,12 +412,117 @@ class OidcHandlerTestCase(HomeserverTestCase):
macaroon = pymacaroons.Macaroon.deserialize(cookie)
state = get_value_from_macaroon(macaroon, "state")
nonce = get_value_from_macaroon(macaroon, "nonce")
+ code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
self.assertEqual(params["state"], [state])
self.assertEqual(params["nonce"], [nonce])
+ self.assertEqual(code_verifier, "")
self.assertEqual(redirect, "http://client/redirect")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
+ def test_redirect_request_with_code_challenge(self) -> None:
+ """The redirect request has the right arguments & generates a valid session cookie."""
+ req = Mock(spec=["cookies"])
+ req.cookies = []
+
+ with self.metadata_edit({"code_challenge_methods_supported": ["S256"]}):
+ url = urlparse(
+ self.get_success(
+ self.provider.handle_redirect_request(
+ req, b"http://client/redirect"
+ )
+ )
+ )
+
+ # Ensure the code_challenge param is added to the redirect.
+ params = parse_qs(url.query)
+ self.assertEqual(len(params["code_challenge"]), 1)
+
+ # Check what is in the cookies
+ self.assertEqual(len(req.cookies), 2) # two cookies
+ cookie_header = req.cookies[0]
+
+ # The cookie name and path don't really matter, just that it has to be coherent
+ # between the callback & redirect handlers.
+ parts = [p.strip() for p in cookie_header.split(b";")]
+ self.assertIn(b"Path=/_synapse/client/oidc", parts)
+ name, cookie = parts[0].split(b"=")
+ self.assertEqual(name, b"oidc_session")
+
+ # Ensure the code_verifier is set in the cookie.
+ macaroon = pymacaroons.Macaroon.deserialize(cookie)
+ code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
+ self.assertNotEqual(code_verifier, "")
+
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "pkce_method": "always"}})
+ def test_redirect_request_with_forced_code_challenge(self) -> None:
+ """The redirect request has the right arguments & generates a valid session cookie."""
+ req = Mock(spec=["cookies"])
+ req.cookies = []
+
+ url = urlparse(
+ self.get_success(
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
+ )
+ )
+
+ # Ensure the code_challenge param is added to the redirect.
+ params = parse_qs(url.query)
+ self.assertEqual(len(params["code_challenge"]), 1)
+
+ # Check what is in the cookies
+ self.assertEqual(len(req.cookies), 2) # two cookies
+ cookie_header = req.cookies[0]
+
+ # The cookie name and path don't really matter, just that it has to be coherent
+ # between the callback & redirect handlers.
+ parts = [p.strip() for p in cookie_header.split(b";")]
+ self.assertIn(b"Path=/_synapse/client/oidc", parts)
+ name, cookie = parts[0].split(b"=")
+ self.assertEqual(name, b"oidc_session")
+
+ # Ensure the code_verifier is set in the cookie.
+ macaroon = pymacaroons.Macaroon.deserialize(cookie)
+ code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
+ self.assertNotEqual(code_verifier, "")
+
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "pkce_method": "never"}})
+ def test_redirect_request_with_disabled_code_challenge(self) -> None:
+ """The redirect request has the right arguments & generates a valid session cookie."""
+ req = Mock(spec=["cookies"])
+ req.cookies = []
+
+ # The metadata should state that PKCE is enabled.
+ with self.metadata_edit({"code_challenge_methods_supported": ["S256"]}):
+ url = urlparse(
+ self.get_success(
+ self.provider.handle_redirect_request(
+ req, b"http://client/redirect"
+ )
+ )
+ )
+
+ # Ensure the code_challenge param is added to the redirect.
+ params = parse_qs(url.query)
+ self.assertNotIn("code_challenge", params)
+
+ # Check what is in the cookies
+ self.assertEqual(len(req.cookies), 2) # two cookies
+ cookie_header = req.cookies[0]
+
+ # The cookie name and path don't really matter, just that it has to be coherent
+ # between the callback & redirect handlers.
+ parts = [p.strip() for p in cookie_header.split(b";")]
+ self.assertIn(b"Path=/_synapse/client/oidc", parts)
+ name, cookie = parts[0].split(b"=")
+ self.assertEqual(name, b"oidc_session")
+
+ # Ensure the code_verifier is blank in the cookie.
+ macaroon = pymacaroons.Macaroon.deserialize(cookie)
+ code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
+ self.assertEqual(code_verifier, "")
+
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed."""
@@ -587,7 +707,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
payload=token
)
code = "code"
- ret = self.get_success(self.provider._exchange_code(code))
+ ret = self.get_success(self.provider._exchange_code(code, code_verifier=""))
kwargs = self.fake_server.request.call_args[1]
self.assertEqual(ret, token)
@@ -601,13 +721,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(args["client_secret"], [CLIENT_SECRET])
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+ # Test providing a code verifier.
+ code_verifier = "code_verifier"
+ ret = self.get_success(
+ self.provider._exchange_code(code, code_verifier=code_verifier)
+ )
+ kwargs = self.fake_server.request.call_args[1]
+
+ self.assertEqual(ret, token)
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
+
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["client_secret"], [CLIENT_SECRET])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+ self.assertEqual(args["code_verifier"], [code_verifier])
+
# Test error handling
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=400, payload={"error": "foo", "error_description": "bar"}
)
from synapse.handlers.oidc import OidcError
- exc = self.get_failure(self.provider._exchange_code(code), OidcError)
+ exc = self.get_failure(
+ self.provider._exchange_code(code, code_verifier=""), OidcError
+ )
self.assertEqual(exc.value.error, "foo")
self.assertEqual(exc.value.error_description, "bar")
@@ -615,7 +756,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.fake_server.post_token_handler.return_value = FakeResponse(
code=500, body=b"Not JSON"
)
- exc = self.get_failure(self.provider._exchange_code(code), OidcError)
+ exc = self.get_failure(
+ self.provider._exchange_code(code, code_verifier=""), OidcError
+ )
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
@@ -623,21 +766,27 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, payload={"error": "internal_server_error"}
)
- exc = self.get_failure(self.provider._exchange_code(code), OidcError)
+ exc = self.get_failure(
+ self.provider._exchange_code(code, code_verifier=""), OidcError
+ )
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=400, payload={}
)
- exc = self.get_failure(self.provider._exchange_code(code), OidcError)
+ exc = self.get_failure(
+ self.provider._exchange_code(code, code_verifier=""), OidcError
+ )
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=200, payload={"error": "some_error"}
)
- exc = self.get_failure(self.provider._exchange_code(code), OidcError)
+ exc = self.get_failure(
+ self.provider._exchange_code(code, code_verifier=""), OidcError
+ )
self.assertEqual(exc.value.error, "some_error")
@override_config(
@@ -674,7 +823,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# timestamps.
self.reactor.advance(1000)
start_time = self.reactor.seconds()
- ret = self.get_success(self.provider._exchange_code(code))
+ ret = self.get_success(self.provider._exchange_code(code, code_verifier=""))
self.assertEqual(ret, token)
@@ -725,7 +874,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
payload=token
)
code = "code"
- ret = self.get_success(self.provider._exchange_code(code))
+ ret = self.get_success(self.provider._exchange_code(code, code_verifier=""))
self.assertEqual(ret, token)
@@ -1189,6 +1338,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
+ code_verifier="",
),
)
@@ -1198,7 +1348,7 @@ def _build_callback_request(
state: str,
session: str,
ip_address: str = "10.0.0.1",
-):
+) -> Mock:
"""Builds a fake SynapseRequest to mock the browser callback
Returns a Mock object which looks like the SynapseRequest we get from a browser
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 75934b1707..0916de64f5 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -15,12 +15,13 @@
"""Tests for the password_auth_provider interface"""
from http import HTTPStatus
-from typing import Any, Type, Union
+from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import Mock
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
+from synapse.handlers.account import AccountHandler
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
@@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider:
"""A legacy password_provider which only implements `check_password`."""
@staticmethod
- def parse_config(self):
+ def parse_config(config: JsonDict) -> None:
pass
- def __init__(self, config, account_handler):
+ def __init__(self, config: None, account_handler: AccountHandler):
pass
- def check_password(self, *args):
+ def check_password(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)
@@ -58,16 +59,16 @@ class LegacyCustomAuthProvider:
"""A legacy password_provider which implements a custom login type."""
@staticmethod
- def parse_config(self):
+ def parse_config(config: JsonDict) -> None:
pass
- def __init__(self, config, account_handler):
+ def __init__(self, config: None, account_handler: AccountHandler):
pass
- def get_supported_login_types(self):
+ def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"test.login_type": ["test_field"]}
- def check_auth(self, *args):
+ def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)
@@ -75,15 +76,15 @@ class CustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type."""
@staticmethod
- def parse_config(self):
+ def parse_config(config: JsonDict) -> None:
pass
- def __init__(self, config, api: ModuleApi):
+ def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
)
- def check_auth(self, *args):
+ def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args)
@@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider:
as a custom type."""
@staticmethod
- def parse_config(self):
+ def parse_config(config: JsonDict) -> None:
pass
- def __init__(self, config, account_handler):
+ def __init__(self, config: None, account_handler: AccountHandler):
pass
- def get_supported_login_types(self):
+ def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
- def check_auth(self, *args):
+ def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)
@@ -110,10 +111,10 @@ class PasswordCustomAuthProvider:
as well as a password login"""
@staticmethod
- def parse_config(self):
+ def parse_config(config: JsonDict) -> None:
pass
- def __init__(self, config, api: ModuleApi):
+ def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
@@ -121,10 +122,10 @@ class PasswordCustomAuthProvider:
}
)
- def check_auth(self, *args):
+ def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args)
- def check_pass(self, *args):
+ def check_pass(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)
@@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
CALLBACK_USERNAME = "get_username_for_registration"
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
- def setUp(self):
+ def setUp(self) -> None:
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
super().setUp()
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
- def test_password_only_auth_progiver_login_legacy(self):
+ def test_password_only_auth_progiver_login_legacy(self) -> None:
self.password_only_auth_provider_login_test_body()
- def password_only_auth_provider_login_test_body(self):
+ def password_only_auth_provider_login_test_body(self) -> None:
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
- def test_password_only_auth_provider_ui_auth_legacy(self):
+ def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
self.password_only_auth_provider_ui_auth_test_body()
- def password_only_auth_provider_ui_auth_test_body(self):
+ def password_only_auth_provider_ui_auth_test_body(self) -> None:
"""UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
@@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
- def test_local_user_fallback_login_legacy(self):
+ def test_local_user_fallback_login_legacy(self) -> None:
self.local_user_fallback_login_test_body()
- def local_user_fallback_login_test_body(self):
+ def local_user_fallback_login_test_body(self) -> None:
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
- def test_local_user_fallback_ui_auth_legacy(self):
+ def test_local_user_fallback_ui_auth_legacy(self) -> None:
self.local_user_fallback_ui_auth_test_body()
- def local_user_fallback_ui_auth_test_body(self):
+ def local_user_fallback_ui_auth_test_body(self) -> None:
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
- def test_no_local_user_fallback_login_legacy(self):
+ def test_no_local_user_fallback_login_legacy(self) -> None:
self.no_local_user_fallback_login_test_body()
- def no_local_user_fallback_login_test_body(self):
+ def no_local_user_fallback_login_test_body(self) -> None:
"""localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
@@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
- def test_no_local_user_fallback_ui_auth_legacy(self):
+ def test_no_local_user_fallback_ui_auth_legacy(self) -> None:
self.no_local_user_fallback_ui_auth_test_body()
- def no_local_user_fallback_ui_auth_test_body(self):
+ def no_local_user_fallback_ui_auth_test_body(self) -> None:
"""localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
@@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_password_auth_disabled_legacy(self):
+ def test_password_auth_disabled_legacy(self) -> None:
self.password_auth_disabled_test_body()
- def password_auth_disabled_test_body(self):
+ def password_auth_disabled_test_body(self) -> None:
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
@@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
- def test_custom_auth_provider_login_legacy(self):
+ def test_custom_auth_provider_login_legacy(self) -> None:
self.custom_auth_provider_login_test_body()
@override_config(providers_config(CustomAuthProvider))
- def test_custom_auth_provider_login(self):
+ def test_custom_auth_provider_login(self) -> None:
self.custom_auth_provider_login_test_body()
- def custom_auth_provider_login_test_body(self):
+ def custom_auth_provider_login_test_body(self) -> None:
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
@@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
- def test_custom_auth_provider_ui_auth_legacy(self):
+ def test_custom_auth_provider_ui_auth_legacy(self) -> None:
self.custom_auth_provider_ui_auth_test_body()
@override_config(providers_config(CustomAuthProvider))
- def test_custom_auth_provider_ui_auth(self):
+ def test_custom_auth_provider_ui_auth(self) -> None:
self.custom_auth_provider_ui_auth_test_body()
- def custom_auth_provider_ui_auth_test_body(self):
+ def custom_auth_provider_ui_auth_test_body(self) -> None:
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
@@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
- def test_custom_auth_provider_callback_legacy(self):
+ def test_custom_auth_provider_callback_legacy(self) -> None:
self.custom_auth_provider_callback_test_body()
@override_config(providers_config(CustomAuthProvider))
- def test_custom_auth_provider_callback(self):
+ def test_custom_auth_provider_callback(self) -> None:
self.custom_auth_provider_callback_test_body()
- def custom_auth_provider_callback_test_body(self):
+ def custom_auth_provider_callback_test_body(self) -> None:
callback = Mock(return_value=make_awaitable(None))
mock_password_provider.check_auth.return_value = make_awaitable(
@@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_custom_auth_password_disabled_legacy(self):
+ def test_custom_auth_password_disabled_legacy(self) -> None:
self.custom_auth_password_disabled_test_body()
@override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
- def test_custom_auth_password_disabled(self):
+ def test_custom_auth_password_disabled(self) -> None:
self.custom_auth_password_disabled_test_body()
- def custom_auth_password_disabled_test_body(self):
+ def custom_auth_password_disabled_test_body(self) -> None:
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
@@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
- def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
+ def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None:
self.custom_auth_password_disabled_localdb_enabled_test_body()
@override_config(
@@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
- def test_custom_auth_password_disabled_localdb_enabled(self):
+ def test_custom_auth_password_disabled_localdb_enabled(self) -> None:
self.custom_auth_password_disabled_localdb_enabled_test_body()
- def custom_auth_password_disabled_localdb_enabled_test_body(self):
+ def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None:
"""Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_password_custom_auth_password_disabled_login_legacy(self):
+ def test_password_custom_auth_password_disabled_login_legacy(self) -> None:
self.password_custom_auth_password_disabled_login_test_body()
@override_config(
@@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_password_custom_auth_password_disabled_login(self):
+ def test_password_custom_auth_password_disabled_login(self) -> None:
self.password_custom_auth_password_disabled_login_test_body()
- def password_custom_auth_password_disabled_login_test_body(self):
+ def password_custom_auth_password_disabled_login_test_body(self) -> None:
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
@@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
+ def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None:
self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config(
@@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_password_custom_auth_password_disabled_ui_auth(self):
+ def test_password_custom_auth_password_disabled_ui_auth(self) -> None:
self.password_custom_auth_password_disabled_ui_auth_test_body()
- def password_custom_auth_password_disabled_ui_auth_test_body(self):
+ def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None:
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
@@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
- def test_custom_auth_no_local_user_fallback_legacy(self):
+ def test_custom_auth_no_local_user_fallback_legacy(self) -> None:
self.custom_auth_no_local_user_fallback_test_body()
@override_config(
@@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
- def test_custom_auth_no_local_user_fallback(self):
+ def test_custom_auth_no_local_user_fallback(self) -> None:
self.custom_auth_no_local_user_fallback_test_body()
- def custom_auth_no_local_user_fallback_test_body(self):
+ def custom_auth_no_local_user_fallback_test_body(self) -> None:
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")
@@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
- def test_on_logged_out(self):
+ def test_on_logged_out(self) -> None:
"""Tests that the on_logged_out callback is called when the user logs out."""
self.register_user("rin", "password")
tok = self.login("rin", "password")
self.called = False
- async def on_logged_out(user_id, device_id, access_token):
+ async def on_logged_out(
+ user_id: str, device_id: Optional[str], access_token: str
+ ) -> None:
self.called = True
on_logged_out = Mock(side_effect=on_logged_out)
@@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
on_logged_out.assert_called_once()
self.assertTrue(self.called)
- def test_username(self):
+ def test_username(self) -> None:
"""Tests that the get_username_for_registration callback can define the username
of a user when registering.
"""
@@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mxid = channel.json_body["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
- def test_username_uia(self):
+ def test_username_uia(self) -> None:
"""Tests that the get_username_for_registration callback is only called at the
end of the UIA flow.
"""
@@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Set some email configuration so the test doesn't fail because of its absence.
@override_config({"email": {"notif_from": "noreply@test"}})
- def test_3pid_allowed(self):
+ def test_3pid_allowed(self) -> None:
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
the 3PID. Also checks that the module is passed a boolean indicating whether the
@@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True)
- def test_displayname(self):
+ def test_displayname(self) -> None:
"""Tests that the get_displayname_for_registration callback can define the
display name of a user when registering.
"""
@@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(display_name, username + "-foo")
- def test_displayname_uia(self):
+ def test_displayname_uia(self) -> None:
"""Tests that the get_displayname_for_registration callback is only called at the
end of the UIA flow.
"""
@@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called.
m.assert_called_once()
- def _test_3pid_allowed(self, username: str, registration: bool):
+ def _test_3pid_allowed(self, username: str, registration: bool) -> None:
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments.
@@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
client is trying to register.
"""
- async def callback(uia_results, params):
+ async def callback(uia_results: JsonDict, params: JsonDict) -> str:
self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"]
return username + "-foo"
@@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
def _send_password_login(self, user: str, password: str) -> FakeChannel:
return self._send_login(type="m.login.password", user=user, password=password)
- def _send_login(self, type, user, **params) -> FakeChannel:
- params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
+ def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel:
+ params = {"identifier": {"type": "m.id.user", "user": user}, "type": type}
+ params.update(extra_params)
channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel
- def _start_delete_device_session(self, access_token, device_id) -> str:
+ def _start_delete_device_session(self, access_token: str, device_id: str) -> str:
"""Make an initial delete device request, and return the UI Auth session ID"""
channel = self._delete_device(access_token, device_id)
self.assertEqual(channel.code, 401)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index c5981ff965..19f5322317 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Optional, cast
from unittest.mock import Mock, call
from parameterized import parameterized
from signedjson.key import generate_signing_key
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -35,7 +37,9 @@ from synapse.handlers.presence import (
)
from synapse.rest import admin
from synapse.rest.client import room
-from synapse.types import UserID, get_domain_from_id
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.util import Clock
from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_offline_to_online(self):
+ def test_offline_to_online(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_online(self):
+ def test_online_to_online(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_online_last_active_noop(self):
+ def test_online_to_online_last_active_noop(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_online_last_active(self):
+ def test_online_to_online_last_active(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_remote_ping_timer(self):
+ def test_remote_ping_timer(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_offline(self):
+ def test_online_to_offline(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
self.assertEqual(wheel_timer.insert.call_count, 0)
- def test_online_to_idle(self):
+ def test_online_to_idle(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_persisting_presence_updates(self):
+ def test_persisting_presence_updates(self) -> None:
"""Tests that the latest presence state for each user is persisted correctly"""
# Create some test users and presence states for them
presence_states = []
@@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.update_presence(presence_states))
# Check that each update is present in the database
- db_presence_states = self.get_success(
+ db_presence_states_raw = self.get_success(
self.store.get_all_presence_updates(
instance_name="master",
last_id=0,
@@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
)
# Extract presence update user ID and state information into lists of tuples
- db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
+ db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]]
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
# Compare what we put into the storage with what we got out.
@@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
class PresenceTimeoutTestCase(unittest.TestCase):
"""Tests different timers and that the timer does not change `status_msg` of user."""
- def test_idle_timer(self):
+ def test_idle_timer(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_busy_no_idle(self):
+ def test_busy_no_idle(self) -> None:
"""
Tests that a user setting their presence to busy but idling doesn't turn their
presence state into unavailable.
@@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg)
- def test_sync_timeout(self):
+ def test_sync_timeout(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_sync_online(self):
+ def test_sync_online(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_federation_ping(self):
+ def test_federation_ping(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEqual(state, new_state)
- def test_no_timeout(self):
+ def test_no_timeout(self) -> None:
user_id = "@foo:bar"
now = 5000000
@@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNone(new_state)
- def test_federation_timeout(self):
+ def test_federation_timeout(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_last_active(self):
+ def test_last_active(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase):
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
- def test_external_process_timeout(self):
+ def test_external_process_timeout(self) -> None:
"""Test that if an external process doesn't update the records for a while
we time out their syncing users presence.
"""
- process_id = 1
+ process_id = "1"
user_id = "@test:server"
# Notify handler that a user is now syncing.
@@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
)
self.assertEqual(state.state, PresenceState.OFFLINE)
- def test_user_goes_offline_by_timeout_status_msg_remain(self):
+ def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
"""Test that if a user doesn't update the records for a while
users presence goes `OFFLINE` because of timeout and `status_msg` remains.
"""
@@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, status_msg)
- def test_user_goes_offline_manually_with_no_status_msg(self):
+ def test_user_goes_offline_manually_with_no_status_msg(self) -> None:
"""Test that if a user change presence manually to `OFFLINE`
and no status is set, that `status_msg` is `None`.
"""
@@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, None)
- def test_user_goes_offline_manually_with_status_msg(self):
+ def test_user_goes_offline_manually_with_status_msg(self) -> None:
"""Test that if a user change presence manually to `OFFLINE`
and a status is set, that `status_msg` appears.
"""
@@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
user_id, PresenceState.OFFLINE, "And now here."
)
- def test_user_reset_online_with_no_status(self):
+ def test_user_reset_online_with_no_status(self) -> None:
"""Test that if a user set again the presence manually
and no status is set, that `status_msg` is `None`.
"""
@@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.ONLINE)
self.assertEqual(state.status_msg, None)
- def test_set_presence_with_status_msg_none(self):
+ def test_set_presence_with_status_msg_none(self) -> None:
"""Test that if a user set again the presence manually
and status is `None`, that `status_msg` is `None`.
"""
@@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# Mark user as online and `status_msg = None`
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
- def test_set_presence_from_syncing_not_set(self):
+ def test_set_presence_from_syncing_not_set(self) -> None:
"""Test that presence is not set by syncing if affect_presence is false"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# and status message should still be the same
self.assertEqual(state.status_msg, status_msg)
- def test_set_presence_from_syncing_is_set(self):
+ def test_set_presence_from_syncing_is_set(self) -> None:
"""Test that presence is set by syncing if affect_presence is true"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# we should now be online
self.assertEqual(state.state, PresenceState.ONLINE)
- def test_set_presence_from_syncing_keeps_status(self):
+ def test_set_presence_from_syncing_keeps_status(self) -> None:
"""Test that presence set by syncing retains status message"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
},
}
)
- def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool):
+ def test_set_presence_from_syncing_keeps_busy(
+ self, test_with_workers: bool
+ ) -> None:
"""Test that presence set by syncing doesn't affect busy status
Args:
@@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def _set_presencestate_with_status_msg(
self, user_id: str, state: str, status_msg: Optional[str]
- ):
+ ) -> None:
"""Set a PresenceState and status_msg and check the result.
Args:
@@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.instance_name = hs.get_instance_name()
self.queue = self.presence_handler.get_federation_queue()
- def test_send_and_get(self):
+ def test_send_and_get(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertFalse(limited)
self.assertCountEqual(rows, [])
- def test_send_and_get_split(self):
+ def test_send_and_get_split(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(rows, expected_rows)
- def test_clear_queue_all(self):
+ def test_clear_queue_all(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(rows, expected_rows)
- def test_partially_clear_queue(self):
+ def test_partially_clear_queue(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
servlets = [room.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
"server",
federation_http_client=None,
@@ -990,13 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
return hs
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
- config["send_federation"] = True
+ # Enable federation sending on the main process.
+ config["federation_sender_instances"] = None
return config
- def prepare(self, reactor, clock, hs):
- self.federation_sender = hs.get_federation_sender()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.federation_sender = cast(Mock, hs.get_federation_sender())
self.event_builder_factory = hs.get_event_builder_factory()
self.federation_event_handler = hs.get_federation_event_handler()
self.presence_handler = hs.get_presence_handler()
@@ -1012,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# random key to use.
self.random_signing_key = generate_signing_key("ver")
- def test_remote_joins(self):
+ def test_remote_joins(self) -> None:
# We advance time to something that isn't 0, as we use 0 as a special
# value.
self.reactor.advance(1000000000000)
@@ -1060,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
destinations={"server3"}, states=[expected_state]
)
- def test_remote_gets_presence_when_local_user_joins(self):
+ def test_remote_gets_presence_when_local_user_joins(self) -> None:
# We advance time to something that isn't 0, as we use 0 as a special
# value.
self.reactor.advance(1000000000000)
@@ -1109,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
destinations={"server2", "server3"}, states=[expected_state]
)
- def _add_new_user(self, room_id, user_id):
+ def _add_new_user(self, room_id: str, user_id: str) -> None:
"""Add new user to the room by creating an event and poking the federation API."""
hostname = get_domain_from_id(user_id)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 675aa023ac..7c174782da 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -332,7 +332,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
)
- def test_avatar_constraint_on_local_server_with_port(self):
+ def test_avatar_constraint_on_local_server_with_port(self) -> None:
"""Test that avatar metadata is correctly fetched when the media is on a local
server and the server has an explicit port.
@@ -376,7 +376,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
)
- def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
+ def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
"""Stores metadata about files in the database.
Args:
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index b55238650c..f60400ff8d 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -15,14 +15,18 @@
from copy import deepcopy
from typing import List
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EduTypes, ReceiptTypes
+from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
class ReceiptsTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.event_source = hs.get_event_sources().sources.receipt
def test_filters_out_private_receipt(self) -> None:
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 765df75d91..b9332d97dc 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Collection, List, Optional, Tuple
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@@ -22,8 +25,18 @@ from synapse.api.errors import (
ResourceLimitError,
SynapseError,
)
+from synapse.module_api import ModuleApi
+from synapse.server import HomeServer
from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.types import (
+ JsonDict,
+ Requester,
+ RoomAlias,
+ RoomID,
+ UserID,
+ create_requester,
+)
+from synapse.util import Clock
from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -33,94 +46,98 @@ from .. import unittest
class TestSpamChecker:
- def __init__(self, config, api):
+ def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
check_registration_for_spam=self.check_registration_for_spam,
)
@staticmethod
- def parse_config(config):
- return config
+ def parse_config(config: JsonDict) -> None:
+ return None
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- auth_provider_id,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str],
+ ) -> RegistrationBehaviour:
pass
class DenyAll(TestSpamChecker):
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- auth_provider_id,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str],
+ ) -> RegistrationBehaviour:
return RegistrationBehaviour.DENY
class BanAll(TestSpamChecker):
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- auth_provider_id,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str],
+ ) -> RegistrationBehaviour:
return RegistrationBehaviour.SHADOW_BAN
class BanBadIdPUser(TestSpamChecker):
async def check_registration_for_spam(
- self, email_threepid, username, request_info, auth_provider_id=None
- ):
+ self,
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> RegistrationBehaviour:
# Reject any user coming from CAS and whose username contains profanity
- if auth_provider_id == "cas" and "flimflob" in username:
+ if auth_provider_id == "cas" and username and "flimflob" in username:
return RegistrationBehaviour.DENY
return RegistrationBehaviour.ALLOW
class TestLegacyRegistrationSpamChecker:
- def __init__(self, config, api):
+ def __init__(self, config: None, api: ModuleApi):
pass
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ ) -> RegistrationBehaviour:
pass
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ ) -> RegistrationBehaviour:
return RegistrationBehaviour.ALLOW
class LegacyDenyAll(TestLegacyRegistrationSpamChecker):
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ ) -> RegistrationBehaviour:
return RegistrationBehaviour.DENY
class RegistrationTestCase(unittest.HomeserverTestCase):
"""Tests the RegistrationHandler."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs_config = self.default_config()
# some of the tests rely on us having a user consent version
@@ -145,7 +162,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastores().main
self.lots_of_users = 100
@@ -153,7 +170,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.requester = create_requester("@requester:test")
- def test_user_is_created_and_logged_in_if_doesnt_exist(self):
+ def test_user_is_created_and_logged_in_if_doesnt_exist(self) -> None:
frank = UserID.from_string("@frank:test")
user_id = frank.to_string()
requester = create_requester(user_id)
@@ -164,7 +181,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertIsInstance(result_token, str)
self.assertGreater(len(result_token), 20)
- def test_if_user_exists(self):
+ def test_if_user_exists(self) -> None:
store = self.hs.get_datastores().main
frank = UserID.from_string("@frank:test")
self.get_success(
@@ -180,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(result_token is not None)
@override_config({"limit_usage_by_mau": False})
- def test_mau_limits_when_disabled(self):
+ def test_mau_limits_when_disabled(self) -> None:
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
@override_config({"limit_usage_by_mau": True})
- def test_get_or_create_user_mau_not_blocked(self):
+ def test_get_or_create_user_mau_not_blocked(self) -> None:
self.store.count_monthly_users = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
@@ -193,7 +210,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@override_config({"limit_usage_by_mau": True})
- def test_get_or_create_user_mau_blocked(self):
+ def test_get_or_create_user_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
@@ -211,7 +228,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
@override_config({"limit_usage_by_mau": True})
- def test_register_mau_blocked(self):
+ def test_register_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
@@ -229,7 +246,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config(
{"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False}
)
- def test_auto_join_rooms_for_guests(self):
+ def test_auto_join_rooms_for_guests(self) -> None:
user_id = self.get_success(
self.handler.register_user(localpart="jeff", make_guest=True),
)
@@ -237,7 +254,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 0)
@override_config({"auto_join_rooms": ["#room:test"]})
- def test_auto_create_auto_join_rooms(self):
+ def test_auto_create_auto_join_rooms(self) -> None:
room_alias_str = "#room:test"
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -249,7 +266,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 1)
@override_config({"auto_join_rooms": []})
- def test_auto_create_auto_join_rooms_with_no_rooms(self):
+ def test_auto_create_auto_join_rooms_with_no_rooms(self) -> None:
frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string())
@@ -257,7 +274,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 0)
@override_config({"auto_join_rooms": ["#room:another"]})
- def test_auto_create_auto_join_where_room_is_another_domain(self):
+ def test_auto_create_auto_join_where_room_is_another_domain(self) -> None:
frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string())
@@ -267,13 +284,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config(
{"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False}
)
- def test_auto_create_auto_join_where_auto_create_is_false(self):
+ def test_auto_create_auto_join_where_auto_create_is_false(self) -> None:
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@override_config({"auto_join_rooms": ["#room:test"]})
- def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self):
+ def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
room_alias_str = "#room:test"
self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
@@ -284,7 +301,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
@override_config({"auto_join_rooms": ["#room:test"]})
- def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
+ def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test"
self.store.count_real_users = Mock(return_value=make_awaitable(1))
@@ -299,7 +316,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 1)
@override_config({"auto_join_rooms": ["#room:test"]})
- def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self):
+ def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
+ self,
+ ) -> None:
self.store.count_real_users = Mock(return_value=make_awaitable(2))
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
@@ -312,7 +331,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"autocreate_auto_join_rooms_federated": False,
}
)
- def test_auto_create_auto_join_rooms_federated(self):
+ def test_auto_create_auto_join_rooms_federated(self) -> None:
"""
Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it).
@@ -339,7 +358,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config(
{"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"}
)
- def test_auto_join_mxid_localpart(self):
+ def test_auto_join_mxid_localpart(self) -> None:
"""
Ensure the user still needs up in the room created by a different user.
"""
@@ -376,7 +395,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support",
}
)
- def test_auto_create_auto_join_room_preset(self):
+ def test_auto_create_auto_join_room_preset(self) -> None:
"""
Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it).
@@ -416,7 +435,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support",
}
)
- def test_auto_create_auto_join_room_preset_guest(self):
+ def test_auto_create_auto_join_room_preset_guest(self) -> None:
"""
Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it).
@@ -454,7 +473,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support",
}
)
- def test_auto_create_auto_join_room_preset_invalid_permissions(self):
+ def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None:
"""
Auto-created rooms that are private require an invite, check that
registration doesn't completely break if the inviter doesn't have proper
@@ -525,7 +544,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_rooms": ["#room:test"],
},
)
- def test_auto_create_auto_join_where_no_consent(self):
+ def test_auto_create_auto_join_where_no_consent(self) -> None:
"""Test to ensure that the first user is not auto-joined to a room if
they have not given general consent.
"""
@@ -550,19 +569,19 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 1)
- def test_register_support_user(self):
+ def test_register_support_user(self) -> None:
user_id = self.get_success(
self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
)
d = self.store.is_support_user(user_id)
self.assertTrue(self.get_success(d))
- def test_register_not_support_user(self):
+ def test_register_not_support_user(self) -> None:
user_id = self.get_success(self.handler.register_user(localpart="user"))
d = self.store.is_support_user(user_id)
self.assertFalse(self.get_success(d))
- def test_invalid_user_id_length(self):
+ def test_invalid_user_id_length(self) -> None:
invalid_user_id = "x" * 256
self.get_failure(
self.handler.register_user(localpart=invalid_user_id), SynapseError
@@ -577,7 +596,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
- def test_spam_checker_deny(self):
+ def test_spam_checker_deny(self) -> None:
"""A spam checker can deny registration, which results in an error."""
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
@@ -590,7 +609,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
- def test_spam_checker_legacy_allow(self):
+ def test_spam_checker_legacy_allow(self) -> None:
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
check_registration_for_spam callback is correctly called.
@@ -610,7 +629,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
- def test_spam_checker_legacy_deny(self):
+ def test_spam_checker_legacy_deny(self) -> None:
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
check_registration_for_spam callback is correctly called.
@@ -630,7 +649,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
- def test_spam_checker_shadow_ban(self):
+ def test_spam_checker_shadow_ban(self) -> None:
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
user_id = self.get_success(self.handler.register_user(localpart="user"))
@@ -660,7 +679,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
- def test_spam_checker_receives_sso_type(self):
+ def test_spam_checker_receives_sso_type(self) -> None:
"""Test rejecting registration based on SSO type"""
f = self.get_failure(
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
@@ -678,8 +697,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
async def get_or_create_user(
- self, requester, localpart, displayname, password_hash=None
- ):
+ self,
+ requester: Requester,
+ localpart: str,
+ displayname: Optional[str],
+ password_hash: Optional[str] = None,
+ ) -> Tuple[str, str]:
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -734,13 +757,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
"""Tests auto-join on remote rooms."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.room_id = "!roomid:remotetest"
- async def update_membership(*args, **kwargs):
+ async def update_membership(*args: Any, **kwargs: Any) -> None:
pass
- async def lookup_room_alias(*args, **kwargs):
+ async def lookup_room_alias(
+ *args: Any, **kwargs: Any
+ ) -> Tuple[RoomID, List[str]]:
return RoomID.from_string(self.room_id), ["remotetest"]
self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"])
@@ -750,12 +775,12 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler)
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastores().main
@override_config({"auto_join_rooms": ["#room:remotetest"]})
- def test_auto_create_auto_join_remote_room(self):
+ def test_auto_create_auto_join_remote_room(self) -> None:
"""Tests that we don't attempt to create remote rooms, and that we don't attempt
to invite ourselves to rooms we're not in."""
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index fcde5dab72..df95490d3b 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -14,7 +14,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
]
@override_config({"encryption_enabled_by_default_for_room_type": "all"})
- def test_encrypted_by_default_config_option_all(self):
+ def test_encrypted_by_default_config_option_all(self) -> None:
"""Tests that invite-only and non-invite-only rooms have encryption enabled by
default when the config option encryption_enabled_by_default_for_room_type is "all".
"""
@@ -45,7 +45,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
@override_config({"encryption_enabled_by_default_for_room_type": "invite"})
- def test_encrypted_by_default_config_option_invite(self):
+ def test_encrypted_by_default_config_option_invite(self) -> None:
"""Tests that only new, invite-only rooms have encryption enabled by default when
the config option encryption_enabled_by_default_for_room_type is "invite".
"""
@@ -76,7 +76,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
)
@override_config({"encryption_enabled_by_default_for_room_type": "off"})
- def test_encrypted_by_default_config_option_off(self):
+ def test_encrypted_by_default_config_option_off(self) -> None:
"""Tests that neither new invite-only nor non-invite-only rooms have encryption
enabled by default when the config option
encryption_enabled_by_default_for_room_type is "off".
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 6bbfd5dc84..6a38893b68 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -171,7 +171,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
state=[create_event],
auth_chain=[create_event],
partial_state=False,
- servers_in_room=[],
+ servers_in_room=frozenset(),
)
)
)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index aa650756e4..d907fcaf04 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from unittest import mock
from twisted.internet.defer import ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import (
EventContentFields,
@@ -34,11 +35,14 @@ from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
+from synapse.util import Clock
from tests import unittest
-def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0):
+def _create_event(
+ room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0
+) -> mock.Mock:
result = mock.Mock(name=room_id)
result.room_id = room_id
result.content = {}
@@ -48,40 +52,40 @@ def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: i
return result
-def _order(*events):
+def _order(*events: mock.Mock) -> List[mock.Mock]:
return sorted(events, key=_child_events_comparison_key)
class TestSpaceSummarySort(unittest.TestCase):
- def test_no_order_last(self):
+ def test_no_order_last(self) -> None:
"""An event with no ordering is placed behind those with an ordering."""
ev1 = _create_event("!abc:test")
ev2 = _create_event("!xyz:test", "xyz")
self.assertEqual([ev2, ev1], _order(ev1, ev2))
- def test_order(self):
+ def test_order(self) -> None:
"""The ordering should be used."""
ev1 = _create_event("!abc:test", "xyz")
ev2 = _create_event("!xyz:test", "abc")
self.assertEqual([ev2, ev1], _order(ev1, ev2))
- def test_order_origin_server_ts(self):
+ def test_order_origin_server_ts(self) -> None:
"""Origin server is a tie-breaker for ordering."""
ev1 = _create_event("!abc:test", origin_server_ts=10)
ev2 = _create_event("!xyz:test", origin_server_ts=30)
self.assertEqual([ev1, ev2], _order(ev1, ev2))
- def test_order_room_id(self):
+ def test_order_room_id(self) -> None:
"""Room ID is a final tie-breaker for ordering."""
ev1 = _create_event("!abc:test")
ev2 = _create_event("!xyz:test")
self.assertEqual([ev1, ev2], _order(ev1, ev2))
- def test_invalid_ordering_type(self):
+ def test_invalid_ordering_type(self) -> None:
"""Invalid orderings are considered the same as missing."""
ev1 = _create_event("!abc:test", 1)
ev2 = _create_event("!xyz:test", "xyz")
@@ -97,7 +101,7 @@ class TestSpaceSummarySort(unittest.TestCase):
ev1 = _create_event("!abc:test", True)
self.assertEqual([ev2, ev1], _order(ev1, ev2))
- def test_invalid_ordering_value(self):
+ def test_invalid_ordering_value(self) -> None:
"""Invalid orderings are considered the same as missing."""
ev1 = _create_event("!abc:test", "foo\n")
ev2 = _create_event("!xyz:test", "xyz")
@@ -115,7 +119,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
self.handler = self.hs.get_room_summary_handler()
@@ -223,7 +227,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
)
- def test_simple_space(self):
+ def test_simple_space(self) -> None:
"""Test a simple space with a single room."""
# The result should have the space and the room in it, along with a link
# from space -> room.
@@ -234,7 +238,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_large_space(self):
+ def test_large_space(self) -> None:
"""Test a space with a large number of rooms."""
rooms = [self.room]
# Make at least 51 rooms that are part of the space.
@@ -260,7 +264,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result["rooms"] += result2["rooms"]
self._assert_hierarchy(result, expected)
- def test_visibility(self):
+ def test_visibility(self) -> None:
"""A user not in a space cannot inspect it."""
user2 = self.register_user("user2", "pass")
token2 = self.login("user2", "pass")
@@ -380,7 +384,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_hierarchy(result2, [(self.space, [self.room])])
def _create_room_with_join_rule(
- self, join_rule: str, room_version: Optional[str] = None, **extra_content
+ self, join_rule: str, room_version: Optional[str] = None, **extra_content: Any
) -> str:
"""Create a room with the given join rule and add it to the space."""
room_id = self.helper.create_room_as(
@@ -403,7 +407,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._add_child(self.space, room_id, self.token)
return room_id
- def test_filtering(self):
+ def test_filtering(self) -> None:
"""
Rooms should be properly filtered to only include rooms the user has access to.
"""
@@ -476,7 +480,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_complex_space(self):
+ def test_complex_space(self) -> None:
"""
Create a "complex" space to see how it handles things like loops and subspaces.
"""
@@ -516,7 +520,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_pagination(self):
+ def test_pagination(self) -> None:
"""Test simple pagination works."""
room_ids = []
for i in range(1, 10):
@@ -553,7 +557,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_hierarchy(result, expected)
self.assertNotIn("next_batch", result)
- def test_invalid_pagination_token(self):
+ def test_invalid_pagination_token(self) -> None:
"""An invalid pagination token, or changing other parameters, shoudl be rejected."""
room_ids = []
for i in range(1, 10):
@@ -604,7 +608,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_max_depth(self):
+ def test_max_depth(self) -> None:
"""Create a deep tree to test the max depth against."""
spaces = [self.space]
rooms = [self.room]
@@ -659,7 +663,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
]
self._assert_hierarchy(result, expected)
- def test_unknown_room_version(self):
+ def test_unknown_room_version(self) -> None:
"""
If a room with an unknown room version is encountered it should not cause
the entire summary to skip.
@@ -685,7 +689,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_fed_complex(self):
+ def test_fed_complex(self) -> None:
"""
Return data over federation and ensure that it is handled properly.
"""
@@ -722,7 +726,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
"world_readable": True,
}
- async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ async def summarize_remote_room_hierarchy(
+ _self: Any, room: Any, suggested_only: bool
+ ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return requested_room_entry, {subroom: child_room}, set()
# Add a room to the space which is on another server.
@@ -744,7 +750,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_fed_filtering(self):
+ def test_fed_filtering(self) -> None:
"""
Rooms returned over federation should be properly filtered to only include
rooms the user has access to.
@@ -853,7 +859,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
],
)
- async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ async def summarize_remote_room_hierarchy(
+ _self: Any, room: Any, suggested_only: bool
+ ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return subspace_room_entry, dict(children_rooms), set()
# Add a room to the space which is on another server.
@@ -892,7 +900,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_fed_invited(self):
+ def test_fed_invited(self) -> None:
"""
A room which the user was invited to should be included in the response.
@@ -915,7 +923,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
},
)
- async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ async def summarize_remote_room_hierarchy(
+ _self: Any, room: Any, suggested_only: bool
+ ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return fed_room_entry, {}, set()
# Add a room to the space which is on another server.
@@ -936,7 +946,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_fed_caching(self):
+ def test_fed_caching(self) -> None:
"""
Federation `/hierarchy` responses should be cached.
"""
@@ -1023,7 +1033,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
self.handler = self.hs.get_room_summary_handler()
@@ -1040,12 +1050,12 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
tok=self.token,
)
- def test_own_room(self):
+ def test_own_room(self) -> None:
"""Test a simple room created by the requester."""
result = self.get_success(self.handler.get_room_summary(self.user, self.room))
self.assertEqual(result.get("room_id"), self.room)
- def test_visibility(self):
+ def test_visibility(self) -> None:
"""A user not in a private room cannot get its summary."""
user2 = self.register_user("user2", "pass")
token2 = self.login("user2", "pass")
@@ -1093,7 +1103,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.handler.get_room_summary(user2, self.room))
self.assertEqual(result.get("room_id"), self.room)
- def test_fed(self):
+ def test_fed(self) -> None:
"""
Return data over federation and ensure that it is handled properly.
"""
@@ -1105,7 +1115,9 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
{"room_id": fed_room, "world_readable": True},
)
- async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ async def summarize_remote_room_hierarchy(
+ _self: Any, room: Any, suggested_only: bool
+ ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return requested_room_entry, {}, set()
with mock.patch(
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index a0f84e2940..9b1b8b9f13 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Set, Tuple
from unittest.mock import Mock
import attr
@@ -20,7 +20,9 @@ import attr
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import RedirectException
+from synapse.module_api import ModuleApi
from synapse.server import HomeServer
+from synapse.types import JsonDict
from synapse.util import Clock
from tests.test_utils import simple_async_mock
@@ -29,6 +31,7 @@ from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests.
try:
import saml2.config
+ import saml2.response
from saml2.sigver import SigverError
has_saml2 = True
@@ -56,31 +59,39 @@ class FakeAuthnResponse:
class TestMappingProvider:
- def __init__(self, config, module):
+ def __init__(self, config: None, module: ModuleApi):
pass
@staticmethod
- def parse_config(config):
- return
+ def parse_config(config: JsonDict) -> None:
+ return None
@staticmethod
- def get_saml_attributes(config):
+ def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]:
return {"uid"}, {"displayName"}
- def get_remote_user_id(self, saml_response, client_redirect_url):
+ def get_remote_user_id(
+ self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str
+ ) -> str:
return saml_response.ava["uid"]
def saml_response_to_user_attributes(
- self, saml_response, failures, client_redirect_url
- ):
+ self,
+ saml_response: "saml2.response.AuthnResponse",
+ failures: int,
+ client_redirect_url: str,
+ ) -> dict:
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
return {"mxid_localpart": localpart, "displayname": None}
class TestRedirectMappingProvider(TestMappingProvider):
def saml_response_to_user_attributes(
- self, saml_response, failures, client_redirect_url
- ):
+ self,
+ saml_response: "saml2.response.AuthnResponse",
+ failures: int,
+ client_redirect_url: str,
+ ) -> dict:
raise RedirectException(b"https://custom-saml-redirect/")
@@ -347,7 +358,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
-def _mock_request():
+def _mock_request() -> Mock:
"""Returns a mock which will stand in as a SynapseRequest"""
mock = Mock(
spec=[
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index da4bf8b582..8b6e4a40b6 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import List, Tuple
+from typing import Callable, List, Tuple
from zope.interface import implementer
@@ -28,20 +28,27 @@ from tests.unittest import HomeserverTestCase, override_config
@implementer(interfaces.IMessageDelivery)
class _DummyMessageDelivery:
- def __init__(self):
+ def __init__(self) -> None:
# (recipient, message) tuples
self.messages: List[Tuple[smtp.Address, bytes]] = []
- def receivedHeader(self, helo, origin, recipients):
+ def receivedHeader(
+ self,
+ helo: Tuple[bytes, bytes],
+ origin: smtp.Address,
+ recipients: List[smtp.User],
+ ) -> None:
return None
- def validateFrom(self, helo, origin):
+ def validateFrom(
+ self, helo: Tuple[bytes, bytes], origin: smtp.Address
+ ) -> smtp.Address:
return origin
- def record_message(self, recipient: smtp.Address, message: bytes):
+ def record_message(self, recipient: smtp.Address, message: bytes) -> None:
self.messages.append((recipient, message))
- def validateTo(self, user: smtp.User):
+ def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]:
return lambda: _DummyMessage(self, user)
@@ -56,20 +63,20 @@ class _DummyMessage:
self._user = user
self._buffer: List[bytes] = []
- def lineReceived(self, line):
+ def lineReceived(self, line: bytes) -> None:
self._buffer.append(line)
- def eomReceived(self):
+ def eomReceived(self) -> "defer.Deferred[bytes]":
message = b"\n".join(self._buffer) + b"\n"
self._delivery.record_message(self._user.dest, message)
return defer.succeed(b"saved")
- def connectionLost(self):
+ def connectionLost(self) -> None:
pass
class SendEmailHandlerTestCase(HomeserverTestCase):
- def test_send_email(self):
+ def test_send_email(self) -> None:
"""Happy-path test that we can send email to a non-TLS server."""
h = self.hs.get_send_email_handler()
d = ensureDeferred(
@@ -119,7 +126,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
},
}
)
- def test_send_email_force_tls(self):
+ def test_send_email_force_tls(self) -> None:
"""Happy-path test that we can send email to an Implicit TLS server."""
h = self.hs.get_send_email_handler()
d = ensureDeferred(
diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py
new file mode 100644
index 0000000000..137deab138
--- /dev/null
+++ b/tests/handlers/test_sso.py
@@ -0,0 +1,145 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from http import HTTPStatus
+from typing import BinaryIO, Callable, Dict, List, Optional, Tuple
+from unittest.mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.http_headers import Headers
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.client import RawHeaders
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.test_utils import SMALL_PNG, FakeResponse
+
+
+class TestSSOHandler(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.http_client = Mock(spec=["get_file"])
+ self.http_client.get_file.side_effect = mock_get_file
+ self.http_client.user_agent = b"Synapse Test"
+ hs = self.setup_test_homeserver(
+ proxied_blacklisted_http_client=self.http_client
+ )
+ return hs
+
+ async def test_set_avatar(self) -> None:
+ """Tests successfully setting the avatar of a newly created user"""
+ handler = self.hs.get_sso_handler()
+
+ # Create a new user to set avatar for
+ reg_handler = self.hs.get_registration_handler()
+ user_id = self.get_success(reg_handler.register_user(approved=True))
+
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # Ensure avatar is set on this newly created user,
+ # so no need to compare for the exact image
+ profile_handler = self.hs.get_profile_handler()
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ self.assertIsNot(profile["avatar_url"], None)
+
+ @unittest.override_config({"max_avatar_size": 1})
+ async def test_set_avatar_too_big_image(self) -> None:
+ """Tests that saving an avatar fails when it is too big"""
+ handler = self.hs.get_sso_handler()
+
+ # any random user works since image check is supposed to fail
+ user_id = "@sso-user:test"
+
+ self.assertFalse(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/jpeg"]})
+ async def test_set_avatar_incorrect_mime_type(self) -> None:
+ """Tests that saving an avatar fails when its mime type is not allowed"""
+ handler = self.hs.get_sso_handler()
+
+ # any random user works since image check is supposed to fail
+ user_id = "@sso-user:test"
+
+ self.assertFalse(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ async def test_skip_saving_avatar_when_not_changed(self) -> None:
+ """Tests whether saving of avatar correctly skips if the avatar hasn't
+ changed"""
+ handler = self.hs.get_sso_handler()
+
+ # Create a new user to set avatar for
+ reg_handler = self.hs.get_registration_handler()
+ user_id = self.get_success(reg_handler.register_user(approved=True))
+
+ # set avatar for the first time, should be a success
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # get avatar picture for comparison after another attempt
+ profile_handler = self.hs.get_profile_handler()
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ url_to_match = profile["avatar_url"]
+
+ # set same avatar for the second time, should be a success
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # compare avatar picture's url from previous step
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ self.assertEqual(profile["avatar_url"], url_to_match)
+
+
+async def mock_get_file(
+ url: str,
+ output_stream: BinaryIO,
+ max_size: Optional[int] = None,
+ headers: Optional[RawHeaders] = None,
+ is_allowed_content_type: Optional[Callable[[str], bool]] = None,
+) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
+
+ fake_response = FakeResponse(code=404)
+ if url == "http://my.server/me.png":
+ fake_response = FakeResponse(
+ code=200,
+ headers=Headers(
+ {"Content-Type": ["image/png"], "Content-Length": [str(len(SMALL_PNG))]}
+ ),
+ body=SMALL_PNG,
+ )
+
+ if max_size is not None and max_size < len(SMALL_PNG):
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY,
+ "Requested file is too large > %r bytes" % (max_size,),
+ Codes.TOO_LARGE,
+ )
+
+ if is_allowed_content_type and not is_allowed_content_type("image/png"):
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY,
+ (
+ "Requested file's content type not allowed for this operation: %s"
+ % "image/png"
+ ),
+ )
+
+ output_stream.write(fake_response.body)
+
+ return len(SMALL_PNG), {b"Content-Type": [b"image/png"]}, "", 200
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 05f9ec3c51..f1a50c5bcb 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -12,9 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.storage.databases.main import stats
+from synapse.util import Clock
from tests import unittest
@@ -32,11 +38,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = self.hs.get_stats_handler()
- def _add_background_updates(self):
+ def _add_background_updates(self) -> None:
"""
Add the background updates we need to run.
"""
@@ -63,12 +69,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- async def get_all_room_state(self):
+ async def get_all_room_state(self) -> List[Dict[str, Any]]:
return await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
- def _get_current_stats(self, stats_type, stat_id):
+ def _get_current_stats(
+ self, stats_type: str, stat_id: str
+ ) -> Optional[Dict[str, Any]]:
table, id_col = stats.TYPE_TO_TABLE[stats_type]
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
@@ -82,13 +90,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- def _perform_background_initial_update(self):
+ def _perform_background_initial_update(self) -> None:
# Do the initial population of the stats via the background update
self._add_background_updates()
self.wait_for_background_updates()
- def test_initial_room(self):
+ def test_initial_room(self) -> None:
"""
The background updates will build the table from scratch.
"""
@@ -125,7 +133,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(len(r), 1)
self.assertEqual(r[0]["topic"], "foo")
- def test_create_user(self):
+ def test_create_user(self) -> None:
"""
When we create a user, it should have statistics already ready.
"""
@@ -134,12 +142,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u1stats = self._get_current_stats("user", u1)
- self.assertIsNotNone(u1stats)
+ assert u1stats is not None
# not in any rooms by default
self.assertEqual(u1stats["joined_rooms"], 0)
- def test_create_room(self):
+ def test_create_room(self) -> None:
"""
When we create a room, it should have statistics already ready.
"""
@@ -153,8 +161,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
r2stats = self._get_current_stats("room", r2)
- self.assertIsNotNone(r1stats)
- self.assertIsNotNone(r2stats)
+ assert r1stats is not None
+ assert r2stats is not None
self.assertEqual(
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
@@ -171,7 +179,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(r2stats["invited_members"], 0)
self.assertEqual(r2stats["banned_members"], 0)
- def test_updating_profile_information_does_not_increase_joined_members_count(self):
+ def test_updating_profile_information_does_not_increase_joined_members_count(
+ self,
+ ) -> None:
"""
Check that the joined_members count does not increase when a user changes their
profile information (which is done by sending another join membership event into
@@ -186,6 +196,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Get the current room stats
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
# Send a profile update into the room
new_profile = {"displayname": "bob"}
@@ -195,6 +206,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Get the new room stats
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
# Ensure that the user count did not changed
self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
@@ -202,7 +214,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
)
- def test_send_state_event_nonoverwriting(self):
+ def test_send_state_event_nonoverwriting(self) -> None:
"""
When we send a non-overwriting state event, it increments current_state_events
"""
@@ -218,19 +230,21 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.send_state(
r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy"
)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
1,
)
- def test_join_first_time(self):
+ def test_join_first_time(self) -> None:
"""
When a user joins a room for the first time, current_state_events and
joined_members should increase by exactly 1.
@@ -246,10 +260,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u2token = self.login("u2", "pass")
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.join(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -259,7 +275,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1
)
- def test_join_after_leave(self):
+ def test_join_after_leave(self) -> None:
"""
When a user joins a room after being previously left,
joined_members should increase by exactly 1.
@@ -280,10 +296,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.leave(r1, u2, tok=u2token)
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.join(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -296,7 +314,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["left_members"] - r1stats_ante["left_members"], -1
)
- def test_invited(self):
+ def test_invited(self) -> None:
"""
When a user invites another user, current_state_events and
invited_members should increase by exactly 1.
@@ -311,10 +329,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u2 = self.register_user("u2", "pass")
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.invite(r1, u1, u2, tok=u1token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -324,7 +344,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1
)
- def test_join_after_invite(self):
+ def test_join_after_invite(self) -> None:
"""
When a user joins a room after being invited and
joined_members should increase by exactly 1.
@@ -344,10 +364,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.invite(r1, u1, u2, tok=u1token)
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.join(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -360,7 +382,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1
)
- def test_left(self):
+ def test_left(self) -> None:
"""
When a user leaves a room after joining and
left_members should increase by exactly 1.
@@ -380,10 +402,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.join(r1, u2, tok=u2token)
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.leave(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -396,7 +420,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
)
- def test_banned(self):
+ def test_banned(self) -> None:
"""
When a user is banned from a room after joining and
left_members should increase by exactly 1.
@@ -416,10 +440,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.join(r1, u2, tok=u2token)
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.change_membership(r1, u1, u2, "ban", tok=u1token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -432,7 +458,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
)
- def test_initial_background_update(self):
+ def test_initial_background_update(self) -> None:
"""
Test that statistics can be generated by the initial background update
handler.
@@ -462,6 +488,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats = self._get_current_stats("room", r1)
u1stats = self._get_current_stats("user", u1)
+ assert r1stats is not None
+ assert u1stats is not None
+
self.assertEqual(r1stats["joined_members"], 1)
self.assertEqual(
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
@@ -469,7 +498,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(u1stats["joined_rooms"], 1)
- def test_incomplete_stats(self):
+ def test_incomplete_stats(self) -> None:
"""
This tests that we track incomplete statistics.
@@ -533,8 +562,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.wait_for_background_updates()
r1stats_complete = self._get_current_stats("room", r1)
+ assert r1stats_complete is not None
u1stats_complete = self._get_current_stats("user", u1)
+ assert u1stats_complete is not None
u2stats_complete = self._get_current_stats("user", u2)
+ assert u2stats_complete is not None
# now we make our assertions
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index ab5c101eb7..0d9a3de92a 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -14,6 +14,8 @@
from typing import Optional
from unittest.mock import MagicMock, Mock, patch
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import Filtering
@@ -23,6 +25,7 @@ from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util import Clock
import tests.unittest
import tests.utils
@@ -39,7 +42,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastores().main
@@ -47,7 +50,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth_blocking()
- def test_wait_for_sync_for_user_auth_blocking(self):
+ def test_wait_for_sync_for_user_auth_blocking(self) -> None:
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = generate_sync_config(user_id1)
@@ -82,7 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def test_unknown_room_version(self):
+ def test_unknown_room_version(self) -> None:
"""
A room with an unknown room version should not break sync (and should be excluded).
"""
@@ -186,7 +189,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertNotIn(invite_room, [r.room_id for r in result.invited])
self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
- def test_ban_wins_race_with_join(self):
+ def test_ban_wins_race_with_join(self) -> None:
"""Rooms shouldn't appear under "joined" if a join loses a race to a ban.
A complicated edge case. Imagine the following scenario:
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 9c821b3042..1fe9563c98 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -14,21 +14,22 @@
import json
-from typing import Dict
+from typing import Dict, List, Set
from unittest.mock import ANY, Mock, call
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
+from synapse.handlers.typing import TypingWriterHandler
from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from tests import unittest
+from tests.server import ThreadedMemoryReactorClock
from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -62,7 +63,11 @@ def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ def make_homeserver(
+ self,
+ reactor: ThreadedMemoryReactorClock,
+ clock: Clock,
+ ) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
@@ -75,8 +80,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
+ self.mock_hs_notifier = Mock()
hs = self.setup_test_homeserver(
- notifier=Mock(),
+ notifier=self.mock_hs_notifier,
federation_http_client=mock_federation_client,
keyring=mock_keyring,
replication_streams={},
@@ -90,32 +96,38 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return d
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- mock_notifier = hs.get_notifier()
- self.on_new_event = mock_notifier.on_new_event
+ self.on_new_event = self.mock_hs_notifier.on_new_event
- self.handler = hs.get_typing_handler()
+ # hs.get_typing_handler will return a TypingWriterHandler when calling it
+ # from the main process, and a FollowerTypingHandler on workers.
+ # We rely on methods only available on the former, so assert we have the
+ # correct type here. We have to assign self.handler after the assert,
+ # otherwise mypy will treat it as a FollowerTypingHandler
+ handler = hs.get_typing_handler()
+ assert isinstance(handler, TypingWriterHandler)
+ self.handler = handler
self.event_source = hs.get_event_sources().sources.typing
self.datastore = hs.get_datastores().main
+
self.datastore.get_destination_retry_timings = Mock(
return_value=make_awaitable(None)
)
- self.datastore.get_device_updates_by_remote = Mock(
+ self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment]
return_value=make_awaitable((0, []))
)
- self.datastore.get_destination_last_successful_stream_ordering = Mock(
+ self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
- def get_received_txn_response(*args):
- return defer.succeed(None)
+ self.datastore.get_received_txn_response = Mock( # type: ignore[assignment]
+ return_value=make_awaitable(None)
+ )
- self.datastore.get_received_txn_response = get_received_txn_response
-
- self.room_members = []
+ self.room_members: List[UserID] = []
async def check_user_in_room(room_id: str, requester: Requester) -> None:
if requester.user.to_string() not in [
@@ -124,47 +136,54 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
raise AuthError(401, "User is not in the room")
return None
- hs.get_auth().check_user_in_room = check_user_in_room
+ hs.get_auth().check_user_in_room = Mock( # type: ignore[assignment]
+ side_effect=check_user_in_room
+ )
async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID
- hs.get_event_auth_handler().is_host_in_room = check_host_in_room
+ hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[assignment]
+ side_effect=check_host_in_room
+ )
- async def get_current_hosts_in_room(room_id: str):
+ async def get_current_hosts_in_room(room_id: str) -> Set[str]:
return {member.domain for member in self.room_members}
- hs.get_storage_controllers().state.get_current_hosts_in_room = (
- get_current_hosts_in_room
+ hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
+ side_effect=get_current_hosts_in_room
)
- hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
- get_current_hosts_in_room
+ hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[assignment]
+ side_effect=get_current_hosts_in_room
)
- async def get_users_in_room(room_id: str):
+ async def get_users_in_room(room_id: str) -> Set[str]:
return {str(u) for u in self.room_members}
- self.datastore.get_users_in_room = get_users_in_room
+ self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room)
- self.datastore.get_user_directory_stream_pos = Mock(
+ self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment]
side_effect=(
- # we deliberately return a non-None stream pos to avoid doing an initial_spam
+ # we deliberately return a non-None stream pos to avoid
+ # doing an initial_sync
lambda: make_awaitable(1)
)
)
- self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))
+ self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment]
- self.datastore.get_to_device_stream_token = lambda: 0
- self.datastore.get_new_device_msgs_for_remote = (
- lambda *args, **kargs: make_awaitable(([], 0))
+ self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment]
+ side_effect=lambda: 0
)
- self.datastore.delete_device_msgs_for_remote = (
- lambda *args, **kargs: make_awaitable(None)
+ self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment]
+ side_effect=lambda *args, **kargs: make_awaitable(([], 0))
)
- self.datastore.set_received_txn_response = (
- lambda *args, **kwargs: make_awaitable(None)
+ self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment]
+ side_effect=lambda *args, **kargs: make_awaitable(None)
+ )
+ self.datastore.set_received_txn_response = Mock( # type: ignore[assignment]
+ side_effect=lambda *args, **kwargs: make_awaitable(None)
)
def test_started_typing_local(self) -> None:
@@ -186,7 +205,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
self.event_source.get_new_events(
- user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+ user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False
)
)
self.assertEqual(
@@ -200,7 +219,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
- @override_config({"send_federation": True})
+ # Enable federation sending on the main process.
+ @override_config({"federation_sender_instances": None})
def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION]
@@ -256,7 +276,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
self.event_source.get_new_events(
- user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+ user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False
)
)
self.assertEqual(
@@ -297,7 +317,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.event_source.get_new_events(
user=U_APPLE,
from_key=0,
- limit=None,
+ limit=0,
room_ids=[OTHER_ROOM_ID],
is_guest=False,
)
@@ -305,7 +325,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0], [])
self.assertEqual(events[1], 0)
- @override_config({"send_federation": True})
+ # Enable federation sending on the main process.
+ @override_config({"federation_sender_instances": None})
def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION]
@@ -349,7 +370,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
self.event_source.get_new_events(
- user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+ user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False
)
)
self.assertEqual(
@@ -385,7 +406,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.event_source.get_new_events(
user=U_APPLE,
from_key=0,
- limit=None,
+ limit=0,
room_ids=[ROOM_ID],
is_guest=False,
)
@@ -410,7 +431,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.event_source.get_new_events(
user=U_APPLE,
from_key=1,
- limit=None,
+ limit=0,
room_ids=[ROOM_ID],
is_guest=False,
)
@@ -445,7 +466,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.event_source.get_new_events(
user=U_APPLE,
from_key=0,
- limit=None,
+ limit=0,
room_ids=[ROOM_ID],
is_guest=False,
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9e39cd97e5..75fc5a17a4 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -56,7 +56,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["update_user_directory"] = True
+ # Re-enables updating the user directory, as that function is needed below.
+ config["update_user_directory_from_worker"] = None
self.appservice = ApplicationService(
token="i_am_an_app_service",
@@ -1045,7 +1046,9 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["update_user_directory"] = True
+ # Re-enables updating the user directory, as that function is needed below. It
+ # will be force disabled later
+ config["update_user_directory_from_worker"] = None
hs = self.setup_test_homeserver(config=config)
self.config = hs.config
diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index 1acf5666a8..1c5de95a80 100644
--- a/tests/logging/__init__.py
+++ b/tests/logging/__init__.py
@@ -13,9 +13,11 @@
# limitations under the License.
import logging
+from tests.unittest import TestCase
-class LoggerCleanupMixin:
- def get_logger(self, handler):
+
+class LoggerCleanupMixin(TestCase):
+ def get_logger(self, handler: logging.Handler) -> logging.Logger:
"""
Attach a handler to a logger and add clean-ups to remove revert this.
"""
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index 0917e478a5..e28ba84cc2 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -153,7 +153,7 @@ class LogContextScopeManagerTestCase(TestCase):
scopes = []
- async def task(i: int):
+ async def task(i: int) -> None:
scope = start_active_span(
f"task{i}",
tracer=self._tracer,
@@ -165,7 +165,7 @@ class LogContextScopeManagerTestCase(TestCase):
self.assertEqual(self._tracer.active_span, scope.span)
scope.close()
- async def root():
+ async def root() -> None:
with start_active_span("root span", tracer=self._tracer) as root_scope:
self.assertEqual(self._tracer.active_span, root_scope.span)
scopes.append(root_scope)
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index b0d046fe00..c08954d887 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.test.proto_helpers import AccumulatingProtocol
+from typing import Tuple
+
+from twisted.internet.protocol import Protocol
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from synapse.logging import RemoteHandler
@@ -20,7 +23,9 @@ from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
-def connect_logging_client(reactor, client_id):
+def connect_logging_client(
+ reactor: MemoryReactorClock, client_id: int
+) -> Tuple[Protocol, AccumulatingProtocol]:
# This is essentially tests.server.connect_client, but disabling autoflush on
# the client transport. This is necessary to avoid an infinite loop due to
# sending of data via the logging transport causing additional logs to be
@@ -35,10 +40,10 @@ def connect_logging_client(reactor, client_id):
class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.reactor, _ = get_clock()
- def test_log_output(self):
+ def test_log_output(self) -> None:
"""
The remote handler delivers logs over TCP.
"""
@@ -51,6 +56,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
client, server = connect_logging_client(self.reactor, 0)
# Trigger data being sent
+ assert isinstance(client.transport, FakeTransport)
client.transport.flush()
# One log message, with a single trailing newline
@@ -61,7 +67,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Ensure the data passed through properly.
self.assertEqual(logs[0], "Hello there, wally!")
- def test_log_backpressure_debug(self):
+ def test_log_backpressure_debug(self) -> None:
"""
When backpressure is hit, DEBUG logs will be shed.
"""
@@ -83,6 +89,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
+ assert isinstance(client.transport, FakeTransport)
client.transport.flush()
# Only the 7 infos made it through, the debugs were elided
@@ -90,7 +97,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
self.assertEqual(len(logs), 7)
self.assertNotIn(b"debug", server.data)
- def test_log_backpressure_info(self):
+ def test_log_backpressure_info(self) -> None:
"""
When backpressure is hit, DEBUG and INFO logs will be shed.
"""
@@ -116,6 +123,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
+ assert isinstance(client.transport, FakeTransport)
client.transport.flush()
# The 10 warnings made it through, the debugs and infos were elided
@@ -124,7 +132,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
self.assertNotIn(b"debug", server.data)
self.assertNotIn(b"info", server.data)
- def test_log_backpressure_cut_middle(self):
+ def test_log_backpressure_cut_middle(self) -> None:
"""
When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
it will cut the middle messages out.
@@ -140,6 +148,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
+ assert isinstance(client.transport, FakeTransport)
client.transport.flush()
# The first five and last five warnings made it through, the debugs and
@@ -151,7 +160,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
logs,
)
- def test_cancel_connection(self):
+ def test_cancel_connection(self) -> None:
"""
Gracefully handle the connection being cancelled.
"""
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 0b0d8737c1..fa27f1279a 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -14,24 +14,28 @@
import json
import logging
from io import BytesIO, StringIO
+from typing import cast
from unittest.mock import Mock, patch
+from twisted.web.http import HTTPChannel
from twisted.web.server import Request
from synapse.http.site import SynapseRequest
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
from synapse.logging.context import LoggingContext, LoggingContextFilter
+from synapse.types import JsonDict
from tests.logging import LoggerCleanupMixin
-from tests.server import FakeChannel
+from tests.server import FakeChannel, get_clock
from tests.unittest import TestCase
class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.output = StringIO()
+ self.reactor, _ = get_clock()
- def get_log_line(self):
+ def get_log_line(self) -> JsonDict:
# One log message, with a single trailing newline.
data = self.output.getvalue()
logs = data.splitlines()
@@ -39,7 +43,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
self.assertEqual(data.count("\n"), 1)
return json.loads(logs[0])
- def test_terse_json_output(self):
+ def test_terse_json_output(self) -> None:
"""
The Terse JSON formatter converts log messages to JSON.
"""
@@ -61,7 +65,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
- def test_extra_data(self):
+ def test_extra_data(self) -> None:
"""
Additional information can be included in the structured logging.
"""
@@ -93,7 +97,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
self.assertEqual(log["int"], 3)
self.assertIs(log["bool"], True)
- def test_json_output(self):
+ def test_json_output(self) -> None:
"""
The Terse JSON formatter converts log messages to JSON.
"""
@@ -114,7 +118,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
- def test_with_context(self):
+ def test_with_context(self) -> None:
"""
The logging context should be added to the JSON response.
"""
@@ -139,7 +143,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
self.assertEqual(log["log"], "Hello there, wally!")
self.assertEqual(log["request"], "name")
- def test_with_request_context(self):
+ def test_with_request_context(self) -> None:
"""
Information from the logging context request should be added to the JSON response.
"""
@@ -154,11 +158,13 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
site.server_version_string = "Server v1"
site.reactor = Mock()
site.experimental_cors_msc3886 = False
- request = SynapseRequest(FakeChannel(site, None), site)
+ request = SynapseRequest(
+ cast(HTTPChannel, FakeChannel(site, self.reactor)), site
+ )
# Call requestReceived to finish instantiating the object.
request.content = BytesIO()
- # Partially skip some of the internal processing of SynapseRequest.
- request._started_processing = Mock()
+ # Partially skip some internal processing of SynapseRequest.
+ request._started_processing = Mock() # type: ignore[assignment]
request.request_metrics = Mock(spec=["name"])
with patch.object(Request, "render"):
request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
@@ -200,7 +206,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
self.assertEqual(log["protocol"], "1.1")
self.assertEqual(log["user_agent"], "")
- def test_with_exception(self):
+ def test_with_exception(self) -> None:
"""
The logging exception type & value should be added to the JSON response.
"""
diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py
index bddc4228bc..7c3656d049 100644
--- a/tests/metrics/test_metrics.py
+++ b/tests/metrics/test_metrics.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Tuple
+
from typing_extensions import Protocol
try:
@@ -22,6 +24,7 @@ except ImportError:
from unittest.mock import patch
from pkg_resources import parse_version
+from prometheus_client.core import Sample
from synapse.app._base import _set_prometheus_client_use_created_metrics
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
@@ -30,7 +33,7 @@ from synapse.util.caches.deferred_cache import DeferredCache
from tests import unittest
-def get_sample_labels_value(sample):
+def get_sample_labels_value(sample: Sample) -> Tuple[Dict[str, str], float]:
"""Extract the labels and values of a sample.
prometheus_client 0.5 changed the sample type to a named tuple with more
@@ -48,12 +51,15 @@ def get_sample_labels_value(sample):
return sample.labels, sample.value
# Otherwise fall back to treating it as a plain 3 tuple.
else:
- _, labels, value = sample
+ # In older versions of prometheus_client Sample was a 3-tuple.
+ labels: Dict[str, str]
+ value: float
+ _, labels, value = sample # type: ignore[misc]
return labels, value
class TestMauLimit(unittest.TestCase):
- def test_basic(self):
+ def test_basic(self) -> None:
class MetricEntry(Protocol):
foo: int
bar: int
@@ -62,11 +68,11 @@ class TestMauLimit(unittest.TestCase):
"test1", "", labels=["test_label"], sub_metrics=["foo", "bar"]
)
- def handle1(metrics):
+ def handle1(metrics: MetricEntry) -> None:
metrics.foo += 2
metrics.bar = max(metrics.bar, 5)
- def handle2(metrics):
+ def handle2(metrics: MetricEntry) -> None:
metrics.foo += 3
metrics.bar = max(metrics.bar, 7)
@@ -116,7 +122,9 @@ class TestMauLimit(unittest.TestCase):
self.get_metrics_from_gauge(gauge),
)
- def get_metrics_from_gauge(self, gauge):
+ def get_metrics_from_gauge(
+ self, gauge: InFlightGauge
+ ) -> Dict[str, Dict[Tuple[str, ...], float]]:
results = {}
for r in gauge.collect():
@@ -129,7 +137,7 @@ class TestMauLimit(unittest.TestCase):
class BuildInfoTests(unittest.TestCase):
- def test_get_build(self):
+ def test_get_build(self) -> None:
"""
The synapse_build_info metric reports the OS version, Python version,
and Synapse version.
@@ -147,7 +155,7 @@ class BuildInfoTests(unittest.TestCase):
class CacheMetricsTests(unittest.HomeserverTestCase):
- def test_cache_metric(self):
+ def test_cache_metric(self) -> None:
"""
Caches produce metrics reflecting their state when scraped.
"""
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 058ca57e55..8f88c0117d 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -110,6 +110,24 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
+ def test_can_set_displayname(self):
+ localpart = "alice_wants_a_new_displayname"
+ user_id = self.register_user(
+ localpart, "1234", displayname="Alice", admin=False
+ )
+ found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+
+ self.get_success(
+ self.module_api.set_displayname(
+ found_userinfo.user_id, "Bob", deactivation=False
+ )
+ )
+ found_profile = self.get_success(
+ self.module_api.get_profile_for_user(localpart)
+ )
+
+ self.assertEqual(found_profile.display_name, "Bob")
+
def test_get_userinfo_by_id(self):
user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
@@ -336,7 +354,8 @@ class ModuleApiTestCase(HomeserverTestCase):
# Test sending local online presence to users from the main process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
- @override_config({"send_federation": True})
+ # Enable federation sending on the main process.
+ @override_config({"federation_sender_instances": None})
def test_send_local_online_presence_to_federation(self):
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates
@@ -385,6 +404,9 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.send_local_online_presence_to([remote_user_id])
)
+ # We don't always send out federation immediately, so we advance the clock.
+ self.reactor.advance(1000)
+
# Check that a presence update was sent as part of a federation transaction
found_update = False
calls = (
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 594e7937a8..7567756135 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -1,15 +1,38 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Optional
from unittest.mock import patch
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventContentFields, RelationTypes
from synapse.api.room_versions import RoomVersions
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.rest import admin
from synapse.rest.client import login, register, room
-from synapse.types import create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
+from synapse.util import Clock
-from tests import unittest
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase, override_config
-class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
+class TestBulkPushRuleEvaluator(HomeserverTestCase):
servlets = [
admin.register_servlets_for_client_rest_resource,
@@ -18,57 +41,372 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
register.register_servlets,
]
- def test_action_for_event_by_user_handles_noninteger_power_levels(self) -> None:
- """We should convert floats and strings to integers before passing to Rust.
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # Create a new user and room.
+ self.alice = self.register_user("alice", "pass")
+ self.token = self.login(self.alice, "pass")
+ self.requester = create_requester(self.alice)
+
+ self.room_id = self.helper.create_room_as(
+ # This is deliberately set to V9, because we want to test the logic which
+ # handles stringy power levels. Stringy power levels were outlawed in V10.
+ self.alice,
+ room_version=RoomVersions.V9.identifier,
+ tok=self.token,
+ )
+
+ self.event_creation_handler = self.hs.get_event_creation_handler()
+
+ @parameterized.expand(
+ [
+ # The historically-permitted bad values. Alice's notification should be
+ # allowed if this threshold is at or below her power level (60)
+ ("100", False),
+ ("0", True),
+ (12.34, True),
+ (60.0, True),
+ (67.89, False),
+ # Values that int(...) would not successfully cast should be ignored.
+ # The room notification level should then default to 50, per the spec, so
+ # Alice's notification is allowed.
+ (None, True),
+ # We haven't seen `"room": []` or `"room": {}` in the wild (yet), but
+ # let's check them for paranoia's sake.
+ ([], True),
+ ({}, True),
+ ]
+ )
+ def test_action_for_event_by_user_handles_noninteger_room_power_levels(
+ self, bad_room_level: object, should_permit: bool
+ ) -> None:
+ """We should convert strings in `room` to integers before passing to Rust.
+
+ Test this as follows:
+ - Create a room as Alice and invite two other users Bob and Charlie.
+ - Set PLs so that Alice has PL 60 and `notifications.room` is set to a bad value.
+ - Have Alice create a message notifying @room.
+ - Evaluate notification actions for that message. This should not raise.
+ - Look in the DB to see if that message triggered a highlight for Bob.
+
+ The test is parameterised with two arguments:
+ - the bad power level value for "room", before JSON serisalistion
+ - whether Bob should expect the message to be highlighted
Reproduces #14060.
A lack of validation: the gift that keeps on giving.
"""
- # Create a new user and room.
- alice = self.register_user("alice", "pass")
- token = self.login(alice, "pass")
+ # Join another user to the room, so that there is someone to see Alice's
+ # @room notification.
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+ self.helper.join(self.room_id, bob, tok=bob_token)
- room_id = self.helper.create_room_as(
- alice, room_version=RoomVersions.V9.identifier, tok=token
- )
-
- # Alter the power levels in that room to include stringy and floaty levels.
- # We need to suppress the validation logic or else it will reject these dodgy
- # values. (Presumably this validation was not always present.)
- event_creation_handler = self.hs.get_event_creation_handler()
- requester = create_requester(alice)
+ # Alter the power levels in that room to include the bad @room notification
+ # level. We need to suppress
+ #
+ # - canonicaljson validation, because canonicaljson forbids floats;
+ # - the event jsonschema validation, because it will forbid bad values; and
+ # - the auth rules checks, because they stop us from creating power levels
+ # with `"room": null`. (We want to test this case, because we have seen it
+ # in the wild.)
+ #
+ # We have seen stringy and null values for "room" in the wild, so presumably
+ # some of this validation was missing in the past.
with patch("synapse.events.validator.validate_canonicaljson"), patch(
"synapse.events.validator.jsonschema.validate"
- ):
- self.helper.send_state(
- room_id,
+ ), patch("synapse.handlers.event_auth.check_state_dependent_auth_rules"):
+ pl_event_id = self.helper.send_state(
+ self.room_id,
"m.room.power_levels",
{
- "users": {alice: "100"}, # stringy
- "notifications": {"room": 100.0}, # float
+ "users": {self.alice: 60},
+ "notifications": {"room": bad_room_level},
},
- token,
+ self.token,
state_key="",
- )
+ )["event_id"]
# Create a new message event, and try to evaluate it under the dodgy
# power level event.
event, context = self.get_success(
- event_creation_handler.create_event(
- requester,
+ self.event_creation_handler.create_event(
+ self.requester,
{
"type": "m.room.message",
- "room_id": room_id,
+ "room_id": self.room_id,
"content": {
"msgtype": "m.text",
- "body": "helo",
+ "body": "helo @room",
},
- "sender": alice,
+ "sender": self.alice,
},
+ prev_event_ids=[pl_event_id],
)
)
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# should not raise
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
+
+ # Did Bob see Alice's @room notification?
+ highlighted_actions = self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_select_list(
+ table="event_push_actions_staging",
+ keyvalues={
+ "event_id": event.event_id,
+ "user_id": bob,
+ "highlight": 1,
+ },
+ retcols=("*",),
+ desc="get_event_push_actions_staging",
+ )
+ )
+ self.assertEqual(len(highlighted_actions), int(should_permit))
+
+ @override_config({"push": {"enabled": False}})
+ def test_action_for_event_by_user_disabled_by_config(self) -> None:
+ """Ensure that push rules are not calculated when disabled in the config"""
+
+ # Create a new message event which should cause a notification.
+ event, context = self.get_success(
+ self.event_creation_handler.create_event(
+ self.requester,
+ {
+ "type": "m.room.message",
+ "room_id": self.room_id,
+ "content": {
+ "msgtype": "m.text",
+ "body": "helo",
+ },
+ "sender": self.alice,
+ },
+ )
+ )
+
+ bulk_evaluator = BulkPushRuleEvaluator(self.hs)
+ # Mock the method which calculates push rules -- we do this instead of
+ # e.g. checking the results in the database because we want to ensure
+ # that code isn't even running.
+ bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment]
+
+ # Ensure no actions are generated!
+ self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
+ bulk_evaluator._action_for_event_by_user.assert_not_called()
+
+ def _create_and_process(
+ self, bulk_evaluator: BulkPushRuleEvaluator, content: Optional[JsonDict] = None
+ ) -> bool:
+ """Returns true iff the `mentions` trigger an event push action."""
+ # Create a new message event which should cause a notification.
+ event, context = self.get_success(
+ self.event_creation_handler.create_event(
+ self.requester,
+ {
+ "type": "test",
+ "room_id": self.room_id,
+ "content": content or {},
+ "sender": f"@bob:{self.hs.hostname}",
+ },
+ )
+ )
+
+ # Execute the push rule machinery.
+ self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
+
+ # If any actions are generated for this event, return true.
+ result = self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_select_list(
+ table="event_push_actions_staging",
+ keyvalues={"event_id": event.event_id},
+ retcols=("*",),
+ desc="get_event_push_actions_staging",
+ )
+ )
+ return len(result) > 0
+
+ @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
+ def test_user_mentions(self) -> None:
+ """Test the behavior of an event which includes invalid user mentions."""
+ bulk_evaluator = BulkPushRuleEvaluator(self.hs)
+
+ # Not including the mentions field should not notify.
+ self.assertFalse(self._create_and_process(bulk_evaluator))
+ # An empty mentions field should not notify.
+ self.assertFalse(
+ self._create_and_process(
+ bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {}}
+ )
+ )
+
+ # Non-dict mentions should be ignored.
+ mentions: Any
+ for mentions in (None, True, False, 1, "foo", []):
+ self.assertFalse(
+ self._create_and_process(
+ bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: mentions}
+ )
+ )
+
+ # A non-list should be ignored.
+ for mentions in (None, True, False, 1, "foo", {}):
+ self.assertFalse(
+ self._create_and_process(
+ bulk_evaluator,
+ {EventContentFields.MSC3952_MENTIONS: {"user_ids": mentions}},
+ )
+ )
+
+ # The Matrix ID appearing anywhere in the list should notify.
+ self.assertTrue(
+ self._create_and_process(
+ bulk_evaluator,
+ {EventContentFields.MSC3952_MENTIONS: {"user_ids": [self.alice]}},
+ )
+ )
+ self.assertTrue(
+ self._create_and_process(
+ bulk_evaluator,
+ {
+ EventContentFields.MSC3952_MENTIONS: {
+ "user_ids": ["@another:test", self.alice]
+ }
+ },
+ )
+ )
+
+ # Duplicate user IDs should notify.
+ self.assertTrue(
+ self._create_and_process(
+ bulk_evaluator,
+ {
+ EventContentFields.MSC3952_MENTIONS: {
+ "user_ids": [self.alice, self.alice]
+ }
+ },
+ )
+ )
+
+ # Invalid entries in the list are ignored.
+ self.assertFalse(
+ self._create_and_process(
+ bulk_evaluator,
+ {
+ EventContentFields.MSC3952_MENTIONS: {
+ "user_ids": [None, True, False, {}, []]
+ }
+ },
+ )
+ )
+ self.assertTrue(
+ self._create_and_process(
+ bulk_evaluator,
+ {
+ EventContentFields.MSC3952_MENTIONS: {
+ "user_ids": [None, True, False, {}, [], self.alice]
+ }
+ },
+ )
+ )
+
+ # The legacy push rule should not mention if the mentions field exists.
+ self.assertFalse(
+ self._create_and_process(
+ bulk_evaluator,
+ {
+ "body": self.alice,
+ "msgtype": "m.text",
+ EventContentFields.MSC3952_MENTIONS: {},
+ },
+ )
+ )
+
+ @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
+ def test_room_mentions(self) -> None:
+ """Test the behavior of an event which includes invalid room mentions."""
+ bulk_evaluator = BulkPushRuleEvaluator(self.hs)
+
+ # Room mentions from those without power should not notify.
+ self.assertFalse(
+ self._create_and_process(
+ bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}}
+ )
+ )
+
+ # Room mentions from those with power should notify.
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ {"notifications": {"room": 0}},
+ self.token,
+ state_key="",
+ )
+ self.assertTrue(
+ self._create_and_process(
+ bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}}
+ )
+ )
+
+ # Invalid data should not notify.
+ mentions: Any
+ for mentions in (None, False, 1, "foo", [], {}):
+ self.assertFalse(
+ self._create_and_process(
+ bulk_evaluator,
+ {EventContentFields.MSC3952_MENTIONS: {"room": mentions}},
+ )
+ )
+
+ # The legacy push rule should not mention if the mentions field exists.
+ self.assertFalse(
+ self._create_and_process(
+ bulk_evaluator,
+ {
+ "body": "@room",
+ "msgtype": "m.text",
+ EventContentFields.MSC3952_MENTIONS: {},
+ },
+ )
+ )
+
+ @override_config({"experimental_features": {"msc3958_supress_edit_notifs": True}})
+ def test_suppress_edits(self) -> None:
+ """Under the default push rules, event edits should not generate notifications."""
+ bulk_evaluator = BulkPushRuleEvaluator(self.hs)
+
+ # Create & persist an event to use as the parent of the relation.
+ event, context = self.get_success(
+ self.event_creation_handler.create_event(
+ self.requester,
+ {
+ "type": "m.room.message",
+ "room_id": self.room_id,
+ "content": {
+ "msgtype": "m.text",
+ "body": "helo",
+ },
+ "sender": self.alice,
+ },
+ )
+ )
+ self.get_success(
+ self.event_creation_handler.handle_new_client_event(
+ self.requester, events_and_context=[(event, context)]
+ )
+ )
+
+ # Room mentions from those without power should not notify.
+ self.assertFalse(
+ self._create_and_process(
+ bulk_evaluator,
+ {
+ "body": self.alice,
+ "m.relates_to": {
+ "rel_type": RelationTypes.REPLACE,
+ "event_id": event.event_id,
+ },
+ },
+ )
+ )
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index fd14568f55..ab8bb417e7 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -13,25 +13,28 @@
# limitations under the License.
import email.message
import os
-from typing import Dict, List, Sequence, Tuple
+from typing import Any, Dict, List, Sequence, Tuple
import attr
import pkg_resources
from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
-@attr.s
+@attr.s(auto_attribs=True)
class _User:
"Helper wrapper for user ID and access token"
- id = attr.ib()
- token = attr.ib()
+ id: str
+ token: str
class EmailPusherTests(HomeserverTestCase):
@@ -41,10 +44,9 @@ class EmailPusherTests(HomeserverTestCase):
room.register_servlets,
login.register_servlets,
]
- user_id = True
hijack_auth = False
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["email"] = {
@@ -66,24 +68,23 @@ class EmailPusherTests(HomeserverTestCase):
"riot_base_url": None,
}
config["public_baseurl"] = "http://aaa"
- config["start_pushers"] = True
hs = self.setup_test_homeserver(config=config)
# List[Tuple[Deferred, args, kwargs]]
self.email_attempts: List[Tuple[Deferred, Sequence, Dict]] = []
- def sendmail(*args, **kwargs):
+ def sendmail(*args: Any, **kwargs: Any) -> Deferred:
# This mocks out synapse.reactor.send_email._sendmail.
- d = Deferred()
+ d: Deferred = Deferred()
self.email_attempts.append((d, args, kwargs))
return d
- hs.get_send_email_handler()._sendmail = sendmail
+ hs.get_send_email_handler()._sendmail = sendmail # type: ignore[assignment]
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register the user who gets notified
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
@@ -130,7 +131,7 @@ class EmailPusherTests(HomeserverTestCase):
self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main
- def test_need_validated_email(self):
+ def test_need_validated_email(self) -> None:
"""Test that we can only add an email pusher if the user has validated
their email.
"""
@@ -152,7 +153,7 @@ class EmailPusherTests(HomeserverTestCase):
self.assertEqual(400, cm.exception.code)
self.assertEqual(Codes.THREEPID_NOT_FOUND, cm.exception.errcode)
- def test_simple_sends_email(self):
+ def test_simple_sends_email(self) -> None:
# Create a simple room with two users
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.helper.invite(
@@ -172,7 +173,7 @@ class EmailPusherTests(HomeserverTestCase):
self._check_for_mail()
- def test_invite_sends_email(self):
+ def test_invite_sends_email(self) -> None:
# Create a room and invite the user to it
room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
self.helper.invite(
@@ -185,7 +186,7 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about the invite
self._check_for_mail()
- def test_invite_to_empty_room_sends_email(self):
+ def test_invite_to_empty_room_sends_email(self) -> None:
# Create a room and invite the user to it
room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
self.helper.invite(
@@ -201,7 +202,7 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about the invite
self._check_for_mail()
- def test_multiple_members_email(self):
+ def test_multiple_members_email(self) -> None:
# We want to test multiple notifications, so we pause processing of push
# while we send messages.
self.pusher._pause_processing()
@@ -228,7 +229,7 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
- def test_multiple_rooms(self):
+ def test_multiple_rooms(self) -> None:
# We want to test multiple notifications from multiple rooms, so we pause
# processing of push while we send messages.
self.pusher._pause_processing()
@@ -258,7 +259,7 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
- def test_room_notifications_include_avatar(self):
+ def test_room_notifications_include_avatar(self) -> None:
# Create a room and set its avatar.
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.helper.send_state(
@@ -291,7 +292,7 @@ class EmailPusherTests(HomeserverTestCase):
)
self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html)
- def test_empty_room(self):
+ def test_empty_room(self) -> None:
"""All users leaving a room shouldn't cause the pusher to break."""
# Create a simple room with two users
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
@@ -310,7 +311,7 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about that message
self._check_for_mail()
- def test_empty_room_multiple_messages(self):
+ def test_empty_room_multiple_messages(self) -> None:
"""All users leaving a room shouldn't cause the pusher to break."""
# Create a simple room with two users
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
@@ -330,7 +331,7 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about that message
self._check_for_mail()
- def test_encrypted_message(self):
+ def test_encrypted_message(self) -> None:
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.helper.invite(
room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
@@ -343,7 +344,7 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about that message
self._check_for_mail()
- def test_no_email_sent_after_removed(self):
+ def test_no_email_sent_after_removed(self) -> None:
# Create a simple room with two users
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.helper.invite(
@@ -380,7 +381,7 @@ class EmailPusherTests(HomeserverTestCase):
pushers = list(pushers)
self.assertEqual(len(pushers), 0)
- def test_remove_unlinked_pushers_background_job(self):
+ def test_remove_unlinked_pushers_background_job(self) -> None:
"""Checks that all existing pushers associated with unlinked email addresses are removed
upon running the remove_deleted_email_pushers background update.
"""
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b383b8401f..23447cc310 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from typing import List, Optional, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -41,17 +41,12 @@ class HTTPPusherTests(HomeserverTestCase):
user_id = True
hijack_auth = False
- def default_config(self) -> Dict[str, Any]:
- config = super().default_config()
- config["start_pushers"] = True
- return config
-
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.push_attempts: List[Tuple[Deferred, str, dict]] = []
m = Mock()
- def post_json_get_json(url, body):
+ def post_json_get_json(url: str, body: JsonDict) -> Deferred:
d: Deferred = Deferred()
self.push_attempts.append((d, url, body))
return make_deferred_yieldable(d)
diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
index aff563919d..d37f8ce262 100644
--- a/tests/push/test_presentable_names.py
+++ b/tests/push/test_presentable_names.py
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple, cast
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent
+from synapse.events import EventBase, FrozenEvent
from synapse.push.presentable_names import calculate_room_name
from synapse.types import StateKey, StateMap
@@ -51,13 +51,15 @@ class MockDataStore:
)
async def get_event(
- self, event_id: StateKey, allow_none: bool = False
+ self, event_id: str, allow_none: bool = False
) -> Optional[FrozenEvent]:
assert allow_none, "Mock not configured for allow_none = False"
- return self._events.get(event_id)
+ # Decode the state key from the event ID.
+ state_key = cast(Tuple[str, str], tuple(event_id.split("|", 1)))
+ return self._events.get(state_key)
- async def get_events(self, event_ids: Iterable[StateKey]):
+ async def get_events(self, event_ids: Iterable[StateKey]) -> StateMap[EventBase]:
# This is cheating since it just returns all events.
return self._events
@@ -68,17 +70,17 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
def _calculate_room_name(
self,
- events: StateMap[dict],
+ events: Iterable[Tuple[Tuple[str, str], dict]],
user_id: str = "",
fallback_to_members: bool = True,
fallback_to_single_member: bool = True,
- ):
- # This isn't 100% accurate, but works with MockDataStore.
- room_state_ids = {k[0]: k[0] for k in events}
+ ) -> Optional[str]:
+ # Encode the state key into the event ID.
+ room_state_ids = {k[0]: "|".join(k[0]) for k in events}
return self.get_success(
calculate_room_name(
- MockDataStore(events),
+ MockDataStore(events), # type: ignore[arg-type]
room_state_ids,
user_id or self.USER_ID,
fallback_to_members,
@@ -86,9 +88,9 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
)
)
- def test_name(self):
+ def test_name(self) -> None:
"""A room name event should be used."""
- events = [
+ events: List[Tuple[Tuple[str, str], dict]] = [
((EventTypes.Name, ""), {"name": "test-name"}),
]
self.assertEqual("test-name", self._calculate_room_name(events))
@@ -100,9 +102,9 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
events = [((EventTypes.Name, ""), {"name": 1})]
self.assertEqual(1, self._calculate_room_name(events))
- def test_canonical_alias(self):
+ def test_canonical_alias(self) -> None:
"""An canonical alias should be used."""
- events = [
+ events: List[Tuple[Tuple[str, str], dict]] = [
((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
]
self.assertEqual("#test-name:test", self._calculate_room_name(events))
@@ -114,9 +116,9 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
self.assertEqual("Empty Room", self._calculate_room_name(events))
- def test_invite(self):
+ def test_invite(self) -> None:
"""An invite has special behaviour."""
- events = [
+ events: List[Tuple[Tuple[str, str], dict]] = [
((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
]
@@ -140,9 +142,9 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
]
self.assertEqual("Room Invite", self._calculate_room_name(events))
- def test_no_members(self):
+ def test_no_members(self) -> None:
"""Behaviour of an empty room."""
- events = []
+ events: List[Tuple[Tuple[str, str], dict]] = []
self.assertEqual("Empty Room", self._calculate_room_name(events))
# Note that events with invalid (or missing) membership are ignored.
@@ -152,7 +154,7 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
]
self.assertEqual("Empty Room", self._calculate_room_name(events))
- def test_no_other_members(self):
+ def test_no_other_members(self) -> None:
"""Behaviour of a room with no other members in it."""
events = [
(
@@ -185,7 +187,7 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
)
- def test_one_other_member(self):
+ def test_one_other_member(self) -> None:
"""Behaviour of a room with a single other member."""
events = [
((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
@@ -209,7 +211,7 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
]
self.assertEqual("@user:test", self._calculate_room_name(events))
- def test_other_members(self):
+ def test_other_members(self) -> None:
"""Behaviour of a room with multiple other members."""
# Two other members.
events = [
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index fe7c145840..7c430c4ecb 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Set, Union, cast
import frozendict
@@ -30,16 +30,46 @@ from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.synapse_rust.push import PushRuleEvaluator
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, JsonMapping, UserID
from synapse.util import Clock
from tests import unittest
from tests.test_utils.event_injection import create_event, inject_member_event
+class FlattenDictTestCase(unittest.TestCase):
+ def test_simple(self) -> None:
+ """Test a dictionary that isn't modified."""
+ input = {"foo": "abc"}
+ self.assertEqual(input, _flatten_dict(input))
+
+ def test_nested(self) -> None:
+ """Nested dictionaries become dotted paths."""
+ input = {"foo": {"bar": "abc"}}
+ self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input))
+
+ def test_non_string(self) -> None:
+ """Non-string items are dropped."""
+ input: Dict[str, Any] = {
+ "woo": "woo",
+ "foo": True,
+ "bar": 1,
+ "baz": None,
+ "fuzz": [],
+ "boo": {},
+ }
+ self.assertEqual({"woo": "woo"}, _flatten_dict(input))
+
+
class PushRuleEvaluatorTestCase(unittest.TestCase):
def _get_evaluator(
- self, content: JsonDict, related_events=None
+ self,
+ content: JsonMapping,
+ *,
+ has_mentions: bool = False,
+ user_mentions: Optional[Set[str]] = None,
+ room_mention: bool = False,
+ related_events: Optional[JsonDict] = None,
) -> PushRuleEvaluator:
event = FrozenEvent(
{
@@ -57,20 +87,23 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluator(
_flatten_dict(event),
+ has_mentions,
+ user_mentions or set(),
+ room_mention,
room_member_count,
sender_power_level,
- power_levels.get("notifications", {}),
+ cast(Dict[str, int], power_levels.get("notifications", {})),
{} if related_events is None else related_events,
- True,
+ related_event_match_enabled=True,
+ room_version_feature_flags=event.room_version.msc3931_push_features,
+ msc3931_enabled=True,
)
def test_display_name(self) -> None:
"""Check for a matching display name in the body of the event."""
evaluator = self._get_evaluator({"body": "foo bar baz"})
- condition = {
- "kind": "contains_display_name",
- }
+ condition = {"kind": "contains_display_name"}
# Blank names are skipped.
self.assertFalse(evaluator.matches(condition, "@user:test", ""))
@@ -90,8 +123,55 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
# A display name with spaces should work fine.
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
+ def test_user_mentions(self) -> None:
+ """Check for user mentions."""
+ condition = {"kind": "org.matrix.msc3952.is_user_mention"}
+
+ # No mentions shouldn't match.
+ evaluator = self._get_evaluator({}, has_mentions=True)
+ self.assertFalse(evaluator.matches(condition, "@user:test", None))
+
+ # An empty set shouldn't match
+ evaluator = self._get_evaluator({}, has_mentions=True, user_mentions=set())
+ self.assertFalse(evaluator.matches(condition, "@user:test", None))
+
+ # The Matrix ID appearing anywhere in the mentions list should match
+ evaluator = self._get_evaluator(
+ {}, has_mentions=True, user_mentions={"@user:test"}
+ )
+ self.assertTrue(evaluator.matches(condition, "@user:test", None))
+
+ evaluator = self._get_evaluator(
+ {}, has_mentions=True, user_mentions={"@another:test", "@user:test"}
+ )
+ self.assertTrue(evaluator.matches(condition, "@user:test", None))
+
+ # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
+ # since the BulkPushRuleEvaluator is what handles data sanitisation.
+
+ def test_room_mentions(self) -> None:
+ """Check for room mentions."""
+ condition = {"kind": "org.matrix.msc3952.is_room_mention"}
+
+ # No room mention shouldn't match.
+ evaluator = self._get_evaluator({}, has_mentions=True)
+ self.assertFalse(evaluator.matches(condition, None, None))
+
+ # Room mention should match.
+ evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True)
+ self.assertTrue(evaluator.matches(condition, None, None))
+
+ # A room mention and user mention is valid.
+ evaluator = self._get_evaluator(
+ {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True
+ )
+ self.assertTrue(evaluator.matches(condition, None, None))
+
+ # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
+ # since the BulkPushRuleEvaluator is what handles data sanitisation.
+
def _assert_matches(
- self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
+ self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
) -> None:
evaluator = self._get_evaluator(content)
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
@@ -285,7 +365,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
This tests the behaviour of tweaks_for_actions.
"""
- actions = [
+ actions: List[Union[Dict[str, str], str]] = [
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
"notify",
@@ -296,7 +376,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
{"sound": "default", "highlight": True},
)
- def test_related_event_match(self):
+ def test_related_event_match(self) -> None:
evaluator = self._get_evaluator(
{
"m.relates_to": {
@@ -308,7 +388,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
},
}
},
- {
+ related_events={
"m.in_reply_to": {
"event_id": "$parent_event_id",
"type": "m.room.message",
@@ -395,7 +475,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
)
)
- def test_related_event_match_with_fallback(self):
+ def test_related_event_match_with_fallback(self) -> None:
evaluator = self._get_evaluator(
{
"m.relates_to": {
@@ -408,7 +488,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
},
}
},
- {
+ related_events={
"m.in_reply_to": {
"event_id": "$parent_event_id",
"type": "m.room.message",
@@ -467,7 +547,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
)
)
- def test_related_event_match_no_related_event(self):
+ def test_related_event_match_no_related_event(self) -> None:
evaluator = self._get_evaluator(
{"msgtype": "m.text", "body": "Message without related event"}
)
@@ -516,7 +596,9 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
# Define an application service so that we can register appservice users
self._service_token = "some_token"
self._service = ApplicationService(
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 3029a16dda..6a7174b333 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -307,7 +307,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
stream to the master HS.
Args:
- worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
+ worker_app: Type of worker, e.g. `synapse.app.generic_worker`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
useful to e.g. pass some mocks for things like `federation_http_client`
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index 936ab4504a..e03d9b4cc0 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -44,7 +44,7 @@ class CancellableReplicationEndpoint(ReplicationEndpoint):
@cancellable
async def _handle_request( # type: ignore[override]
- self, request: Request
+ self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True}
@@ -54,6 +54,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
NAME = "uncancellable_sleep"
PATH_ARGS = ()
CACHE = False
+ WAIT_FOR_STREAMS = False
def __init__(self, hs: HomeServer):
super().__init__(hs)
@@ -64,7 +65,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request
+ self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True}
@@ -85,7 +86,7 @@ class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
- channel = self.make_request("POST", path, await_result=False)
+ channel = self.make_request("POST", path, await_result=False, content={})
test_disconnect(
self.reactor,
channel,
@@ -96,7 +97,7 @@ class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
- channel = self.make_request("POST", path, await_result=False)
+ channel = self.make_request("POST", path, await_result=False, content={})
test_disconnect(
self.reactor,
channel,
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index ffec06a0d6..bcb82c9c80 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -22,9 +22,8 @@ class FederationStreamTestCase(BaseStreamTestCase):
def _get_worker_hs_config(self) -> dict:
# enable federation sending on the worker
config = super()._get_worker_hs_config()
- # TODO: make it so we don't need both of these
- config["send_federation"] = False
- config["worker_app"] = "synapse.app.federation_sender"
+ config["worker_name"] = "federation_sender1"
+ config["federation_sender_instances"] = ["federation_sender1"]
return config
def test_catchup(self):
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
new file mode 100644
index 0000000000..2c10eab4db
--- /dev/null
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -0,0 +1,65 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet.defer import ensureDeferred
+
+from synapse.rest.client import room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+
+class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [room.register_servlets]
+ hijack_auth = True
+ user_id = "@bob:test"
+
+ def setUp(self):
+ super().setUp()
+ self.store = self.hs.get_datastores().main
+
+ def test_un_partial_stated_room_unblocks_over_replication(self) -> None:
+ """
+ Tests that, when a room is un-partial-stated on another worker,
+ pending calls to `await_full_state` get unblocked.
+ """
+
+ # Make a room.
+ room_id = self.helper.create_room_as("@bob:test")
+ # Mark the room as partial-stated.
+ self.get_success(
+ self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
+ )
+
+ worker = self.make_worker_hs("synapse.app.generic_worker")
+
+ # On the worker, attempt to get the current hosts in the room
+ d = ensureDeferred(
+ worker.get_storage_controllers().state.get_current_hosts_in_room(room_id)
+ )
+
+ self.reactor.advance(0.1)
+
+ # This should block
+ self.assertFalse(
+ d.called, "get_current_hosts_in_room/await_full_state did not block"
+ )
+
+ # On the master, clear the partial state flag.
+ self.get_success(self.store.clear_partial_state_room(room_id))
+
+ self.reactor.advance(0.1)
+
+ # The worker should have unblocked
+ self.assertTrue(
+ d.called, "get_current_hosts_in_room/await_full_state did not unblock"
+ )
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 1e299d2d67..6e4055cc21 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
+from synapse.replication.tcp.commands import PositionCommand
+
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -71,3 +75,68 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
)
+
+ def test_wait_for_stream_position(self) -> None:
+ """Check that wait for stream position correctly waits for an update from the
+ correct instance.
+ """
+ store = self.hs.get_datastores().main
+ cmd_handler = self.hs.get_replication_command_handler()
+ data_handler = self.hs.get_replication_data_handler()
+
+ worker1 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker1",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+
+ cache_id_gen = worker1.get_datastores().main._cache_id_gen
+ assert cache_id_gen is not None
+
+ self.replicate()
+
+ # First, make sure the master knows that `worker1` exists.
+ initial_token = cache_id_gen.get_current_token()
+ cmd_handler.send_command(
+ PositionCommand("caches", "worker1", initial_token, initial_token)
+ )
+ self.replicate()
+
+ # Next send out a normal RDATA, and check that waiting for that stream
+ # ID returns immediately.
+ ctx = cache_id_gen.get_next()
+ next_token = self.get_success(ctx.__aenter__())
+ self.get_success(ctx.__aexit__(None, None, None))
+
+ self.get_success(
+ data_handler.wait_for_stream_position("worker1", "caches", next_token)
+ )
+
+ # `wait_for_stream_position` should only return once master receives a
+ # notification that `next_token` has persisted.
+ ctx_worker1 = cache_id_gen.get_next()
+ next_token = self.get_success(ctx_worker1.__aenter__())
+
+ d = defer.ensureDeferred(
+ data_handler.wait_for_stream_position("worker1", "caches", next_token)
+ )
+ self.assertFalse(d.called)
+
+ # ... updating the cache ID gen on the master still shouldn't cause the
+ # deferred to wake up.
+ ctx = store._cache_id_gen.get_next()
+ self.get_success(ctx.__aenter__())
+ self.get_success(ctx.__aexit__(None, None, None))
+
+ d = defer.ensureDeferred(
+ data_handler.wait_for_stream_position("worker1", "caches", next_token)
+ )
+ self.assertFalse(d.called)
+
+ # ... but worker1 finishing (and so sending an update) should.
+ self.get_success(ctx_worker1.__aexit__(None, None, None))
+
+ self.assertTrue(d.called)
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index 43a16bb141..5d7a89e0c7 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -38,7 +38,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
- config["worker_app"] = "synapse.app.client_reader"
+ config["worker_app"] = "synapse.app.generic_worker"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
@@ -53,7 +53,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
4. Return the final request.
"""
- worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ worker_hs = self.make_worker_hs("synapse.app.generic_worker")
site = self._hs_to_site[worker_hs]
channel_1 = make_request(
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 995097d72c..eb5b376534 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -22,20 +22,20 @@ logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
- """Test using one or more client readers for registration."""
+ """Test using one or more generic workers for registration."""
servlets = [register.register_servlets]
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
- config["worker_app"] = "synapse.app.client_reader"
+ config["worker_app"] = "synapse.app.generic_worker"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def test_register_single_worker(self):
- """Test that registration works when using a single client reader worker."""
- worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ """Test that registration works when using a single generic worker."""
+ worker_hs = self.make_worker_hs("synapse.app.generic_worker")
site = self._hs_to_site[worker_hs]
channel_1 = make_request(
@@ -64,9 +64,9 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
def test_register_multi_worker(self):
- """Test that registration works when using multiple client reader workers."""
- worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
- worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
+ """Test that registration works when using multiple generic workers."""
+ worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
+ worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")
site_1 = self._hs_to_site[worker_hs_1]
channel_1 = make_request(
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 26b8bd512a..63b1dd40b5 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -25,8 +25,9 @@ from tests.unittest import HomeserverTestCase
class FederationAckTestCase(HomeserverTestCase):
def default_config(self) -> dict:
config = super().default_config()
- config["worker_app"] = "synapse.app.federation_sender"
- config["send_federation"] = False
+ config["worker_app"] = "synapse.app.generic_worker"
+ config["worker_name"] = "federation_sender1"
+ config["federation_sender_instances"] = ["federation_sender1"]
return config
def make_homeserver(self, reactor, clock):
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 6104a55aa1..c28073b8f7 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -27,17 +27,19 @@ logger = logging.getLogger(__name__)
class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
+ """
+ Various tests for federation sending on workers.
+
+ Federation sending is disabled by default, it will be enabled in each test by
+ updating 'federation_sender_instances'.
+ """
+
servlets = [
login.register_servlets,
register_servlets_for_client_rest_resource,
room.register_servlets,
]
- def default_config(self):
- conf = super().default_config()
- conf["send_federation"] = False
- return conf
-
def test_send_event_single_sender(self):
"""Test that using a single federation sender worker correctly sends a
new event.
@@ -46,8 +48,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client.put_json.return_value = make_awaitable({})
self.make_worker_hs(
- "synapse.app.federation_sender",
- {"send_federation": False},
+ "synapse.app.generic_worker",
+ {
+ "worker_name": "federation_sender1",
+ "federation_sender_instances": ["federation_sender1"],
+ },
federation_http_client=mock_client,
)
@@ -73,11 +78,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
- "synapse.app.federation_sender",
+ "synapse.app.generic_worker",
{
- "send_federation": True,
- "worker_name": "sender1",
- "federation_sender_instances": ["sender1", "sender2"],
+ "worker_name": "federation_sender1",
+ "federation_sender_instances": [
+ "federation_sender1",
+ "federation_sender2",
+ ],
},
federation_http_client=mock_client1,
)
@@ -85,11 +92,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
- "synapse.app.federation_sender",
+ "synapse.app.generic_worker",
{
- "send_federation": True,
- "worker_name": "sender2",
- "federation_sender_instances": ["sender1", "sender2"],
+ "worker_name": "federation_sender2",
+ "federation_sender_instances": [
+ "federation_sender1",
+ "federation_sender2",
+ ],
},
federation_http_client=mock_client2,
)
@@ -136,11 +145,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
- "synapse.app.federation_sender",
+ "synapse.app.generic_worker",
{
- "send_federation": True,
- "worker_name": "sender1",
- "federation_sender_instances": ["sender1", "sender2"],
+ "worker_name": "federation_sender1",
+ "federation_sender_instances": [
+ "federation_sender1",
+ "federation_sender2",
+ ],
},
federation_http_client=mock_client1,
)
@@ -148,11 +159,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
- "synapse.app.federation_sender",
+ "synapse.app.generic_worker",
{
- "send_federation": True,
- "worker_name": "sender2",
- "federation_sender_instances": ["sender1", "sender2"],
+ "worker_name": "federation_sender2",
+ "federation_sender_instances": [
+ "federation_sender1",
+ "federation_sender2",
+ ],
},
federation_http_client=mock_client2,
)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 59fea93e49..ca18ad6553 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -38,11 +38,6 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
- def default_config(self):
- conf = super().default_config()
- conf["start_pushers"] = False
- return conf
-
def _create_pusher_and_send_msg(self, localpart):
# Create a user that will get push notifications
user_id = self.register_user(localpart, "pass")
@@ -92,8 +87,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.make_worker_hs(
- "synapse.app.pusher",
- {"start_pushers": False},
+ "synapse.app.generic_worker",
+ {"worker_name": "pusher1", "pusher_instances": ["pusher1"]},
proxied_blacklisted_http_client=http_client_mock,
)
@@ -122,9 +117,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.make_worker_hs(
- "synapse.app.pusher",
+ "synapse.app.generic_worker",
{
- "start_pushers": True,
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
@@ -137,9 +131,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.make_worker_hs(
- "synapse.app.pusher",
+ "synapse.app.generic_worker",
{
- "start_pushers": True,
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index d52aee8f92..03f2112b07 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
+from synapse.handlers.device import DeviceHandler
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 8a4e5c3f77..233eba3516 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -280,7 +280,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- self.assertEqual("Unknown direction: bar", channel.json_body["error"])
+ self.assertEqual(
+ "Query parameter 'dir' must be one of ['b', 'f']",
+ channel.json_body["error"],
+ )
def test_limit_is_negative(self) -> None:
"""
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index d156be82b0..453a6e979c 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1831,7 +1831,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
def test_topo_token_is_accepted(self) -> None:
"""Test Topo Token is accepted."""
- token = "t1-0_0_0_0_0_0_0_0_0"
+ token = "t1-0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
@@ -1845,7 +1845,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
"""Test that stream token is accepted for forward pagination."""
- token = "s0_0_0_0_0_0_0_0_0"
+ token = "s0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
@@ -1857,6 +1857,46 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
self.assertIn("chunk", channel.json_body)
self.assertIn("end", channel.json_body)
+ def test_room_messages_backward(self) -> None:
+ """Test room messages can be retrieved by an admin that isn't in the room."""
+ latest_event_id = self.helper.send(
+ self.room_id, body="message 1", tok=self.user_tok
+ )["event_id"]
+
+ # Check that we get the first and second message when querying /messages.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?dir=b" % (self.room_id,),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 6, [event["content"] for event in chunk])
+
+ # in backwards, this is the first event
+ self.assertEqual(chunk[0]["event_id"], latest_event_id)
+
+ def test_room_messages_forward(self) -> None:
+ """Test room messages can be retrieved by an admin that isn't in the room."""
+ latest_event_id = self.helper.send(
+ self.room_id, body="message 1", tok=self.user_tok
+ )["event_id"]
+
+ # Check that we get the first and second message when querying /messages.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?dir=f" % (self.room_id,),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 6, [event["content"] for event in chunk])
+
+ # in forward, this is the last event
+ self.assertEqual(chunk[5]["event_id"], latest_event_id)
+
def test_room_messages_purge(self) -> None:
"""Test room messages can be retrieved by an admin that isn't in the room."""
store = self.hs.get_datastores().main
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e8c9457794..5c1ced355f 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -3994,7 +3994,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
"""
Tests that shadow-banning for a user that is not a local returns a 400
"""
- url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/shadow_ban"
channel = self.make_request(method, url, access_token=self.admin_user_tok)
self.assertEqual(400, channel.code, msg=channel.json_body)
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index c1a7fb2f8a..88f255c9ee 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -690,42 +690,22 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.hs.config.registration.enable_3pid_changes = False
client_secret = "foobar"
- session_id = self._request_token(self.email, client_secret)
-
- self.assertEqual(len(self.email_attempts), 1)
- link = self._get_link_from_email()
-
- self._validate_token(link)
-
channel = self.make_request(
"POST",
- b"/_matrix/client/unstable/account/3pid/add",
+ b"/_matrix/client/unstable/account/3pid/email/requestToken",
{
"client_secret": client_secret,
- "sid": session_id,
- "auth": {
- "type": "m.login.password",
- "user": self.user_id,
- "password": "test",
- },
+ "email": "test@example.com",
+ "send_attempt": 1,
},
- access_token=self.user_id_tok,
)
+
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
+
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertFalse(channel.json_body["threepids"])
-
def test_delete_email(self) -> None:
"""Test deleting an email from profile"""
# Add a threepid
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index c2e1e08811..6aedc1a11c 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -48,13 +48,13 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
def test_disabled(self) -> None:
channel = self.make_request("POST", endpoint, {}, access_token=None)
- self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.code, 404)
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
channel = self.make_request("POST", endpoint, {}, access_token=token)
- self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.code, 404)
@override_config({"experimental_features": {"msc3882_enabled": True}})
def test_require_auth(self) -> None:
diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py
new file mode 100644
index 0000000000..2a7fcea386
--- /dev/null
+++ b/tests/rest/client/test_receipts.py
@@ -0,0 +1,76 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+from synapse.rest.client import login, receipts, register
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+
+
+class ReceiptsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ register.register_servlets,
+ receipts.register_servlets,
+ synapse.rest.admin.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.owner = self.register_user("owner", "pass")
+ self.owner_tok = self.login("owner", "pass")
+
+ def test_send_receipt(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "/rooms/!abc:beep/receipt/m.read/$def",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ def test_send_receipt_invalid_room_id(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "/rooms/not-a-room-id/receipt/m.read/$def",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["error"], "A valid room ID and event ID must be specified"
+ )
+
+ def test_send_receipt_invalid_event_id(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "/rooms/!abc:beep/receipt/m.read/not-an-event-id",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["error"], "A valid room ID and event ID must be specified"
+ )
+
+ def test_send_receipt_invalid_receipt_type(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "/rooms/!abc:beep/receipt/invalid-receipt-type/$def",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index e3d801f7a8..c8a6911d5e 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -30,6 +30,7 @@ from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
+from tests.unittest import override_config
class BaseRelationsTestCase(unittest.HomeserverTestCase):
@@ -355,30 +356,67 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"])
+ def _assert_edit_bundle(
+ self, event_json: JsonDict, edit_event_id: str, edit_event_content: JsonDict
+ ) -> None:
+ """
+ Assert that the given event has a correctly-serialised edit event in its
+ bundled aggregations
+
+ Args:
+ event_json: the serialised event to be checked
+ edit_event_id: the ID of the edit event that we expect to be bundled
+ edit_event_content: the content of that event, excluding the 'm.relates_to`
+ property
+ """
+ relations_dict = event_json["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in [
+ "event_id",
+ "sender",
+ "origin_server_ts",
+ "content",
+ "type",
+ "unsigned",
+ ]:
+ self.assertIn(key, m_replace_dict)
+
+ expected_edit_content = {
+ "m.relates_to": {
+ "event_id": event_json["event_id"],
+ "rel_type": "m.replace",
+ }
+ }
+ expected_edit_content.update(edit_event_content)
+
+ self.assert_dict(
+ {
+ "event_id": edit_event_id,
+ "sender": self.user_id,
+ "content": expected_edit_content,
+ "type": "m.room.message",
+ },
+ m_replace_dict,
+ )
+
def test_edit(self) -> None:
"""Test that a simple edit works."""
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ edit_event_content = {
+ "msgtype": "m.text",
+ "body": "foo",
+ "m.new_content": new_body,
+ }
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
- content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ content=edit_event_content,
)
edit_event_id = channel.json_body["event_id"]
- def assert_bundle(event_json: JsonDict) -> None:
- """Assert the expected values of the bundled aggregations."""
- relations_dict = event_json["unsigned"].get("m.relations")
- self.assertIn(RelationTypes.REPLACE, relations_dict)
-
- m_replace_dict = relations_dict[RelationTypes.REPLACE]
- for key in ["event_id", "sender", "origin_server_ts"]:
- self.assertIn(key, m_replace_dict)
-
- self.assert_dict(
- {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
- )
-
# /event should return the *original* event
channel = self.make_request(
"GET",
@@ -389,7 +427,7 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(
channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
)
- assert_bundle(channel.json_body)
+ self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
# Request the room messages.
channel = self.make_request(
@@ -398,7 +436,11 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+ self._assert_edit_bundle(
+ self._find_event_in_chunk(channel.json_body["chunk"]),
+ edit_event_id,
+ edit_event_content,
+ )
# Request the room context.
# /context should return the edited event.
@@ -408,7 +450,9 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"])
+ self._assert_edit_bundle(
+ channel.json_body["event"], edit_event_id, edit_event_content
+ )
self.assertEqual(channel.json_body["event"]["content"], new_body)
# Request sync, but limit the timeline so it becomes limited (and includes
@@ -420,7 +464,11 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
self.assertTrue(room_timeline["limited"])
- assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+ self._assert_edit_bundle(
+ self._find_event_in_chunk(room_timeline["events"]),
+ edit_event_id,
+ edit_event_content,
+ )
# Request search.
channel = self.make_request(
@@ -437,7 +485,45 @@ class RelationsTestCase(BaseRelationsTestCase):
"results"
]
]
- assert_bundle(self._find_event_in_chunk(chunk))
+ self._assert_edit_bundle(
+ self._find_event_in_chunk(chunk),
+ edit_event_id,
+ edit_event_content,
+ )
+
+ @override_config({"experimental_features": {"msc3925_inhibit_edit": True}})
+ def test_edit_inhibit_replace(self) -> None:
+ """
+ If msc3925_inhibit_edit is enabled, then the original event should not be
+ replaced.
+ """
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ edit_event_content = {
+ "msgtype": "m.text",
+ "body": "foo",
+ "m.new_content": new_body,
+ }
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content=edit_event_content,
+ )
+ edit_event_id = channel.json_body["event_id"]
+
+ # /context should return the *original* event.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/context/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(
+ channel.json_body["event"]["content"], {"body": "Hi!", "msgtype": "m.text"}
+ )
+ self._assert_edit_bundle(
+ channel.json_body["event"], edit_event_id, edit_event_content
+ )
def test_multi_edit(self) -> None:
"""Test that multiple edits, including attempts by people who
@@ -455,10 +541,15 @@ class RelationsTestCase(BaseRelationsTestCase):
)
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ edit_event_content = {
+ "msgtype": "m.text",
+ "body": "foo",
+ "m.new_content": new_body,
+ }
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
- content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ content=edit_event_content,
)
edit_event_id = channel.json_body["event_id"]
@@ -480,16 +571,8 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["event"]["content"], new_body)
-
- relations_dict = channel.json_body["event"]["unsigned"].get("m.relations")
- self.assertIn(RelationTypes.REPLACE, relations_dict)
-
- m_replace_dict = relations_dict[RelationTypes.REPLACE]
- for key in ["event_id", "sender", "origin_server_ts"]:
- self.assertIn(key, m_replace_dict)
-
- self.assert_dict(
- {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ self._assert_edit_bundle(
+ channel.json_body["event"], edit_event_id, edit_event_content
)
def test_edit_reply(self) -> None:
@@ -502,11 +585,15 @@ class RelationsTestCase(BaseRelationsTestCase):
)
reply = channel.json_body["event_id"]
- new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ edit_event_content = {
+ "msgtype": "m.text",
+ "body": "foo",
+ "m.new_content": {"msgtype": "m.text", "body": "I've been edited!"},
+ }
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
- content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ content=edit_event_content,
parent_id=reply,
)
edit_event_id = channel.json_body["event_id"]
@@ -549,28 +636,22 @@ class RelationsTestCase(BaseRelationsTestCase):
# We expect that the edit relation appears in the unsigned relations
# section.
- relations_dict = result_event_dict["unsigned"].get("m.relations")
- self.assertIn(RelationTypes.REPLACE, relations_dict, desc)
-
- m_replace_dict = relations_dict[RelationTypes.REPLACE]
- for key in ["event_id", "sender", "origin_server_ts"]:
- self.assertIn(key, m_replace_dict, desc)
-
- self.assert_dict(
- {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ self._assert_edit_bundle(
+ result_event_dict, edit_event_id, edit_event_content
)
def test_edit_edit(self) -> None:
"""Test that an edit cannot be edited."""
new_body = {"msgtype": "m.text", "body": "Initial edit"}
+ edit_event_content = {
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": new_body,
+ }
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
- content={
- "msgtype": "m.text",
- "body": "Wibble",
- "m.new_content": new_body,
- },
+ content=edit_event_content,
)
edit_event_id = channel.json_body["event_id"]
@@ -599,8 +680,7 @@ class RelationsTestCase(BaseRelationsTestCase):
)
# The relations information should not include the edit to the edit.
- relations_dict = channel.json_body["unsigned"].get("m.relations")
- self.assertIn(RelationTypes.REPLACE, relations_dict)
+ self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
# /context should return the event updated for the *first* edit
# (The edit to the edit should be ignored.)
@@ -611,13 +691,8 @@ class RelationsTestCase(BaseRelationsTestCase):
)
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["event"]["content"], new_body)
-
- m_replace_dict = relations_dict[RelationTypes.REPLACE]
- for key in ["event_id", "sender", "origin_server_ts"]:
- self.assertIn(key, m_replace_dict)
-
- self.assert_dict(
- {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ self._assert_edit_bundle(
+ channel.json_body["event"], edit_event_id, edit_event_content
)
# Directly requesting the edit should not have the edit to the edit applied.
@@ -1108,7 +1183,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
- self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
@@ -1170,7 +1245,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7)
def test_nested_thread(self) -> None:
"""
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index ad00a476e1..c0eb5d01a6 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -36,7 +36,7 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
def test_disabled(self) -> None:
channel = self.make_request("POST", endpoint, {}, access_token=None)
- self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.code, 404)
@override_config({"experimental_features": {"msc3886_endpoint": "/asd"}})
def test_redirect(self) -> None:
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index e919e089cb..9222cab198 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -1987,7 +1987,7 @@ class RoomMessageListTestCase(RoomBase):
self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self) -> None:
- token = "t1-0_0_0_0_0_0_0_0_0"
+ token = "t1-0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
@@ -1998,7 +1998,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("end" in channel.json_body)
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
- token = "s0_0_0_0_0_0_0_0_0"
+ token = "s0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
@@ -2728,7 +2728,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by a label on a /messages request."""
self._send_labelled_messages_in_room()
- token = "s0_0_0_0_0_0_0_0_0"
+ token = "s0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
@@ -2745,7 +2745,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by the absence of a label on a /messages request."""
self._send_labelled_messages_in_room()
- token = "s0_0_0_0_0_0_0_0_0"
+ token = "s0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
@@ -2768,7 +2768,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""
self._send_labelled_messages_in_room()
- token = "s0_0_0_0_0_0_0_0_0"
+ token = "s0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
@@ -3546,11 +3546,6 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def default_config(self) -> JsonDict:
- config = super().default_config()
- config["experimental_features"] = {"msc3030_enabled": True}
- return config
-
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self._storage_controllers = self.hs.get_storage_controllers()
@@ -3592,7 +3587,7 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}",
+ f"/_matrix/client/v1/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}",
access_token=self.room_owner_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 0af643ecd9..b9047194dd 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -294,9 +294,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.make_request("GET", sync_url % (access_token, next_batch))
-class SyncKnockTestCase(
- unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin
-):
+class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -913,7 +911,9 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
# We need to manually append the room ID, because we can't know the ID before
# creating the room, and we can't set the config after starting the homeserver.
- self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id)
+ self.hs.get_sync_handler().rooms_to_exclude_globally.append(
+ self.excluded_room_id
+ )
def test_join_leave(self) -> None:
"""Tests that rooms are correctly excluded from the 'join' and 'leave' sections of
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 21a1ca2a68..3086e1b565 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -13,18 +13,22 @@
# limitations under the License.
from http import HTTPStatus
+from typing import Any, Generator, Tuple, cast
from unittest.mock import Mock, call
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as _reactor
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
+from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import MockClock
+reactor = cast(ISynapseReactor, _reactor)
+
class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self) -> None:
@@ -34,11 +38,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
- self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
+ self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
self.mock_key = "foo"
@defer.inlineCallbacks
- def test_executes_given_function(self):
+ def test_executes_given_function(
+ self,
+ ) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
@@ -47,7 +53,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
self.assertEqual(res, self.mock_http_response)
@defer.inlineCallbacks
- def test_deduplicates_based_on_key(self):
+ def test_deduplicates_based_on_key(
+ self,
+ ) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
@@ -58,18 +66,20 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)
@defer.inlineCallbacks
- def test_logcontexts_with_async_result(self):
+ def test_logcontexts_with_async_result(
+ self,
+ ) -> Generator["defer.Deferred[Any]", object, None]:
@defer.inlineCallbacks
- def cb():
+ def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]:
yield Clock(reactor).sleep(0)
- return "yay"
+ return 1, {}
@defer.inlineCallbacks
- def test():
+ def test() -> Generator["defer.Deferred[Any]", object, None]:
with LoggingContext("c") as c1:
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertIs(current_context(), c1)
- self.assertEqual(res, "yay")
+ self.assertEqual(res, (1, {}))
# run the test twice in parallel
d = defer.gatherResults([test(), test()])
@@ -78,13 +88,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
self.assertIs(current_context(), SENTINEL_CONTEXT)
@defer.inlineCallbacks
- def test_does_not_cache_exceptions(self):
+ def test_does_not_cache_exceptions(
+ self,
+ ) -> Generator["defer.Deferred[Any]", object, None]:
"""Checks that, if the callback throws an exception, it is called again
for the next request.
"""
called = [False]
- def cb():
+ def cb() -> "defer.Deferred[Tuple[int, JsonDict]]":
if called[0]:
# return a valid result the second time
return defer.succeed(self.mock_http_response)
@@ -104,13 +116,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
- def test_does_not_cache_failures(self):
+ def test_does_not_cache_failures(
+ self,
+ ) -> Generator["defer.Deferred[Any]", object, None]:
"""Checks that, if the callback returns a failure, it is called again
for the next request.
"""
called = [False]
- def cb():
+ def cb() -> "defer.Deferred[Tuple[int, JsonDict]]":
if called[0]:
# return a valid result the second time
return defer.succeed(self.mock_http_response)
@@ -130,7 +144,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
- def test_cleans_up(self):
+ def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
# should NOT have cleaned up yet
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 5e7bf97482..5ec343dd7f 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -199,9 +199,15 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
def test_stringy_power_levels(self) -> None:
"""The room upgrade converts stringy power levels to proper integers."""
+ # Create a room on room version < 10.
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_token, room_version="9"
+ )
+ self.helper.join(room_id, self.other, tok=self.other_token)
+
# Retrieve the room's current power levels.
power_levels = self.helper.get_state(
- self.room_id,
+ room_id,
"m.room.power_levels",
tok=self.creator_token,
)
@@ -217,14 +223,14 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
# conscience, we ought to ensure it's upgrading from a sufficiently old
# version of room.
self.helper.send_state(
- self.room_id,
+ room_id,
"m.room.power_levels",
body=power_levels,
tok=self.creator_token,
)
# Upgrade the room. Check the homeserver reports success.
- channel = self._upgrade_room()
+ channel = self._upgrade_room(room_id=room_id)
self.assertEqual(200, channel.code, channel.result)
# Extract the new room ID.
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 7f1fba1086..2bb6e27d94 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import urllib.parse
from io import BytesIO, StringIO
from typing import Any, Dict, Optional, Union
from unittest.mock import Mock
@@ -65,9 +64,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
- self.assertEqual(
- path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),)
- )
+ self.assertEqual(path, "/_matrix/key/v2/server")
response = {
"server_name": server_name,
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
index 319ae8b1cc..3f7f1dbab9 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/rest/media/v1/test_oembed.py
@@ -150,3 +150,13 @@ class OEmbedTests(HomeserverTestCase):
result = self.parse_response({"type": "link"})
self.assertIn("og:type", result.open_graph_result)
self.assertEqual(result.open_graph_result["og:type"], "website")
+
+ def test_title_html_entities(self) -> None:
+ """Test HTML entities in title"""
+ result = self.parse_response(
+ {"title": "Why JSON isn’t a Good Configuration Language"}
+ )
+ self.assertEqual(
+ result.open_graph_result["og:title"],
+ "Why JSON isn’t a Good Configuration Language",
+ )
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 7cbc40736c..dadc6efcbf 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -69,7 +69,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(1000)
)
- self._rlsn._server_notices_manager.send_notice = Mock(
+ self._rlsn._server_notices_manager.send_notice = Mock( # type: ignore[assignment]
return_value=make_awaitable(Mock())
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
@@ -82,8 +82,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock(
return_value=make_awaitable("!something:localhost")
)
- self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))
- self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
+ self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
+ self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment]
@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self):
@@ -361,9 +361,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
tok: The access token of the user that joined the room.
room_id: The ID of the room that's been joined.
"""
- user_id = None
- tok = None
- invites = []
+ # We need at least one user to process
+ self.assertGreater(self.hs.config.server.max_mau_value, 0)
+
+ invites = {}
# Register as many users as the MAU limit allows.
for i in range(self.hs.config.server.max_mau_value):
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index 50c20c5b92..373707b275 100644
--- a/tests/storage/databases/main/test_deviceinbox.py
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.rest import admin
from synapse.rest.client import devices
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -25,11 +29,11 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
devices.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
- def test_background_remove_deleted_devices_from_device_inbox(self):
+ def test_background_remove_deleted_devices_from_device_inbox(self) -> None:
"""Test that the background task to delete old device_inboxes works properly."""
# create a valid device
@@ -89,7 +93,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.assertEqual(1, len(res))
self.assertEqual(res[0], "cur_device")
- def test_background_remove_hidden_devices_from_device_inbox(self):
+ def test_background_remove_hidden_devices_from_device_inbox(self) -> None:
"""Test that the background task to delete hidden devices
from device_inboxes works properly."""
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 5773172ab8..9f33afcca0 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -45,7 +45,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
self.store: EventsWorkerStore = hs.get_datastores().main
@@ -68,7 +68,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.event_ids.append(event.event_id)
- def test_simple(self):
+ def test_simple(self) -> None:
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events(
@@ -90,7 +90,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
- def test_persisting_event_invalidates_cache(self):
+ def test_persisting_event_invalidates_cache(self) -> None:
"""
Test to make sure that the `have_seen_event` cache
is invalidated after we persist an event and returns
@@ -138,7 +138,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
- def test_invalidate_cache_by_room_id(self):
+ def test_invalidate_cache_by_room_id(self) -> None:
"""
Test to make sure that all events associated with the given `(room_id,)`
are invalidated in the `have_seen_event` cache.
@@ -175,7 +175,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main
self.user = self.register_user("user", "pass")
@@ -189,7 +189,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# Reset the event cache so the tests start with it empty
self.get_success(self.store._get_event_cache.clear())
- def test_simple(self):
+ def test_simple(self) -> None:
"""Test that we cache events that we pull from the DB."""
with LoggingContext("test") as ctx:
@@ -198,7 +198,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
- def test_event_ref(self):
+ def test_event_ref(self) -> None:
"""Test that we reuse events that are still in memory but have fallen
out of the cache, rather than requesting them from the DB.
"""
@@ -223,7 +223,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
- def test_dedupe(self):
+ def test_dedupe(self) -> None:
"""Test that if we request the same event multiple times we only pull it
out once.
"""
@@ -241,7 +241,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
class DatabaseOutageTestCase(unittest.HomeserverTestCase):
"""Test event fetching during a database outage."""
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main
self.room_id = f"!room:{hs.hostname}"
@@ -377,7 +377,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main
self.user = self.register_user("user", "pass")
@@ -412,7 +412,8 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
unblock: "Deferred[None]" = Deferred()
original_runWithConnection = self.store.db_pool.runWithConnection
- async def runWithConnection(*args, **kwargs):
+ # Don't bother with the types here, we just pass into the original function.
+ async def runWithConnection(*args, **kwargs): # type: ignore[no-untyped-def]
await unblock
return await original_runWithConnection(*args, **kwargs)
@@ -441,7 +442,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
- def test_first_get_event_cancelled(self):
+ def test_first_get_event_cancelled(self) -> None:
"""Test cancellation of the first `get_event` call sharing a database fetch.
The first `get_event` call is the one which initiates the fetch. We expect the
@@ -467,7 +468,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
# The second `get_event` call should complete successfully.
self.get_success(get_event2)
- def test_second_get_event_cancelled(self):
+ def test_second_get_event_cancelled(self) -> None:
"""Test cancellation of the second `get_event` call sharing a database fetch."""
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
# Cancel the second `get_event` call.
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 3cc2a58d8d..56cb49d9b5 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -15,18 +15,20 @@
from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
+from synapse.util import Clock
from tests import unittest
class LockTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def test_acquire_contention(self):
+ def test_acquire_contention(self) -> None:
# Track the number of tasks holding the lock.
# Should be at most 1.
in_lock = 0
@@ -34,7 +36,7 @@ class LockTestCase(unittest.HomeserverTestCase):
release_lock: "Deferred[None]" = Deferred()
- async def task():
+ async def task() -> None:
nonlocal in_lock
nonlocal max_in_lock
@@ -76,7 +78,7 @@ class LockTestCase(unittest.HomeserverTestCase):
# At most one task should have held the lock at a time.
self.assertEqual(max_in_lock, 1)
- def test_simple_lock(self):
+ def test_simple_lock(self) -> None:
"""Test that we can take out a lock and that while we hold it nobody
else can take it out.
"""
@@ -103,7 +105,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None))
- def test_maintain_lock(self):
+ def test_maintain_lock(self) -> None:
"""Test that we don't time out locks while they're still active"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -119,7 +121,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.get_success(lock.__aexit__(None, None, None))
- def test_timeout_lock(self):
+ def test_timeout_lock(self) -> None:
"""Test that we time out locks if they're not updated for ages"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -139,7 +141,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.assertFalse(self.get_success(lock.is_still_valid()))
- def test_drop(self):
+ def test_drop(self) -> None:
"""Test that dropping the context manager means we stop renewing the lock"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -153,7 +155,7 @@ class LockTestCase(unittest.HomeserverTestCase):
lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock2)
- def test_shutdown(self):
+ def test_shutdown(self) -> None:
"""Test that shutting down Synapse releases the locks"""
# Acquire two locks
lock = self.get_success(self.store.try_acquire_lock("name", "key1"))
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index c4f12d81d7..ac77aec003 100644
--- a/tests/storage/databases/main/test_receipts.py
+++ b/tests/storage/databases/main/test_receipts.py
@@ -33,7 +33,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -47,7 +47,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
table: str,
receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]],
expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]],
- ):
+ ) -> None:
"""Test that the background update to uniqueify non-thread receipts in
the given receipts table works properly.
@@ -154,7 +154,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
f"Background update did not remove all duplicate receipts from {table}",
)
- def test_background_receipts_linearized_unique_index(self):
+ def test_background_receipts_linearized_unique_index(self) -> None:
"""Test that the background update to uniqueify non-thread receipts in
`receipts_linearized` works properly.
"""
@@ -168,7 +168,9 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
{"stream_id": 6, "event_id": "$some_event"},
],
(self.other_room_id, "m.read", self.user_id): [
- {"stream_id": 7, "event_id": "$some_event"}
+ # It is possible for stream IDs to be duplicated.
+ {"stream_id": 7, "event_id": "$some_event"},
+ {"stream_id": 7, "event_id": "$some_event"},
],
},
expected_unique_receipts={
@@ -177,7 +179,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
},
)
- def test_background_receipts_graph_unique_index(self):
+ def test_background_receipts_graph_unique_index(self) -> None:
"""Test that the background update to uniqueify non-thread receipts in
`receipts_graph` works properly.
"""
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 1edb619630..3108ca3444 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -14,10 +14,14 @@
import json
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import RoomTypes
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.storage.databases.main.room import _BackgroundUpdates
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -30,17 +34,31 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
def _generate_room(self) -> str:
- room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ """Create a room and return the room ID."""
+ return self.helper.create_room_as(self.user_id, tok=self.token)
- return room_id
+ def run_background_updates(self, update_name: str) -> None:
+ """Insert and run the background update."""
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": update_name, "progress_json": "{}"},
+ )
+ )
- def test_background_populate_rooms_creator_column(self):
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ self.wait_for_background_updates()
+
+ def test_background_populate_rooms_creator_column(self) -> None:
"""Test that the background update to populate the rooms creator column
works properly.
"""
@@ -67,22 +85,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(room_creator_before, None)
- # Insert and run the background update.
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN,
- "progress_json": "{}",
- },
- )
- )
-
- # ... and tell the DataStore that it hasn't finished all updates yet
- self.store.db_pool.updates._all_done = False
-
- # Now let's actually drive the updates to completion
- self.wait_for_background_updates()
+ self.run_background_updates(_BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN)
# Make sure the background update filled in the room creator
room_creator_after = self.get_success(
@@ -95,7 +98,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(room_creator_after, self.user_id)
- def test_background_add_room_type_column(self):
+ def test_background_add_room_type_column(self) -> None:
"""Test that the background update to populate the `room_type` column in
`room_stats_state` works properly.
"""
@@ -133,22 +136,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
)
- # Insert and run the background update
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
- "progress_json": "{}",
- },
- )
- )
-
- # ... and tell the DataStore that it hasn't finished all updates yet
- self.store.db_pool.updates._all_done = False
-
- # Now let's actually drive the updates to completion
- self.wait_for_background_updates()
+ self.run_background_updates(_BackgroundUpdates.ADD_ROOM_TYPE_COLUMN)
# Make sure the background update filled in the room type
room_type_after = self.get_success(
@@ -160,3 +148,39 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
)
self.assertEqual(room_type_after, RoomTypes.SPACE)
+
+ def test_populate_stats_broken_rooms(self) -> None:
+ """Ensure that re-populating room stats skips broken rooms."""
+
+ # Create a good room.
+ good_room_id = self._generate_room()
+
+ # Create a room and then break it by having no room version.
+ room_id = self._generate_room()
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"room_version": None},
+ desc="test",
+ )
+ )
+
+ # Nuke any current stats in the database.
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="room_stats_state", keyvalues={"1": 1}, desc="test"
+ )
+ )
+
+ self.run_background_updates("populate_stats_process_rooms")
+
+ # Only the good room appears in the stats tables.
+ results = self.get_success(
+ self.store.db_pool.simple_select_onecol(
+ table="room_stats_state",
+ keyvalues={},
+ retcol="room_id",
+ )
+ )
+ self.assertEqual(results, [good_room_id])
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 09cb06d614..8bbf936ae9 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -106,7 +106,7 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
{(1, "user1", "hello"), (2, "user2", "bleb")},
)
- def test_simple_update_many(self):
+ def test_simple_update_many(self) -> None:
"""
simple_update_many performs many updates at once.
"""
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 72bf5b3d31..1bfd11ceae 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -14,13 +14,17 @@
from typing import Iterable, Optional, Set
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import AccountDataTypes
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
class IgnoredUsersTestCase(unittest.HomeserverTestCase):
- def prepare(self, hs, reactor, clock):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user = "@user:test"
@@ -55,7 +59,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
expected_ignored_user_ids,
)
- def test_ignoring_users(self):
+ def test_ignoring_users(self) -> None:
"""Basic adding/removing of users from the ignore list."""
self._update_ignore_list("@other:test", "@another:remote")
self.assert_ignored(self.user, {"@other:test", "@another:remote"})
@@ -82,7 +86,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# Check the removed user.
self.assert_ignorers("@another:remote", {self.user})
- def test_caching(self):
+ def test_caching(self) -> None:
"""Ensure that caching works properly between different users."""
# The first user ignores a user.
self._update_ignore_list("@other:test")
@@ -99,7 +103,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", {"@second:test"})
- def test_invalid_data(self):
+ def test_invalid_data(self) -> None:
"""Invalid data ends up clearing out the ignored users list."""
# Add some data and ensure it is there.
self._update_ignore_list("@other:test")
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 1047ed09c8..5e1324a169 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -26,7 +26,7 @@ from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
from synapse.events import EventBase
from synapse.server import HomeServer
-from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection, make_conn
from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
@@ -39,7 +39,7 @@ from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
- def setUp(self):
+ def setUp(self) -> None:
super(ApplicationServiceStoreTestCase, self).setUp()
self.as_yaml_files: List[str] = []
@@ -73,7 +73,9 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
super(ApplicationServiceStoreTestCase, self).tearDown()
- def _add_appservice(self, as_token, id, url, hs_token, sender) -> None:
+ def _add_appservice(
+ self, as_token: str, id: str, url: str, hs_token: str, sender: str
+ ) -> None:
as_yaml = {
"url": url,
"as_token": as_token,
@@ -135,7 +137,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
database, make_conn(db_config, self.engine, "test"), self.hs
)
- def _add_service(self, url, as_token, id) -> None:
+ def _add_service(self, url: str, as_token: str, id: str) -> None:
as_yaml = {
"url": url,
"as_token": as_token,
@@ -149,7 +151,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
- def _set_state(self, id: str, state: ApplicationServiceState):
+ def _set_state(self, id: str, state: ApplicationServiceState) -> defer.Deferred:
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
@@ -157,7 +159,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
(id, state.value),
)
- def _insert_txn(self, as_id, txn_id, events):
+ def _insert_txn(
+ self, as_id: str, txn_id: int, events: List[Mock]
+ ) -> "defer.Deferred[None]":
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
@@ -448,12 +452,14 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, database: DatabasePool, db_conn, hs) -> None:
+ def __init__(
+ self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: HomeServer
+ ) -> None:
super().__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
- def _write_config(self, suffix, **kwargs) -> str:
+ def _write_config(self, suffix: str, **kwargs: str) -> str:
vals = {
"id": "id" + suffix,
"url": "url" + suffix,
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 40e58f8199..256d28e4c9 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from collections import OrderedDict
+from typing import Generator
from unittest.mock import Mock
from twisted.internet import defer
@@ -30,7 +30,7 @@ from tests.utils import default_config
class SQLBaseStoreTestCase(unittest.TestCase):
"""Test the "simple" SQL generating methods in SQLBaseStore."""
- def setUp(self):
+ def setUp(self) -> None:
self.db_pool = Mock(spec=["runInteraction"])
self.mock_txn = Mock()
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
@@ -38,12 +38,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_conn.rollback.return_value = None
# Our fake runInteraction just runs synchronously inline
- def runInteraction(func, *args, **kwargs):
+ def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_txn, *args, **kwargs))
self.db_pool.runInteraction = runInteraction
- def runWithConnection(func, *args, **kwargs):
+ def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_conn, *args, **kwargs))
self.db_pool.runWithConnection = runWithConnection
@@ -62,7 +62,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]
@defer.inlineCallbacks
- def test_insert_1col(self):
+ def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -76,7 +76,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_insert_3cols(self):
+ def test_insert_3cols(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_select_one_1col(self):
+ def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
@@ -108,7 +108,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_select_one_3col(self):
+ def test_select_one_3col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
@@ -126,7 +126,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_select_one_missing(self):
+ def test_select_one_missing(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
@@ -142,7 +144,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertFalse(ret)
@defer.inlineCallbacks
- def test_select_list(self):
+ def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 3
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
@@ -159,7 +161,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_update_one_1col(self):
+ def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -176,7 +178,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_update_one_4cols(self):
+ def test_update_one_4cols(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -193,7 +197,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_delete_one(self):
+ def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index b998ad42d9..d570684c99 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -15,11 +15,16 @@
import os.path
from unittest.mock import Mock, patch
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.storage import prepare_database
+from synapse.storage.types import Cursor
from synapse.types import UserID, create_requester
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -29,7 +34,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
Test the background update to clean forward extremities table.
"""
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
@@ -39,7 +46,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
- def run_background_update(self):
+ def run_background_update(self) -> None:
"""Re run the background update to clean up the extremities."""
# Make sure we don't clash with in progress updates.
self.assertTrue(
@@ -54,7 +61,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"delete_forward_extremities.sql",
)
- def run_delta_file(txn):
+ def run_delta_file(txn: Cursor) -> None:
prepare_database.executescript(txn, schema_path)
self.get_success(
@@ -84,7 +91,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
(room_id,)
)
- def test_soft_failed_extremities_handled_correctly(self):
+ def test_soft_failed_extremities_handled_correctly(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -114,7 +121,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.assertEqual(latest_event_ids, [event_id_4])
- def test_basic_cleanup(self):
+ def test_basic_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -149,7 +156,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(latest_event_ids, [event_id_b])
- def test_chain_of_fail_cleanup(self):
+ def test_chain_of_fail_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -187,7 +194,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(latest_event_ids, [event_id_b])
- def test_forked_graph_cleanup(self):
+ def test_forked_graph_cleanup(self) -> None:
r"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -252,12 +259,14 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["cleanup_extremities_with_dummy_events"] = True
return self.setup_test_homeserver(config=config)
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
self.event_creator_handler = homeserver.get_event_creation_handler()
@@ -273,7 +282,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
- def test_send_dummy_event(self):
+ def test_send_dummy_event(self) -> None:
self._create_extremity_rich_graph()
# Pump the reactor repeatedly so that the background updates have a
@@ -286,7 +295,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
- def test_send_dummy_events_when_insufficient_power(self):
+ def test_send_dummy_events_when_insufficient_power(self) -> None:
self._create_extremity_rich_graph()
# Criple power levels
self.helper.send_state(
@@ -317,7 +326,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
- def test_expiry_logic(self):
+ def test_expiry_logic(self) -> None:
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
expires old entries correctly.
"""
@@ -357,7 +366,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
0,
)
- def _create_extremity_rich_graph(self):
+ def _create_extremity_rich_graph(self) -> None:
"""Helper method to create bushy graph on demand"""
event_id_start = self.create_and_send_event(self.room_id, self.user)
@@ -372,7 +381,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
)
self.assertEqual(len(latest_event_ids), 50)
- def _enable_consent_checking(self):
+ def _enable_consent_checking(self) -> None:
"""Helper method to enable consent checking"""
self.event_creator._block_events_without_consent_error = "No consent from user"
consent_uri_builder = Mock()
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 49ad3c1324..7f7f4ef892 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
+from synapse.server import HomeServer
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.types import UserID
+from synapse.util import Clock
from tests import unittest
from tests.server import make_request
@@ -30,14 +35,10 @@ from tests.unittest import override_config
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- return hs
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
- def prepare(self, hs, reactor, clock):
- self.store = self.hs.get_datastores().main
-
- def test_insert_new_client_ip(self):
+ def test_insert_new_client_ip(self) -> None:
self.reactor.advance(12345678)
user_id = "@user:id"
@@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
- def test_insert_new_client_ip_none_device_id(self):
+ def test_insert_new_client_ip_none_device_id(self) -> None:
"""
An insert with a device ID of NULL will not create a new entry, but
update an existing entry in the user_ips table.
@@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
- def test_get_last_client_ip_by_device(self, after_persisting: bool):
+ def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None:
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
self.reactor.advance(12345678)
@@ -211,7 +212,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
},
)
- def test_get_last_client_ip_by_device_combined_data(self):
+ def test_get_last_client_ip_by_device_combined_data(self) -> None:
"""Test that `get_last_client_ip_by_device` combines persisted and unpersisted
data together correctly
"""
@@ -310,7 +311,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
- def test_get_user_ip_and_agents(self, after_persisting: bool):
+ def test_get_user_ip_and_agents(self, after_persisting: bool) -> None:
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
self.reactor.advance(12345678)
@@ -350,7 +351,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)
- def test_get_user_ip_and_agents_combined_data(self):
+ def test_get_user_ip_and_agents_combined_data(self) -> None:
"""Test that `get_user_ip_and_agents` combines persisted and unpersisted data
together correctly
"""
@@ -427,7 +428,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
- def test_disabled_monthly_active_user(self):
+ def test_disabled_monthly_active_user(self) -> None:
user_id = "@user:server"
self.get_success(
self.store.insert_client_ip(
@@ -438,7 +439,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
- def test_adding_monthly_active_user_when_full(self):
+ def test_adding_monthly_active_user_when_full(self) -> None:
lots_of_users = 100
user_id = "@user:server"
@@ -454,7 +455,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
- def test_adding_monthly_active_user_when_space(self):
+ def test_adding_monthly_active_user_when_space(self) -> None:
user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@@ -471,7 +472,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertTrue(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
- def test_updating_monthly_active_user_when_space(self):
+ def test_updating_monthly_active_user_when_space(self) -> None:
user_id = "@user:server"
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
@@ -489,7 +490,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
- def test_devices_last_seen_bg_update(self):
+ def test_devices_last_seen_bg_update(self) -> None:
# First make sure we have completed all updates.
self.wait_for_background_updates()
@@ -574,7 +575,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
- def test_old_user_ips_pruned(self):
+ def test_old_user_ips_pruned(self) -> None:
# First make sure we have completed all updates.
self.wait_for_background_updates()
@@ -637,11 +638,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(result, [])
# But we should still get the correct values for the device
- result = self.get_success(
+ result2 = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, device_id)]
+ r = result2[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
@@ -661,15 +662,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- return hs
-
- def prepare(self, hs, reactor, clock):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user_id = self.register_user("bob", "abc123", True)
- def test_request_with_xforwarded(self):
+ def test_request_with_xforwarded(self) -> None:
"""
The IP in X-Forwarded-For is entered into the client IPs table.
"""
@@ -679,14 +676,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
{"request": XForwardedForRequest},
)
- def test_request_from_getPeer(self):
+ def test_request_from_getPeer(self) -> None:
"""
The IP returned by getPeer is entered into the client IPs table, if
there's no X-Forwarded-For header.
"""
self._runtest({}, "127.0.0.1", {})
- def _runtest(self, headers, expected_ip, make_request_args):
+ def _runtest(
+ self,
+ headers: Dict[bytes, bytes],
+ expected_ip: str,
+ make_request_args: Dict[str, Any],
+ ) -> None:
device_id = "bleb"
access_token = self.login("bob", "abc123", device_id=device_id)
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index a40fc20ef9..8cd7c89ca2 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import (
DatabasePool,
+ LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
@@ -31,12 +32,107 @@ from tests import unittest
class TupleComparisonClauseTestCase(unittest.TestCase):
- def test_native_tuple_comparison(self):
+ def test_native_tuple_comparison(self) -> None:
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])
+class ExecuteScriptTestCase(unittest.HomeserverTestCase):
+ """Tests for `BaseDatabaseEngine.executescript` implementations."""
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+ self.get_success(
+ self.db_pool.runInteraction(
+ "create",
+ lambda txn: txn.execute("CREATE TABLE foo (name TEXT PRIMARY KEY)"),
+ )
+ )
+
+ def test_transaction(self) -> None:
+ """Test that all statements are run in a single transaction."""
+
+ def run(conn: LoggingDatabaseConnection) -> None:
+ cur = conn.cursor(txn_name="test_transaction")
+ self.db_pool.engine.executescript(
+ cur,
+ ";".join(
+ [
+ "INSERT INTO foo (name) VALUES ('transaction test')",
+ # This next statement will fail. When `executescript` is not
+ # transactional, the previous row will be observed later.
+ "INSERT INTO foo (name) VALUES ('transaction test')",
+ ]
+ ),
+ )
+
+ self.get_failure(
+ self.db_pool.runWithConnection(run),
+ self.db_pool.engine.module.IntegrityError,
+ )
+
+ self.assertIsNone(
+ self.get_success(
+ self.db_pool.simple_select_one_onecol(
+ "foo",
+ keyvalues={"name": "transaction test"},
+ retcol="name",
+ allow_none=True,
+ )
+ ),
+ "executescript is not running statements inside a transaction",
+ )
+
+ def test_commit(self) -> None:
+ """Test that the script transaction remains open and can be committed."""
+
+ def run(conn: LoggingDatabaseConnection) -> None:
+ cur = conn.cursor(txn_name="test_commit")
+ self.db_pool.engine.executescript(
+ cur, "INSERT INTO foo (name) VALUES ('commit test')"
+ )
+ cur.execute("COMMIT")
+
+ self.get_success(self.db_pool.runWithConnection(run))
+
+ self.assertIsNotNone(
+ self.get_success(
+ self.db_pool.simple_select_one_onecol(
+ "foo",
+ keyvalues={"name": "commit test"},
+ retcol="name",
+ allow_none=True,
+ )
+ ),
+ )
+
+ def test_rollback(self) -> None:
+ """Test that the script transaction remains open and can be rolled back."""
+
+ def run(conn: LoggingDatabaseConnection) -> None:
+ cur = conn.cursor(txn_name="test_rollback")
+ self.db_pool.engine.executescript(
+ cur, "INSERT INTO foo (name) VALUES ('rollback test')"
+ )
+ cur.execute("ROLLBACK")
+
+ self.get_success(self.db_pool.runWithConnection(run))
+
+ self.assertIsNone(
+ self.get_success(
+ self.db_pool.simple_select_one_onecol(
+ "foo",
+ keyvalues={"name": "rollback test"},
+ retcol="name",
+ allow_none=True,
+ )
+ ),
+ "executescript is not leaving the script transaction open",
+ )
+
+
class CallbacksTestCase(unittest.HomeserverTestCase):
"""Tests for transaction callbacks."""
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index f37505b6cf..f03807c8f9 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -12,23 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Collection, List, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.api.errors
from synapse.api.constants import EduTypes
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class DeviceStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def add_device_change(self, user_id, device_ids, host):
+ def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
"""
for device_id in device_ids:
- stream_id = self.get_success(
+ self.get_success(
self.store.add_device_change_to_streams(
user_id, [device_id], ["!some:room"]
)
@@ -39,18 +46,18 @@ class DeviceStoreTestCase(HomeserverTestCase):
user_id=user_id,
device_id=device_id,
room_id="!some:room",
- stream_id=stream_id,
hosts=[host],
context={},
)
)
- def test_store_new_device(self):
+ def test_store_new_device(self) -> None:
self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -60,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res,
)
- def test_get_devices_by_user(self):
+ def test_get_devices_by_user(self) -> None:
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
@@ -90,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res["device2"],
)
- def test_count_devices_by_users(self):
+ def test_count_devices_by_users(self) -> None:
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
@@ -115,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
self.assertEqual(3, res)
- def test_get_device_updates_by_remote(self):
+ def test_get_device_updates_by_remote(self) -> None:
device_ids = ["device_id1", "device_id2"]
# Add two device updates with sequential `stream_id`s
@@ -129,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
- def test_get_device_updates_by_remote_can_limit_properly(self):
+ def test_get_device_updates_by_remote_can_limit_properly(self) -> None:
"""
Tests that `get_device_updates_by_remote` returns an appropriate
stream_id to resume fetching from (without skipping any results).
@@ -281,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
self.assertEqual(device_updates, [])
- def _check_devices_in_updates(self, expected_device_ids, device_updates):
+ def _check_devices_in_updates(
+ self,
+ expected_device_ids: Collection[str],
+ device_updates: List[Tuple[str, JsonDict]],
+ ) -> None:
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
@@ -290,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))
- def test_update_device(self):
+ def test_update_device(self) -> None:
self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
self.get_success(self.store.update_device("user_id", "device_id"))
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -312,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase):
# check it worked
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 2", res["display_name"])
- def test_update_unknown_device(self):
+ def test_update_unknown_device(self) -> None:
exc = self.get_failure(
self.store.update_device(
"user_id", "unknown_device_id", new_display_name="display_name 2"
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 20bf3ca17b..8bedc6bdf3 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -12,19 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
from synapse.types import RoomAlias, RoomID
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class DirectoryStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
- def test_room_to_alias(self):
+ def test_room_to_alias(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
@@ -36,7 +40,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
)
- def test_alias_to_room(self):
+ def test_alias_to_room(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
@@ -48,7 +52,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_association_from_room_alias(self.alias))),
)
- def test_delete_alias(self):
+ def test_delete_alias(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index fb96ab3a2f..9cb326d90a 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
from synapse.storage.databases.main.e2e_room_keys import RoomKey
+from synapse.util import Clock
from tests import unittest
@@ -26,12 +30,12 @@ room_key: RoomKey = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.store = hs.get_datastores().main
return hs
- def test_room_keys_version_delete(self):
+ def test_room_keys_version_delete(self) -> None:
# test that deleting a room key backup deletes the keys
version1 = self.get_success(
self.store.create_e2e_room_keys_version(
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 0f04493ad0..5fde3b9c78 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+
from tests.unittest import HomeserverTestCase
class EndToEndKeyStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def test_key_without_device_name(self):
+ def test_key_without_device_name(self) -> None:
now = 1470174257070
json = {"key": "value"}
@@ -35,7 +40,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev)
- def test_reupload_key(self):
+ def test_reupload_key(self) -> None:
now = 1470174257070
json = {"key": "value"}
@@ -53,7 +58,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
)
self.assertFalse(changed)
- def test_get_key_with_device_name(self):
+ def test_get_key_with_device_name(self) -> None:
now = 1470174257070
json = {"key": "value"}
@@ -70,7 +75,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)
- def test_multiple_devices(self):
+ def test_multiple_devices(self) -> None:
now = 1470174257070
self.get_success(self.store.store_device("user1", "device1", None))
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index de9f4af2de..c070278db8 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -14,6 +14,7 @@
from typing import Dict, List, Set, Tuple
+from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from synapse.api.constants import EventTypes
@@ -22,18 +23,22 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.events import _LinkMap
+from synapse.storage.types import Cursor
from synapse.types import create_requester
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class EventChainStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self._next_stream_ordering = 1
- def test_simple(self):
+ def test_simple(self) -> None:
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
works.
"""
@@ -232,7 +237,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
),
)
- def test_out_of_order_events(self):
+ def test_out_of_order_events(self) -> None:
"""Test that we handle persisting events that we don't have the full
auth chain for yet (which should only happen for out of band memberships).
"""
@@ -378,7 +383,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
def persist(
self,
events: List[EventBase],
- ):
+ ) -> None:
"""Persist the given events and check that the links generated match
those given.
"""
@@ -389,7 +394,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
e.internal_metadata.stream_ordering = self._next_stream_ordering
self._next_stream_ordering += 1
- def _persist(txn):
+ def _persist(txn: LoggingTransaction) -> None:
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
@@ -456,7 +461,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
class LinkMapTestCase(unittest.TestCase):
- def test_simple(self):
+ def test_simple(self) -> None:
"""Basic tests for the LinkMap."""
link_map = _LinkMap()
@@ -492,7 +497,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -559,7 +564,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Delete the chain cover info.
- def _delete_tables(txn):
+ def _delete_tables(txn: Cursor) -> None:
txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links")
@@ -567,7 +572,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
return room_id, [state1, state2]
- def test_background_update_single_room(self):
+ def test_background_update_single_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -602,7 +607,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
- def test_background_update_multiple_rooms(self):
+ def test_background_update_multiple_rooms(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -640,7 +645,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
- def test_background_update_single_large_room(self):
+ def test_background_update_single_large_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -693,7 +698,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
- def test_background_update_multiple_large_room(self):
+ def test_background_update_multiple_large_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 853db930d6..7fd3e01364 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,7 +13,7 @@
# limitations under the License.
import datetime
-from typing import Dict, List, Tuple, Union
+from typing import Dict, List, Tuple, Union, cast
import attr
from parameterized import parameterized
@@ -26,11 +26,12 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersion,
)
-from synapse.events import _EventInternalMetadata
+from synapse.events import EventBase, _EventInternalMetadata
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
+from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import Clock, json_encoder
@@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def test_get_prev_events_for_room(self):
+ def test_get_prev_events_for_room(self) -> None:
room_id = "@ROOM:local"
# add a bunch of events and hashes to act as forward extremities
- def insert_event(txn, i):
+ def insert_event(txn: Cursor, i: int) -> None:
event_id = "$event_%i:local" % i
txn.execute(
@@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for i in range(0, 10):
self.assertEqual("$event_%i:local" % (19 - i), r[i])
- def test_get_rooms_with_many_extremities(self):
+ def test_get_rooms_with_many_extremities(self) -> None:
room1 = "#room1"
room2 = "#room2"
room3 = "#room3"
- def insert_event(txn, i, room_id):
+ def insert_event(txn: Cursor, i: int, room_id: str) -> None:
event_id = "$event_%i:local" % i
txn.execute(
(
@@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | |
# K J
- auth_graph = {
+ auth_graph: Dict[str, List[str]] = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
@@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# Mark the room as maybe having a cover index.
- def store_room(txn):
+ def store_room(txn: LoggingTransaction) -> None:
self.store.db_pool.simple_insert_txn(
txn,
"rooms",
@@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
- def insert_event(txn):
+ def insert_event(txn: LoggingTransaction) -> None:
stream_ordering = 0
for event_id in auth_graph:
@@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
[
- FakeEvent(event_id, room_id, auth_graph[event_id])
+ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
],
)
@@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return room_id
@parameterized.expand([(True,), (False,)])
- def test_auth_chain_ids(self, use_chain_cover_index: bool):
+ def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None:
room_id = self._setup_auth_chain(use_chain_cover_index)
# a and b have the same auth chain.
@@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertCountEqual(auth_chain_ids, ["i", "j"])
@parameterized.expand([(True,), (False,)])
- def test_auth_difference(self, use_chain_cover_index: bool):
+ def test_auth_difference(self, use_chain_cover_index: bool) -> None:
room_id = self._setup_auth_chain(use_chain_cover_index)
# Now actually test that various combinations give the right result:
@@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.assertSetEqual(difference, set())
- def test_auth_difference_partial_cover(self):
+ def test_auth_difference_partial_cover(self) -> None:
"""Test that we correctly handle rooms where not all events have a chain
cover calculated. This can happen in some obscure edge cases, including
during the background update that calculates the chain cover for old
@@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | |
# K J
- auth_graph = {
+ auth_graph: Dict[str, List[str]] = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
@@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
- def insert_event(txn):
+ def insert_event(txn: LoggingTransaction) -> None:
# First insert the room and mark it as having a chain cover.
self.store.db_pool.simple_insert_txn(
txn,
@@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
[
- FakeEvent(event_id, room_id, auth_graph[event_id])
+ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
if event_id != "b"
],
@@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
- [FakeEvent("b", room_id, auth_graph["b"])],
+ [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
)
self.store.db_pool.simple_update_txn(
@@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
@parameterized.expand(
[(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
)
- def test_prune_inbound_federation_queue(self, room_version: RoomVersion):
+ def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None:
"""Test that pruning of inbound federation queues work"""
room_id = "some_room_id"
@@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
stream_ordering += 1
- def populate_db(txn: LoggingTransaction):
+ def populate_db(txn: LoggingTransaction) -> None:
# Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts`
self.store.db_pool.simple_insert_txn(
@@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
- def test_get_backfill_points_in_room(self):
+ def test_get_backfill_points_in_room(self) -> None:
"""
Test to make sure only backfill points that are older and come before
the `current_depth` are returned.
@@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
self,
- ):
+ ) -> None:
"""
Test to make sure that events we have attempted to backfill (and within
backoff timeout duration) do not show up as an event to backfill again.
@@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
self,
- ):
+ ) -> None:
"""
Test to make sure after we fake attempt to backfill event "b3" many times,
we can see retry and see the "b3" again after the backoff timeout duration
@@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"5": 7,
}
- def populate_db(txn: LoggingTransaction):
+ def populate_db(txn: LoggingTransaction) -> None:
# Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts`
self.store.db_pool.simple_insert_txn(
@@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
- def test_get_insertion_event_backward_extremities_in_room(self):
+ def test_get_insertion_event_backward_extremities_in_room(self) -> None:
"""
Test to make sure only insertion event backward extremities that are
older and come before the `current_depth` are returned.
@@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
self,
- ):
+ ) -> None:
"""
Test to make sure that insertion events we have attempted to backfill
(and within backoff timeout duration) do not show up as an event to
@@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
self,
- ):
+ ) -> None:
"""
Test to make sure after we fake attempt to backfill event
"insertion_eventA" many times, we can see retry and see the
@@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertEqual(backfill_event_ids, ["insertion_eventA"])
- def test_get_event_ids_to_not_pull_from_backoff(
- self,
- ):
+ def test_get_event_ids_to_not_pull_from_backoff(self) -> None:
"""
Test to make sure only event IDs we should backoff from are returned.
"""
@@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
self,
- ):
+ ) -> None:
"""
Test to make sure no event IDs are returned after the backoff duration has
elapsed.
@@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(event_ids_to_backoff, [])
-@attr.s
+@attr.s(auto_attribs=True)
class FakeEvent:
- event_id = attr.ib()
- room_id = attr.ib()
- auth_events = attr.ib()
+ event_id: str
+ room_id: str
+ auth_events: List[str]
type = "foo"
state_key = "foo"
internal_metadata = _EventInternalMetadata({})
- def auth_event_ids(self):
+ def auth_event_ids(self) -> List[str]:
return self.auth_events
- def is_state(self):
+ def is_state(self) -> bool:
return True
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 088fbb247b..a91411168c 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -11,15 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from prometheus_client import generate_latest
-from synapse.metrics import REGISTRY, generate_latest
+from synapse.metrics import REGISTRY
from synapse.types import UserID, create_requester
from tests.unittest import HomeserverTestCase
class ExtremStatisticsTestCase(HomeserverTestCase):
- def test_exposed_to_prometheus(self):
+ def test_exposed_to_prometheus(self) -> None:
"""
Forward extremity counts are exposed via Prometheus.
"""
@@ -53,8 +54,8 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
items = list(
filter(
- lambda x: b"synapse_forward_extremities_" in x,
- generate_latest(REGISTRY, emit_help=False).split(b"\n"),
+ lambda x: b"synapse_forward_extremities_" in x and b"# HELP" not in x,
+ generate_latest(REGISTRY).split(b"\n"),
)
)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index ee48920f84..76c06a9d1e 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -154,9 +154,9 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
# Create a user to receive notifications and send receipts.
user_id, token, _, other_token, room_id = self._create_users_and_room()
- last_event_id: str
+ last_event_id = ""
- def _assert_counts(noitf_count: int, highlight_count: int) -> None:
+ def _assert_counts(notif_count: int, highlight_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
@@ -168,13 +168,22 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
- notify_count=noitf_count,
+ notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
)
self.assertEqual(counts.threads, {})
+ aggregate_counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-aggregate-unread-counts",
+ self.store._get_unread_counts_by_room_for_user_txn,
+ user_id,
+ )
+ )
+ self.assertEqual(aggregate_counts[room_id], notif_count)
+
def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event(
room_id,
@@ -280,10 +289,10 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
user_id, token, _, other_token, room_id = self._create_users_and_room()
thread_id: str
- last_event_id: str
+ last_event_id = ""
def _assert_counts(
- noitf_count: int,
+ notif_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
@@ -299,7 +308,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
- notify_count=noitf_count,
+ notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
@@ -318,6 +327,17 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
else:
self.assertEqual(counts.threads, {})
+ aggregate_counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-aggregate-unread-counts",
+ self.store._get_unread_counts_by_room_for_user_txn,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ aggregate_counts[room_id], notif_count + thread_notif_count
+ )
+
def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
@@ -451,10 +471,10 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
user_id, token, _, other_token, room_id = self._create_users_and_room()
thread_id: str
- last_event_id: str
+ last_event_id = ""
def _assert_counts(
- noitf_count: int,
+ notif_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
@@ -470,7 +490,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
- notify_count=noitf_count,
+ notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
@@ -489,6 +509,17 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
else:
self.assertEqual(counts.threads, {})
+ aggregate_counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-aggregate-unread-counts",
+ self.store._get_unread_counts_by_room_for_user_txn,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ aggregate_counts[room_id], notif_count + thread_notif_count
+ )
+
def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
@@ -646,7 +677,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
return result["event_id"]
- def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
+ def _assert_counts(notif_count: int, thread_notif_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
@@ -658,7 +689,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
- notify_count=noitf_count, unread_count=0, highlight_count=0
+ notify_count=notif_count, unread_count=0, highlight_count=0
),
)
if thread_notif_count:
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 3ce4f35cb7..05661a537d 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -12,12 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import StateMap
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -29,7 +36,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
@@ -67,7 +76,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check that the current extremities is the remote event.
self.assert_extremities([self.remote_event_1.event_id])
- def persist_event(self, event, state=None):
+ def persist_event(
+ self, event: EventBase, state: Optional[StateMap[str]] = None
+ ) -> None:
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(
@@ -78,14 +89,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
self.get_success(self._persistence.persist_event(event, context))
- def assert_extremities(self, expected_extremities):
+ def assert_extremities(self, expected_extremities: List[str]) -> None:
"""Assert the current extremities for the room"""
extremities = self.get_success(
self.store.get_prev_events_for_room(self.room_id)
)
self.assertCountEqual(extremities, expected_extremities)
- def test_prune_gap(self):
+ def test_prune_gap(self) -> None:
"""Test that we drop extremities after a gap when we see an event from
the same domain.
"""
@@ -117,7 +128,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
- def test_do_not_prune_gap_if_state_different(self):
+ def test_do_not_prune_gap_if_state_different(self) -> None:
"""Test that we don't prune extremities after a gap if the resolved
state is different.
"""
@@ -161,7 +172,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
- def test_prune_gap_if_old(self):
+ def test_prune_gap_if_old(self) -> None:
"""Test that we drop extremities after a gap when the previous extremity
is "old"
"""
@@ -197,7 +208,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
- def test_do_not_prune_gap_if_other_server(self):
+ def test_do_not_prune_gap_if_other_server(self) -> None:
"""Test that we do not drop extremities after a gap when we see an event
from a different domain.
"""
@@ -229,7 +240,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
- def test_prune_gap_if_dummy_remote(self):
+ def test_prune_gap_if_dummy_remote(self) -> None:
"""Test that we drop extremities after a gap when the previous extremity
is a local dummy event and only points to remote events.
"""
@@ -271,7 +282,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
- def test_prune_gap_if_dummy_local(self):
+ def test_prune_gap_if_dummy_local(self) -> None:
"""Test that we don't drop extremities after a gap when the previous
extremity is a local dummy event and points to local events.
"""
@@ -315,7 +326,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
- def test_do_not_prune_gap_if_not_dummy(self):
+ def test_do_not_prune_gap_if_not_dummy(self) -> None:
"""Test that we do not drop extremities after a gap when the previous extremity
is not a dummy event.
"""
@@ -359,12 +370,14 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
- def test_remote_user_rooms_cache_invalidated(self):
+ def test_remote_user_rooms_cache_invalidated(self) -> None:
"""Test that if the server leaves a room the `get_rooms_for_user` cache
is invalidated for remote users.
"""
@@ -411,7 +424,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
self.assertEqual(set(rooms), set())
- def test_room_remote_user_cache_invalidated(self):
+ def test_room_remote_user_cache_invalidated(self) -> None:
"""Test that if the server leaves a room the `get_users_in_room` cache
is invalidated for remote users.
"""
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index d6a2b8d274..9174fb0964 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -52,6 +52,7 @@ class StreamIdGeneratorTestCase(HomeserverTestCase):
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
return StreamIdGenerator(
db_conn=conn,
+ notifier=self.hs.get_replication_notifier(),
table="foobar",
column="stream_id",
)
@@ -196,6 +197,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
@@ -349,8 +351,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# The first ID gen will notice that it can advance its token to 7 as it
# has no in progress writes...
- self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# ... but the second ID gen doesn't know that.
@@ -366,8 +368,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(
- first_id_gen.get_positions(), {"first": 7, "second": 7}
+ first_id_gen.get_positions(), {"first": 3, "second": 7}
)
+ self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
self.get_success(_get_next_async())
@@ -473,7 +476,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator("first", writers=["first", "second"])
- self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@@ -629,6 +632,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
@@ -720,7 +724,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async2())
- self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
@@ -765,6 +769,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[
@@ -816,15 +821,12 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- # The first ID gen will notice that it can advance its token to 7 as it
- # has no in progress writes...
- self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
- # ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
- self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
+ self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 9059095525..aa4b5bd3b1 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -13,6 +13,7 @@
# limitations under the License.
import signedjson.key
+import signedjson.types
import unpaddedbase64
from twisted.internet.defer import Deferred
@@ -22,7 +23,9 @@ from synapse.storage.keys import FetchKeyResult
import tests.unittest
-def decode_verify_key_base64(key_id: str, key_base64: str):
+def decode_verify_key_base64(
+ key_id: str, key_base64: str
+) -> signedjson.types.VerifyKey:
key_bytes = unpaddedbase64.decode_base64(key_base64)
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
@@ -36,7 +39,7 @@ KEY_2 = decode_verify_key_base64(
class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
- def test_get_server_verify_keys(self):
+ def test_get_server_verify_keys(self) -> None:
store = self.hs.get_datastores().main
key_id_1 = "ed25519:key1"
@@ -71,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
- def test_cache(self):
+ def test_cache(self) -> None:
"""Check that updates correctly invalidate the cache."""
store = self.hs.get_datastores().main
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index c55c4db970..2827738379 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -53,7 +53,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.reactor.advance(FORTY_DAYS)
@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
- def test_initialise_reserved_users(self):
+ def test_initialise_reserved_users(self) -> None:
threepids = self.hs.config.server.mau_limits_reserved_threepids
# register three users, of which two have reserved 3pids, and a third
@@ -133,7 +133,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
active_count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(active_count, 3)
- def test_can_insert_and_count_mau(self):
+ def test_can_insert_and_count_mau(self) -> None:
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0)
@@ -143,7 +143,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 1)
- def test_appservice_user_not_counted_in_mau(self):
+ def test_appservice_user_not_counted_in_mau(self) -> None:
self.get_success(
self.store.register_user(
user_id="@appservice_user:server", appservice_id="wibble"
@@ -158,7 +158,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0)
- def test_user_last_seen_monthly_active(self):
+ def test_user_last_seen_monthly_active(self) -> None:
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
@@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertIsNone(result)
@override_config({"max_mau_value": 5})
- def test_reap_monthly_active_users(self):
+ def test_reap_monthly_active_users(self) -> None:
initial_users = 10
for i in range(initial_users):
self.get_success(
@@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# Note that below says mau_limit (no s), this is the name of the config
# value, although it gets stored on the config object as mau_limits.
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
- def test_reap_monthly_active_users_reserved_users(self):
+ def test_reap_monthly_active_users_reserved_users(self) -> None:
"""Tests that reaping correctly handles reaping where reserved users are
present"""
threepids = self.hs.config.server.mau_limits_reserved_threepids
@@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, self.hs.config.server.max_mau_value)
- def test_populate_monthly_users_is_guest(self):
+ def test_populate_monthly_users_is_guest(self) -> None:
# Test that guest users are not added to mau list
user_id = "@user_id:host"
@@ -260,7 +260,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
- def test_populate_monthly_users_should_update(self):
+ def test_populate_monthly_users_should_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
@@ -273,7 +273,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
- def test_populate_monthly_users_should_not_update(self):
+ def test_populate_monthly_users_should_not_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
@@ -286,7 +286,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
- def test_get_reserved_real_user_account(self):
+ def test_get_reserved_real_user_account(self) -> None:
# Test no reserved users, or reserved threepids
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), 0)
@@ -326,7 +326,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), len(threepids))
- def test_support_user_not_add_to_mau_limits(self):
+ def test_support_user_not_add_to_mau_limits(self) -> None:
support_user_id = "@support:test"
count = self.get_success(self.store.get_monthly_active_count())
@@ -347,7 +347,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config(
{"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
)
- def test_track_monthly_users_without_cap(self):
+ def test_track_monthly_users_without_cap(self) -> None:
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(0, count)
@@ -358,14 +358,14 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(2, count)
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
- def test_no_users_when_not_tracking(self):
+ def test_no_users_when_not_tracking(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
self.store.upsert_monthly_active_user.assert_not_called()
- def test_get_monthly_active_count_by_service(self):
+ def test_get_monthly_active_count_by_service(self) -> None:
appservice1_user1 = "@appservice1_user1:example.com"
appservice1_user2 = "@appservice1_user2:example.com"
@@ -413,7 +413,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(result[service2], 1)
self.assertEqual(result[native], 1)
- def test_get_monthly_active_users_by_service(self):
+ def test_get_monthly_active_users_by_service(self) -> None:
# (No users, no filtering) -> empty result
result = self.get_success(self.store.get_monthly_active_users_by_service())
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 9c1182ed16..010cc74c31 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client import room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -23,17 +27,17 @@ class PurgeTests(HomeserverTestCase):
user_id = "@red:server"
servlets = [room.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main
self._storage_controllers = self.hs.get_storage_controllers()
- def test_purge_history(self):
+ def test_purge_history(self) -> None:
"""
Purging a room history will delete everything before the topological point.
"""
@@ -63,7 +67,7 @@ class PurgeTests(HomeserverTestCase):
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
self.get_success(self.store.get_event(last["event_id"]))
- def test_purge_history_wont_delete_extrems(self):
+ def test_purge_history_wont_delete_extrems(self) -> None:
"""
Purging a room history will delete everything before the topological point.
"""
@@ -77,6 +81,7 @@ class PurgeTests(HomeserverTestCase):
token = self.get_success(
self.store.get_topological_token_for_event(last["event_id"])
)
+ assert token.topological is not None
event = f"t{token.topological + 1}-{token.stream + 1}"
# Purge everything before this topological token
@@ -94,7 +99,7 @@ class PurgeTests(HomeserverTestCase):
self.get_success(self.store.get_event(third["event_id"]))
self.get_success(self.store.get_event(last["event_id"]))
- def test_purge_room(self):
+ def test_purge_room(self) -> None:
"""
Purging a room will delete everything about it.
"""
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index 81253d0361..d8d84152dc 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -14,8 +14,12 @@
from typing import Collection, Optional
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import ReceiptTypes
+from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util import Clock
from tests.test_utils.event_injection import create_event
from tests.unittest import HomeserverTestCase
@@ -25,7 +29,9 @@ OUR_USER_ID = "@our:test"
class ReceiptTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver) -> None:
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
super().prepare(reactor, clock, homeserver)
self.store = homeserver.get_datastores().main
@@ -135,11 +141,11 @@ class ReceiptTestCase(HomeserverTestCase):
)
self.assertEqual(res, {})
- res = self.get_last_unthreaded_receipt(
+ res2 = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
- self.assertEqual(res, None)
+ self.assertIsNone(res2)
def test_get_receipts_for_user(self) -> None:
# Send some events into the first room
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6c4e63b77c..df4740f9d9 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -11,27 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional
+from typing import List, Optional, cast
from canonicaljson import json
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase, _EventInternalMetadata
+from synapse.events.builder import EventBuilder
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomID, UserID
+from synapse.util import Clock
from tests import unittest
from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config["redaction_retention_period"] = "30d"
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- self._storage = hs.get_storage_controllers()
+ storage = hs.get_storage_controllers()
+ assert storage.persistence is not None
+ self._persistence = storage.persistence
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.depth = 1
- def inject_room_member(
+ def inject_room_member( # type: ignore[override]
self,
- room,
- user,
- membership,
- replaces_state=None,
- extra_content: Optional[dict] = None,
- ):
+ room: RoomID,
+ user: UserID,
+ membership: str,
+ extra_content: Optional[JsonDict] = None,
+ ) -> EventBase:
content = {"membership": membership}
content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
@@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def inject_message(self, room, user, body):
+ def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase:
self.depth += 1
builder = self.event_builder_factory.for_room_version(
@@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def inject_redaction(self, room, event_id, user, reason):
+ def inject_redaction(
+ self, room: RoomID, event_id: str, user: UserID, reason: str
+ ) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def test_redact(self):
+ def test_redact(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_message(self.room1, self.u_alice, "t")
@@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"],
)
- def test_redact_join(self):
+ def test_redact_join(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_room_member(
@@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"],
)
- def test_circular_redaction(self):
+ def test_circular_redaction(self) -> None:
redaction_event_id1 = "$redaction1_id:test"
redaction_event_id2 = "$redaction2_id:test"
class EventIdManglingBuilder:
- def __init__(self, base_builder, event_id):
+ def __init__(self, base_builder: EventBuilder, event_id: str):
self._base_builder = base_builder
self._event_id = event_id
@@ -227,67 +236,73 @@ class RedactionTestCase(unittest.HomeserverTestCase):
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
- ):
+ ) -> EventBase:
built_event = await self._base_builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)
- built_event._event_id = self._event_id
+ built_event._event_id = self._event_id # type: ignore[attr-defined]
built_event._dict["event_id"] = self._event_id
assert built_event.event_id == self._event_id
return built_event
@property
- def room_id(self):
+ def room_id(self) -> str:
return self._base_builder.room_id
@property
- def type(self):
+ def type(self) -> str:
return self._base_builder.type
@property
- def internal_metadata(self):
+ def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
- EventIdManglingBuilder(
- self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Redaction,
- "sender": self.u_alice.to_string(),
- "room_id": self.room1.to_string(),
- "content": {"reason": "test"},
- "redacts": redaction_event_id2,
- },
+ cast(
+ EventBuilder,
+ EventIdManglingBuilder(
+ self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "test"},
+ "redacts": redaction_event_id2,
+ },
+ ),
+ redaction_event_id1,
),
- redaction_event_id1,
)
)
)
- self.get_success(self._storage.persistence.persist_event(event_1, context_1))
+ self.get_success(self._persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
- EventIdManglingBuilder(
- self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Redaction,
- "sender": self.u_alice.to_string(),
- "room_id": self.room1.to_string(),
- "content": {"reason": "test"},
- "redacts": redaction_event_id1,
- },
+ cast(
+ EventBuilder,
+ EventIdManglingBuilder(
+ self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "test"},
+ "redacts": redaction_event_id1,
+ },
+ ),
+ redaction_event_id2,
),
- redaction_event_id2,
)
)
)
- self.get_success(self._storage.persistence.persist_event(event_2, context_2))
+ self.get_success(self._persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
)
- def test_redact_censor(self):
+ def test_redact_censor(self) -> None:
"""Test that a redacted event gets censored in the DB after a month"""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.assert_dict({"content": {}}, json.loads(event_json))
- def test_redact_redaction(self):
+ def test_redact_redaction(self) -> None:
"""Tests that we can redact a redaction and can fetch it again."""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.store.get_event(first_redact_event.event_id, allow_none=True)
)
- def test_store_redacted_redaction(self):
+ def test_store_redacted_redaction(self) -> None:
"""Tests that we can store a redacted redaction."""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self._storage.persistence.persist_event(redaction_event, context)
- )
+ self.get_success(self._persistence.persist_event(redaction_event, context))
# Now lets jump to the future where we have censored the redaction event
# in the DB.
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index 0baa54312e..966aafea6f 100644
--- a/tests/storage/test_rollback_worker.py
+++ b/tests/storage/test_rollback_worker.py
@@ -14,10 +14,15 @@
from typing import List
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.app.generic_worker import GenericWorkerServer
+from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
from synapse.storage.schema import SCHEMA_VERSION
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -39,13 +44,13 @@ def fake_listdir(filepath: str) -> List[str]:
class WorkerSchemaTests(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
- def default_config(self):
+ def default_config(self) -> JsonDict:
conf = super().default_config()
# Mark this as a worker app.
@@ -53,7 +58,7 @@ class WorkerSchemaTests(HomeserverTestCase):
return conf
- def test_rolling_back(self):
+ def test_rolling_back(self) -> None:
"""Test that workers can start if the DB is a newer schema version"""
db_pool = self.hs.get_datastores().main.db_pool
@@ -70,7 +75,7 @@ class WorkerSchemaTests(HomeserverTestCase):
prepare_database(db_conn, db_pool.engine, self.hs.config)
- def test_not_upgraded_old_schema_version(self):
+ def test_not_upgraded_old_schema_version(self) -> None:
"""Test that workers don't start if the DB has an older schema version"""
db_pool = self.hs.get_datastores().main.db_pool
db_conn = LoggingDatabaseConnection(
@@ -87,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase):
with self.assertRaises(PrepareDatabaseException):
prepare_database(db_conn, db_pool.engine, self.hs.config)
- def test_not_upgraded_current_schema_version_with_outstanding_deltas(self):
+ def test_not_upgraded_current_schema_version_with_outstanding_deltas(self) -> None:
"""
Test that workers don't start if the DB is on the current schema version,
but there are still outstanding delta migrations to run.
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3405efb6a8..71ec74eadc 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.room_versions import RoomVersions
+from synapse.server import HomeServer
from synapse.types import RoomAlias, RoomID, UserID
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class RoomStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastores().main
@@ -37,30 +41,34 @@ class RoomStoreTestCase(HomeserverTestCase):
)
)
- def test_get_room(self):
+ def test_get_room(self) -> None:
+ res = self.get_success(self.store.get_room(self.room.to_string()))
+ assert res is not None
self.assertDictContainsSubset(
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (self.get_success(self.store.get_room(self.room.to_string()))),
+ res,
)
- def test_get_room_unknown_room(self):
+ def test_get_room_unknown_room(self) -> None:
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
- def test_get_room_with_stats(self):
+ def test_get_room_with_stats(self) -> None:
+ res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
+ assert res is not None
self.assertDictContainsSubset(
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"public": True,
},
- (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
+ res,
)
- def test_get_room_with_stats_unknown_room(self):
+ def test_get_room_with_stats_unknown_room(self) -> None:
self.assertIsNone(
- (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
+ self.get_success(self.store.get_room_with_stats("!uknown:test"))
)
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index ef850daa73..14d872514d 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -39,7 +39,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
room.register_servlets,
]
- def test_null_byte(self):
+ def test_null_byte(self) -> None:
"""
Postgres/SQLite don't like null bytes going into the search tables. Internally
we replace those with a space.
@@ -86,7 +86,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
if isinstance(store.database_engine, PostgresEngine):
self.assertIn("alice", result.get("highlights"))
- def test_non_string(self):
+ def test_non_string(self) -> None:
"""Test that non-string `value`s are not inserted into `event_search`.
This is particularly important when using sqlite, since a sqlite column can hold
@@ -157,7 +157,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
self.assertEqual(f.value.code, 404)
@skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite")
- def test_sqlite_non_string_deletion_background_update(self):
+ def test_sqlite_non_string_deletion_background_update(self) -> None:
"""Test the background update to delete bad rows from `event_search`."""
store = self.hs.get_datastores().main
@@ -350,7 +350,7 @@ class MessageSearchTest(HomeserverTestCase):
"results array length should match count",
)
- def test_postgres_web_search_for_phrase(self):
+ def test_postgres_web_search_for_phrase(self) -> None:
"""
Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery.
This test is skipped unless the postgres instance supports websearch_to_tsquery.
@@ -364,7 +364,7 @@ class MessageSearchTest(HomeserverTestCase):
self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES)
- def test_sqlite_search(self):
+ def test_sqlite_search(self) -> None:
"""
Test sqlite searching for phrases.
"""
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5564161750..bad7f0bc60 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -16,18 +16,23 @@ import logging
from frozendict import frozendict
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.storage.state import StateFilter
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomID, StateMap, UserID
+from synapse.types.state import StateFilter
+from synapse.util import Clock
-from tests.unittest import HomeserverTestCase, TestCase
+from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
@@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase):
)
)
- def inject_state_event(self, room, sender, typ, state_key, content):
+ def inject_state_event(
+ self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
+ ) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
+ assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
return event
- def assertStateMapEqual(self, s1, s2):
+ def assertStateMapEqual(
+ self, s1: StateMap[EventBase], s2: StateMap[EventBase]
+ ) -> None:
for t in s1:
# just compare event IDs for simplicity
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
- def test_get_state_groups_ids(self):
+ def test_get_state_groups_ids(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = self.get_success(
- self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
+ self.storage.state.get_state_groups_ids(
+ self.room.to_string(), [e2.event_id]
+ )
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
- def test_get_state_groups(self):
+ def test_get_state_groups(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = self.get_success(
- self.storage.state.get_state_groups(self.room, [e2.event_id])
+ self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
- def test_get_state_for_event(self):
+ def test_get_state_for_event(self) -> None:
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -482,622 +494,3 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
-
-
-class StateFilterDifferenceTestCase(TestCase):
- def assert_difference(
- self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
- ):
- self.assertEqual(
- minuend.approx_difference(subtrahend),
- expected,
- f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
- )
-
- def test_state_filter_difference_no_include_other_minus_no_include_other(self):
- """
- Tests the StateFilter.approx_difference method
- where, in a.approx_difference(b), both a and b do not have the
- include_others flag set.
- """
- # (wildcard on state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.Create: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
- include_others=False,
- ),
- StateFilter.freeze({EventTypes.Create: None}, include_others=False),
- )
-
- # (wildcard on state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- self.assert_difference(
- StateFilter.freeze({EventTypes.Member: None}, include_others=False),
- StateFilter.freeze(
- {EventTypes.Member: {"@wombat:spqr"}},
- include_others=False,
- ),
- StateFilter.freeze({EventTypes.Member: None}, include_others=False),
- )
-
- # (wildcard on state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- )
-
- # (specific state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.CanonicalAlias: {""}},
- include_others=False,
- ),
- )
-
- # (specific state keys) - (specific state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- )
-
- # (specific state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- )
-
- def test_state_filter_difference_include_other_minus_no_include_other(self):
- """
- Tests the StateFilter.approx_difference method
- where, in a.approx_difference(b), only a has the include_others flag set.
- """
- # (wildcard on state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.Create: None},
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Create: None,
- EventTypes.Member: set(),
- EventTypes.CanonicalAlias: set(),
- },
- include_others=True,
- ),
- )
-
- # (wildcard on state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- # This also shows that the resultant state filter is normalised.
- self.assert_difference(
- StateFilter.freeze({EventTypes.Member: None}, include_others=True),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- EventTypes.Create: {""},
- },
- include_others=False,
- ),
- StateFilter(types=frozendict(), include_others=True),
- )
-
- # (wildcard on state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=False,
- ),
- StateFilter(
- types=frozendict(),
- include_others=True,
- ),
- )
-
- # (specific state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.CanonicalAlias: {""},
- EventTypes.Member: set(),
- },
- include_others=True,
- ),
- )
-
- # (specific state keys) - (specific state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- )
-
- # (specific state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- )
-
- def test_state_filter_difference_include_other_minus_include_other(self):
- """
- Tests the StateFilter.approx_difference method
- where, in a.approx_difference(b), both a and b have the include_others
- flag set.
- """
- # (wildcard on state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.Create: None},
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
- include_others=True,
- ),
- StateFilter(types=frozendict(), include_others=False),
- )
-
- # (wildcard on state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- self.assert_difference(
- StateFilter.freeze({EventTypes.Member: None}, include_others=True),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
- include_others=False,
- ),
- )
-
- # (wildcard on state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- )
-
- # (specific state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=True,
- ),
- StateFilter(
- types=frozendict(),
- include_others=False,
- ),
- )
-
- # (specific state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- EventTypes.Create: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- EventTypes.Create: set(),
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@spqr:spqr"},
- EventTypes.Create: {""},
- },
- include_others=False,
- ),
- )
-
- # (specific state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- },
- include_others=False,
- ),
- )
-
- def test_state_filter_difference_no_include_other_minus_include_other(self):
- """
- Tests the StateFilter.approx_difference method
- where, in a.approx_difference(b), only b has the include_others flag set.
- """
- # (wildcard on state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.Create: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
- include_others=True,
- ),
- StateFilter(types=frozendict(), include_others=False),
- )
-
- # (wildcard on state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- self.assert_difference(
- StateFilter.freeze({EventTypes.Member: None}, include_others=False),
- StateFilter.freeze(
- {EventTypes.Member: {"@wombat:spqr"}},
- include_others=True,
- ),
- StateFilter.freeze({EventTypes.Member: None}, include_others=False),
- )
-
- # (wildcard on state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- )
-
- # (specific state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=True,
- ),
- StateFilter(
- types=frozendict(),
- include_others=False,
- ),
- )
-
- # (specific state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@spqr:spqr"},
- },
- include_others=False,
- ),
- )
-
- # (specific state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- },
- include_others=False,
- ),
- )
-
- def test_state_filter_difference_simple_cases(self):
- """
- Tests some very simple cases of the StateFilter approx_difference,
- that are not explicitly tested by the more in-depth tests.
- """
-
- self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
-
- self.assert_difference(
- StateFilter.all(),
- StateFilter.none(),
- StateFilter.all(),
- )
-
-
-class StateFilterTestCase(TestCase):
- def test_return_expanded(self):
- """
- Tests the behaviour of the return_expanded() function that expands
- StateFilters to include more state types (for the sake of cache hit rate).
- """
-
- self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
-
- self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
-
- # Concrete-only state filters stay the same
- # (Case: mixed filter)
- self.assertEqual(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:test", "@alicia:test"},
- "some.other.state.type": {""},
- },
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:test", "@alicia:test"},
- "some.other.state.type": {""},
- },
- include_others=False,
- ),
- )
-
- # Concrete-only state filters stay the same
- # (Case: non-member-only filter)
- self.assertEqual(
- StateFilter.freeze(
- {"some.other.state.type": {""}}, include_others=False
- ).return_expanded(),
- StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
- )
-
- # Concrete-only state filters stay the same
- # (Case: member-only filter)
- self.assertEqual(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:test", "@alicia:test"},
- },
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:test", "@alicia:test"},
- },
- include_others=False,
- ),
- )
-
- # Wildcard member-only state filters stay the same
- self.assertEqual(
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- )
-
- # If there is a wildcard in the non-member portion of the filter,
- # it's expanded to include ALL non-member events.
- # (Case: mixed filter)
- self.assertEqual(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:test", "@alicia:test"},
- "some.other.state.type": None,
- },
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze(
- {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
- include_others=True,
- ),
- )
-
- # If there is a wildcard in the non-member portion of the filter,
- # it's expanded to include ALL non-member events.
- # (Case: non-member-only filter)
- self.assertEqual(
- StateFilter.freeze(
- {
- "some.other.state.type": None,
- },
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
- )
- self.assertEqual(
- StateFilter.freeze(
- {
- "some.other.state.type": None,
- "yet.another.state.type": {"wombat"},
- },
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
- )
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 34fa810cf6..bc090ebce0 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -14,11 +14,15 @@
from typing import List
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.filtering import Filter
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -37,12 +41,14 @@ class PaginationTestCase(HomeserverTestCase):
login.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc3874_enabled": True}
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
@@ -130,7 +136,7 @@ class PaginationTestCase(HomeserverTestCase):
return [ev.event_id for ev in events]
- def test_filter_relation_senders(self):
+ def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
@@ -146,7 +152,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
- def test_filter_relation_type(self):
+ def test_filter_relation_type(self) -> None:
# Messages which have annotations.
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
@@ -167,7 +173,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
- def test_filter_relation_senders_and_type(self):
+ def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
"related_by_senders": [self.second_user_id],
@@ -176,7 +182,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertEqual(chunk, [self.event_id_1])
- def test_duplicate_relation(self):
+ def test_duplicate_relation(self) -> None:
"""An event should only be returned once if there are multiple relations to it."""
self.helper.send_event(
room_id=self.room_id,
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index e05daa285e..db9ee9955e 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -12,17 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
from synapse.storage.databases.main.transactions import DestinationRetryTimings
+from synapse.util import Clock
from synapse.util.retryutils import MAX_RETRY_INTERVAL
from tests.unittest import HomeserverTestCase
class TransactionStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_get_set_transactions(self):
+ def test_get_set_transactions(self) -> None:
"""Tests that we can successfully get a non-existent entry for
destination retries, as well as testing tht we can set and get
correctly.
@@ -44,18 +50,18 @@ class TransactionStoreTestCase(HomeserverTestCase):
r,
)
- def test_initial_set_transactions(self):
+ def test_initial_set_transactions(self) -> None:
"""Tests that we can successfully set the destination retries (there
was a bug around invalidating the cache that broke this)
"""
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
self.get_success(d)
- def test_large_destination_retry(self):
+ def test_large_destination_retry(self) -> None:
d = self.store.set_destination_retry_timings(
"example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
)
self.get_success(d)
- d = self.store.get_destination_retry_timings("example.com")
- self.get_success(d)
+ d2 = self.store.get_destination_retry_timings("example.com")
+ self.get_success(d2)
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index ace82cbf42..15ea4770bd 100644
--- a/tests/storage/test_txn_limit.py
+++ b/tests/storage/test_txn_limit.py
@@ -12,21 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.types import Cursor
+from synapse.util import Clock
+
from tests import unittest
class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
"""Test SQL transaction limit doesn't break transactions."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(db_txn_limit=1000)
- def test_config(self):
+ def test_config(self) -> None:
db_config = self.hs.config.database.get_single_database()
self.assertEqual(db_config.config["txn_limit"], 1000)
- def test_select(self):
- def do_select(txn):
+ def test_select(self) -> None:
+ def do_select(txn: Cursor) -> None:
txn.execute("SELECT 1")
db_pool = self.hs.get_datastores().databases[0]
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 5b60cf5285..f1ca523d23 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
from typing import Any, Dict, Set, Tuple
from unittest import mock
from unittest.mock import Mock, patch
@@ -27,9 +28,16 @@ from synapse.storage.background_updates import _BackgroundUpdateHandler
from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock
+from tests.server import ThreadedMemoryReactorClock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
+try:
+ import icu
+except ImportError:
+ icu = None # type: ignore
+
+
ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
@@ -131,7 +139,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
register.register_servlets,
]
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ def make_homeserver(
+ self, reactor: ThreadedMemoryReactorClock, clock: Clock
+ ) -> HomeServer:
self.appservice = ApplicationService(
token="i_am_an_app_service",
id="1234",
@@ -448,6 +458,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
{"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
)
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_limit_correct(self) -> None:
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 1))
+ self.assertTrue(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+
@override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_stop_words(self) -> None:
"""Tests that a user can look up another user by searching for the start if its
@@ -461,3 +477,39 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
r["results"][0],
{"user_id": BELA, "display_name": "Bela", "avatar_url": None},
)
+
+
+class UserDirectoryICUTestCase(HomeserverTestCase):
+ if not icu:
+ skip = "Requires PyICU"
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.user_dir_helper = GetUserDirectoryTables(self.store)
+
+ def test_icu_word_boundary(self) -> None:
+ """Tests that we correctly detect word boundaries when ICU (International
+ Components for Unicode) support is available.
+ """
+
+ display_name = "Gáo"
+
+ # This word is not broken down correctly by Python's regular expressions,
+ # likely because á is actually a lowercase a followed by a U+0301 combining
+ # acute accent. This is specifically something that ICU support fixes.
+ matches = re.findall(r"([\w\-]+)", display_name, re.UNICODE)
+ self.assertEqual(len(matches), 2)
+
+ self.get_success(
+ self.store.update_profile_in_user_dir(ALICE, display_name, None)
+ )
+ self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE,)))
+
+ # Check that searching for this user yields the correct result.
+ r = self.get_success(self.store.search_user_dir(BOB, display_name, 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(len(r["results"]), 1)
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": ALICE, "display_name": display_name, "avatar_url": None},
+ )
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index cae14151c0..0e3fc2a77f 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict
+from typing import Collection, Dict
from unittest import mock
from twisted.internet.defer import CancelledError, ensureDeferred
@@ -31,7 +31,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
# the results to be returned by the mocked get_partial_state_events
self._events_dict: Dict[str, bool] = {}
- async def get_partial_state_events(events):
+ async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
return {e: self._events_dict[e] for e in events}
self.mock_store = mock.Mock(spec_set=["get_partial_state_events"])
@@ -39,7 +39,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker = PartialStateEventsTracker(self.mock_store)
- def test_does_not_block_for_full_state_events(self):
+ def test_does_not_block_for_full_state_events(self) -> None:
self._events_dict = {"event1": False, "event2": False}
self.successResultOf(
@@ -50,7 +50,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
["event1", "event2"]
)
- def test_blocks_for_partial_state_events(self):
+ def test_blocks_for_partial_state_events(self) -> None:
self._events_dict = {"event1": True, "event2": False}
d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
@@ -62,12 +62,12 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d)
- def test_un_partial_state_race(self):
+ def test_un_partial_state_race(self) -> None:
# if the event is un-partial-stated between the initial check and the
# registration of the listener, it should not block.
self._events_dict = {"event1": True, "event2": False}
- async def get_partial_state_events(events):
+ async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
res = {e: self._events_dict[e] for e in events}
# change the result for next time
self._events_dict = {"event1": False, "event2": False}
@@ -79,19 +79,19 @@ class PartialStateEventsTrackerTestCase(TestCase):
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)
- def test_un_partial_state_during_get_partial_state_events(self):
+ def test_un_partial_state_during_get_partial_state_events(self) -> None:
# we should correctly handle a call to notify_un_partial_stated during the
# second call to get_partial_state_events.
self._events_dict = {"event1": True, "event2": False}
- async def get_partial_state_events1(events):
+ async def get_partial_state_events1(events: Collection[str]) -> Dict[str, bool]:
self.mock_store.get_partial_state_events.side_effect = (
get_partial_state_events2
)
return {e: self._events_dict[e] for e in events}
- async def get_partial_state_events2(events):
+ async def get_partial_state_events2(events: Collection[str]) -> Dict[str, bool]:
self.tracker.notify_un_partial_stated("event1")
self._events_dict["event1"] = False
return {e: self._events_dict[e] for e in events}
@@ -102,7 +102,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)
- def test_cancellation(self):
+ def test_cancellation(self) -> None:
self._events_dict = {"event1": True, "event2": False}
d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
@@ -127,12 +127,12 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.tracker = PartialCurrentStateTracker(self.mock_store)
- def test_does_not_block_for_full_state_rooms(self):
+ def test_does_not_block_for_full_state_rooms(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
- def test_blocks_for_partial_room_state(self):
+ def test_blocks_for_partial_room_state(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
d = ensureDeferred(self.tracker.await_full_state("room_id"))
@@ -144,10 +144,10 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("room_id")
self.successResultOf(d)
- def test_un_partial_state_race(self):
+ def test_un_partial_state_race(self) -> None:
# We should correctly handle race between awaiting the state and us
# un-partialling the state
- async def is_partial_state_room(events):
+ async def is_partial_state_room(room_id: str) -> bool:
self.tracker.notify_un_partial_stated("room_id")
return True
@@ -155,7 +155,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
- def test_cancellation(self):
+ def test_cancellation(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index f4d9fba0a1..0a7937f1cc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -13,7 +13,7 @@
# limitations under the License.
import unittest
-from typing import Collection, Dict, Iterable, List, Optional
+from typing import Any, Collection, Dict, Iterable, List, Optional
from parameterized import parameterized
@@ -728,6 +728,36 @@ class EventAuthTestCase(unittest.TestCase):
pl_event.room_version, pl_event2, {("fake_type", "fake_key"): pl_event}
)
+ def test_room_v10_rejects_other_non_integer_power_levels(self) -> None:
+ """We should reject PLs that are non-integer, non-string JSON values.
+
+ test_room_v10_rejects_string_power_levels above handles the string case.
+ """
+
+ def create_event(pl_event_content: Dict[str, Any]) -> EventBase:
+ return make_event_from_dict(
+ {
+ "room_id": TEST_ROOM_ID,
+ **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10),
+ "type": "m.room.power_levels",
+ "sender": "@test:test.com",
+ "state_key": "",
+ "content": pl_event_content,
+ "signatures": {"test.com": {"ed25519:0": "some9signature"}},
+ },
+ room_version=RoomVersions.V10,
+ )
+
+ contents: Iterable[Dict[str, Any]] = [
+ {"notifications": {"room": None}},
+ {"users": {"@alice:wonderland": []}},
+ {"users_default": {}},
+ ]
+ for content in contents:
+ event = create_event(content)
+ with self.assertRaises(SynapseError):
+ event_auth._check_power_levels(event.room_version, event, {})
+
# helpers for making events
TEST_DOMAIN = "example.com"
diff --git a/tests/test_server.py b/tests/test_server.py
index 6f35966d0c..27537758c4 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -174,7 +174,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
)
- self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 9228454c9e..304c7b98c5 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -17,6 +17,7 @@ import os
import twisted.logger
from synapse.logging.context import LoggingContextFilter
+from synapse.synapse_rust import reset_logging_config
class ToTwistedHandler(logging.Handler):
@@ -52,3 +53,5 @@ def setup_logging():
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
root_logger.setLevel(log_level)
+
+ reset_logging_config()
diff --git a/tests/types/__init__.py b/tests/types/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/types/test_state.py b/tests/types/test_state.py
new file mode 100644
index 0000000000..eb809f9fb7
--- /dev/null
+++ b/tests/types/test_state.py
@@ -0,0 +1,627 @@
+from frozendict import frozendict
+
+from synapse.api.constants import EventTypes
+from synapse.types.state import StateFilter
+
+from tests.unittest import TestCase
+
+
+class StateFilterDifferenceTestCase(TestCase):
+ def assert_difference(
+ self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
+ ) -> None:
+ self.assertEqual(
+ minuend.approx_difference(subtrahend),
+ expected,
+ f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
+ )
+
+ def test_state_filter_difference_no_include_other_minus_no_include_other(
+ self,
+ ) -> None:
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), both a and b do not have the
+ include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ StateFilter.freeze({EventTypes.Create: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:spqr"}},
+ include_others=False,
+ ),
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.CanonicalAlias: {""}},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), only a has the include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Create: None,
+ EventTypes.Member: set(),
+ EventTypes.CanonicalAlias: set(),
+ },
+ include_others=True,
+ ),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ # This also shows that the resultant state filter is normalised.
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.Create: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter(types=frozendict(), include_others=True),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.CanonicalAlias: {""},
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ )
+
+ def test_state_filter_difference_include_other_minus_include_other(self) -> None:
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), both a and b have the include_others
+ flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=True,
+ ),
+ StateFilter(types=frozendict(), include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ EventTypes.Create: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.Create: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.Create: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), only b has the include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=True,
+ ),
+ StateFilter(types=frozendict(), include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:spqr"}},
+ include_others=True,
+ ),
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_simple_cases(self) -> None:
+ """
+ Tests some very simple cases of the StateFilter approx_difference,
+ that are not explicitly tested by the more in-depth tests.
+ """
+
+ self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
+
+ self.assert_difference(
+ StateFilter.all(),
+ StateFilter.none(),
+ StateFilter.all(),
+ )
+
+
+class StateFilterTestCase(TestCase):
+ def test_return_expanded(self) -> None:
+ """
+ Tests the behaviour of the return_expanded() function that expands
+ StateFilters to include more state types (for the sake of cache hit rate).
+ """
+
+ self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
+
+ self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
+
+ # Concrete-only state filters stay the same
+ # (Case: mixed filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": {""},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # Concrete-only state filters stay the same
+ # (Case: non-member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {"some.other.state.type": {""}}, include_others=False
+ ).return_expanded(),
+ StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
+ )
+
+ # Concrete-only state filters stay the same
+ # (Case: member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ },
+ include_others=False,
+ ),
+ )
+
+ # Wildcard member-only state filters stay the same
+ self.assertEqual(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # If there is a wildcard in the non-member portion of the filter,
+ # it's expanded to include ALL non-member events.
+ # (Case: mixed filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": None,
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
+ include_others=True,
+ ),
+ )
+
+ # If there is a wildcard in the non-member portion of the filter,
+ # it's expanded to include ALL non-member events.
+ # (Case: non-member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ "some.other.state.type": None,
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+ )
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ "some.other.state.type": None,
+ "yet.another.state.type": {"wombat"},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+ )
diff --git a/tests/unittest.py b/tests/unittest.py
index 532b92d43a..50aa5abda9 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -75,6 +75,7 @@ from synapse.util.httpresourcetree import create_resource_tree
from tests.server import (
CustomHeaderType,
FakeChannel,
+ ThreadedMemoryReactorClock,
get_clock,
make_request,
setup_test_homeserver,
@@ -360,7 +361,7 @@ class HomeserverTestCase(TestCase):
store.db_pool.updates.do_next_background_update(False), by=0.1
)
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock):
+ def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock):
"""
Make and return a homeserver.
diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py
index 80b97167ba..9266f12590 100644
--- a/tests/util/caches/test_cached_call.py
+++ b/tests/util/caches/test_cached_call.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import NoReturn
from unittest.mock import Mock
from twisted.internet import defer
@@ -23,14 +24,14 @@ from tests.unittest import TestCase
class CachedCallTestCase(TestCase):
- def test_get(self):
+ def test_get(self) -> None:
"""
Happy-path test case: makes a couple of calls and makes sure they behave
correctly
"""
- d = Deferred()
+ d: "Deferred[int]" = Deferred()
- async def f():
+ async def f() -> int:
return await d
slow_call = Mock(side_effect=f)
@@ -43,7 +44,7 @@ class CachedCallTestCase(TestCase):
# now fire off a couple of calls
completed_results = []
- async def r():
+ async def r() -> None:
res = await cached_call.get()
completed_results.append(res)
@@ -69,12 +70,12 @@ class CachedCallTestCase(TestCase):
self.assertEqual(r3, 123)
slow_call.assert_not_called()
- def test_fast_call(self):
+ def test_fast_call(self) -> None:
"""
Test the behaviour when the underlying function completes immediately
"""
- async def f():
+ async def f() -> int:
return 12
fast_call = Mock(side_effect=f)
@@ -92,12 +93,12 @@ class CachedCallTestCase(TestCase):
class RetryOnExceptionCachedCallTestCase(TestCase):
- def test_get(self):
+ def test_get(self) -> None:
# set up the RetryOnExceptionCachedCall around a function which will fail
# (after a while)
- d = Deferred()
+ d: "Deferred[int]" = Deferred()
- async def f1():
+ async def f1() -> NoReturn:
await d
raise ValueError("moo")
@@ -110,7 +111,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
# now fire off a couple of calls
completed_results = []
- async def r():
+ async def r() -> None:
try:
await cached_call.get()
except Exception as e1:
@@ -137,7 +138,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
# to the getter
d = Deferred()
- async def f2():
+ async def f2() -> int:
return await d
slow_call.reset_mock()
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 02b99b466a..f74d82b1dc 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
+from typing import List, Tuple
from twisted.internet import defer
@@ -22,20 +23,20 @@ from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase):
- def test_empty(self):
- cache = DeferredCache("test")
+ def test_empty(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
with self.assertRaises(KeyError):
cache.get("foo")
- def test_hit(self):
- cache = DeferredCache("test")
+ def test_hit(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
cache.prefill("foo", 123)
self.assertEqual(self.successResultOf(cache.get("foo")), 123)
- def test_hit_deferred(self):
- cache = DeferredCache("test")
- origin_d = defer.Deferred()
+ def test_hit_deferred(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
+ origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d)
# get should return an incomplete deferred
@@ -43,7 +44,7 @@ class DeferredCacheTestCase(TestCase):
self.assertFalse(get_d.called)
# add a callback that will make sure that the set_d gets called before the get_d
- def check1(r):
+ def check1(r: str) -> str:
self.assertTrue(set_d.called)
return r
@@ -55,16 +56,16 @@ class DeferredCacheTestCase(TestCase):
self.assertEqual(self.successResultOf(set_d), 99)
self.assertEqual(self.successResultOf(get_d), 99)
- def test_callbacks(self):
+ def test_callbacks(self) -> None:
"""Invalidation callbacks are called at the right time"""
- cache = DeferredCache("test")
+ cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
- origin_d = defer.Deferred()
+ origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
@@ -89,15 +90,15 @@ class DeferredCacheTestCase(TestCase):
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"set", "get"})
- def test_set_fail(self):
- cache = DeferredCache("test")
+ def test_set_fail(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
- origin_d = defer.Deferred()
+ origin_d: defer.Deferred = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
@@ -126,9 +127,9 @@ class DeferredCacheTestCase(TestCase):
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"prefill", "get2"})
- def test_get_immediate(self):
- cache = DeferredCache("test")
- d1 = defer.Deferred()
+ def test_get_immediate(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
+ d1: "defer.Deferred[int]" = defer.Deferred()
cache.set("key1", d1)
# get_immediate should return default
@@ -142,27 +143,27 @@ class DeferredCacheTestCase(TestCase):
v = cache.get_immediate("key1", 1)
self.assertEqual(v, 2)
- def test_invalidate(self):
- cache = DeferredCache("test")
+ def test_invalidate(self) -> None:
+ cache: DeferredCache[Tuple[str], int] = DeferredCache("test")
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))
with self.assertRaises(KeyError):
cache.get(("foo",))
- def test_invalidate_all(self):
- cache = DeferredCache("testcache")
+ def test_invalidate_all(self) -> None:
+ cache: DeferredCache[str, str] = DeferredCache("testcache")
callback_record = [False, False]
- def record_callback(idx):
+ def record_callback(idx: int) -> None:
callback_record[idx] = True
# add a couple of pending entries
- d1 = defer.Deferred()
+ d1: "defer.Deferred[str]" = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
- d2 = defer.Deferred()
+ d2: "defer.Deferred[str]" = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return pending deferreds
@@ -193,8 +194,8 @@ class DeferredCacheTestCase(TestCase):
with self.assertRaises(KeyError):
cache.get("key1", None)
- def test_eviction(self):
- cache = DeferredCache(
+ def test_eviction(self) -> None:
+ cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
@@ -208,8 +209,8 @@ class DeferredCacheTestCase(TestCase):
cache.get(2)
cache.get(3)
- def test_eviction_lru(self):
- cache = DeferredCache(
+ def test_eviction_lru(self) -> None:
+ cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
@@ -227,8 +228,8 @@ class DeferredCacheTestCase(TestCase):
cache.get(1)
cache.get(3)
- def test_eviction_iterable(self):
- cache = DeferredCache(
+ def test_eviction_iterable(self) -> None:
+ cache: DeferredCache[int, List[str]] = DeferredCache(
"test",
max_entries=3,
apply_cache_factor_from_config=False,
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 43475a307f..13f1edd533 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Iterable, Set, Tuple
+from typing import Iterable, Set, Tuple, cast
from unittest import mock
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.interfaces import IReactorTime
from synapse.api.errors import SynapseError
from synapse.logging.context import (
@@ -37,8 +38,8 @@ logger = logging.getLogger(__name__)
def run_on_reactor():
- d = defer.Deferred()
- reactor.callLater(0, d.callback, 0)
+ d: "Deferred[int]" = defer.Deferred()
+ cast(IReactorTime, reactor).callLater(0, d.callback, 0)
return make_deferred_yieldable(d)
@@ -224,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase):
callbacks: Set[str] = set()
# set off an asynchronous request
- obj.result = origin_d = defer.Deferred()
+ origin_d: Deferred = defer.Deferred()
+ obj.result = origin_d
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
self.assertFalse(d1.called)
@@ -262,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase):
"""Check that logcontexts are set and restored correctly when
using the cache."""
- complete_lookup = defer.Deferred()
+ complete_lookup: Deferred = defer.Deferred()
class Cls:
@descriptors.cached()
@@ -772,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
- assert current_context().name == "c1"
+ context = current_context()
+ assert isinstance(context, LoggingContext)
+ assert context.name == "c1"
# we want this to behave like an asynchronous function
await run_on_reactor()
- assert current_context().name == "c1"
+ context = current_context()
+ assert isinstance(context, LoggingContext)
+ assert context.name == "c1"
return self.mock(args1, arg2)
with LoggingContext("c1") as c1:
@@ -834,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
return self.mock(args1)
obj = Cls()
- deferred_result = Deferred()
+ deferred_result: "Deferred[dict]" = Deferred()
obj.mock.return_value = deferred_result
# start off several concurrent lookups of the same key
diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py
index 025b73e32f..f09eeecada 100644
--- a/tests/util/caches/test_response_cache.py
+++ b/tests/util/caches/test_response_cache.py
@@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase):
(These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
"""
- def setUp(self):
+ def setUp(self) -> None:
self.reactor, self.clock = get_clock()
def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
@@ -49,7 +49,7 @@ class ResponseCacheTestCase(TestCase):
await self.clock.sleep(1)
return o
- def test_cache_hit(self):
+ def test_cache_hit(self) -> None:
cache = self.with_cache("keeping_cache", ms=9001)
expected_result = "howdy"
@@ -74,7 +74,7 @@ class ResponseCacheTestCase(TestCase):
"cache should still have the result",
)
- def test_cache_miss(self):
+ def test_cache_miss(self) -> None:
cache = self.with_cache("trashing_cache", ms=0)
expected_result = "howdy"
@@ -90,7 +90,7 @@ class ResponseCacheTestCase(TestCase):
)
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
- def test_cache_expire(self):
+ def test_cache_expire(self) -> None:
cache = self.with_cache("short_cache", ms=1000)
expected_result = "howdy"
@@ -115,7 +115,7 @@ class ResponseCacheTestCase(TestCase):
self.reactor.pump((2,))
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
- def test_cache_wait_hit(self):
+ def test_cache_wait_hit(self) -> None:
cache = self.with_cache("neutral_cache")
expected_result = "howdy"
@@ -131,7 +131,7 @@ class ResponseCacheTestCase(TestCase):
self.assertEqual(expected_result, self.successResultOf(wrap_d))
- def test_cache_wait_expire(self):
+ def test_cache_wait_expire(self) -> None:
cache = self.with_cache("medium_cache", ms=3000)
expected_result = "howdy"
@@ -162,7 +162,7 @@ class ResponseCacheTestCase(TestCase):
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
@parameterized.expand([(True,), (False,)])
- def test_cache_context_nocache(self, should_cache: bool):
+ def test_cache_context_nocache(self, should_cache: bool) -> None:
"""If the callback clears the should_cache bit, the result should not be cached"""
cache = self.with_cache("medium_cache", ms=3000)
@@ -170,7 +170,7 @@ class ResponseCacheTestCase(TestCase):
call_count = 0
- async def non_caching(o: str, cache_context: ResponseCacheContext[int]):
+ async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str:
nonlocal call_count
call_count += 1
await self.clock.sleep(1)
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
index fe8314057d..679d1eb36b 100644
--- a/tests/util/caches/test_ttlcache.py
+++ b/tests/util/caches/test_ttlcache.py
@@ -20,11 +20,11 @@ from tests import unittest
class CacheTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.mock_timer = Mock(side_effect=lambda: 100.0)
- self.cache = TTLCache("test_cache", self.mock_timer)
+ self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer)
- def test_get(self):
+ def test_get(self) -> None:
"""simple set/get tests"""
self.cache.set("one", "1", 10)
self.cache.set("two", "2", 20)
@@ -59,7 +59,7 @@ class CacheTestCase(unittest.TestCase):
self.assertEqual(self.cache._metrics.hits, 4)
self.assertEqual(self.cache._metrics.misses, 5)
- def test_expiry(self):
+ def test_expiry(self) -> None:
self.cache.set("one", "1", 10)
self.cache.set("two", "2", 20)
self.cache.set("three", "3", 30)
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index 9d5010bf92..91cac9822a 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
+from typing import Generator, List, NoReturn, Optional
from parameterized import parameterized_class
@@ -41,8 +42,8 @@ from tests.unittest import TestCase
class ObservableDeferredTest(TestCase):
- def test_succeed(self):
- origin_d = Deferred()
+ def test_succeed(self) -> None:
+ origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d)
observer1 = observable.observe()
@@ -52,16 +53,18 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer2.called)
# check the first observer is called first
- def check_called_first(res):
+ def check_called_first(res: int) -> int:
self.assertFalse(observer2.called)
return res
observer1.addBoth(check_called_first)
# store the results
- results = [None, None]
+ results: List[Optional[ObservableDeferred[int]]] = [None, None]
- def check_val(res, idx):
+ def check_val(
+ res: ObservableDeferred[int], idx: int
+ ) -> ObservableDeferred[int]:
results[idx] = res
return res
@@ -72,8 +75,8 @@ class ObservableDeferredTest(TestCase):
self.assertEqual(results[0], 123, "observer 1 callback result")
self.assertEqual(results[1], 123, "observer 2 callback result")
- def test_failure(self):
- origin_d = Deferred()
+ def test_failure(self) -> None:
+ origin_d: Deferred = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
observer1 = observable.observe()
@@ -83,16 +86,16 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer2.called)
# check the first observer is called first
- def check_called_first(res):
+ def check_called_first(res: int) -> int:
self.assertFalse(observer2.called)
return res
observer1.addBoth(check_called_first)
# store the results
- results = [None, None]
+ results: List[Optional[ObservableDeferred[str]]] = [None, None]
- def check_val(res, idx):
+ def check_val(res: ObservableDeferred[str], idx: int) -> None:
results[idx] = res
return None
@@ -103,10 +106,12 @@ class ObservableDeferredTest(TestCase):
raise Exception("gah!")
except Exception as e:
origin_d.errback(e)
+ assert results[0] is not None
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
+ assert results[1] is not None
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
- def test_cancellation(self):
+ def test_cancellation(self) -> None:
"""Test that cancelling an observer does not affect other observers."""
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
@@ -136,37 +141,38 @@ class ObservableDeferredTest(TestCase):
class TimeoutDeferredTest(TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock = Clock()
- def test_times_out(self):
+ def test_times_out(self) -> None:
"""Basic test case that checks that the original deferred is cancelled and that
the timing-out deferred is errbacked
"""
- cancelled = [False]
+ cancelled = False
- def canceller(_d):
- cancelled[0] = True
+ def canceller(_d: Deferred) -> None:
+ nonlocal cancelled
+ cancelled = True
- non_completing_d = Deferred(canceller)
+ non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
- self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
+ self.assertFalse(cancelled, "deferred was cancelled prematurely")
self.clock.pump((1.0,))
- self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
+ self.assertTrue(cancelled, "deferred was not cancelled by timeout")
self.failureResultOf(timing_out_d, defer.TimeoutError)
- def test_times_out_when_canceller_throws(self):
+ def test_times_out_when_canceller_throws(self) -> None:
"""Test that we have successfully worked around
https://twistedmatrix.com/trac/ticket/9534"""
- def canceller(_d):
+ def canceller(_d: Deferred) -> None:
raise Exception("can't cancel this deferred")
- non_completing_d = Deferred(canceller)
+ non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
@@ -175,22 +181,24 @@ class TimeoutDeferredTest(TestCase):
self.failureResultOf(timing_out_d, defer.TimeoutError)
- def test_logcontext_is_preserved_on_cancellation(self):
- blocking_was_cancelled = [False]
+ def test_logcontext_is_preserved_on_cancellation(self) -> None:
+ blocking_was_cancelled = False
@defer.inlineCallbacks
- def blocking():
- non_completing_d = Deferred()
+ def blocking() -> Generator["Deferred[object]", object, None]:
+ nonlocal blocking_was_cancelled
+
+ non_completing_d: Deferred = Deferred()
with PreserveLoggingContext():
try:
yield non_completing_d
except CancelledError:
- blocking_was_cancelled[0] = True
+ blocking_was_cancelled = True
raise
with LoggingContext("one") as context_one:
# the errbacks should be run in the test logcontext
- def errback(res, deferred_name):
+ def errback(res: Failure, deferred_name: str) -> Failure:
self.assertIs(
current_context(),
context_one,
@@ -209,7 +217,7 @@ class TimeoutDeferredTest(TestCase):
self.clock.pump((1.0,))
self.assertTrue(
- blocking_was_cancelled[0], "non-completing deferred was not cancelled"
+ blocking_was_cancelled, "non-completing deferred was not cancelled"
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one)
@@ -220,13 +228,13 @@ class _TestException(Exception):
class ConcurrentlyExecuteTest(TestCase):
- def test_limits_runners(self):
+ def test_limits_runners(self) -> None:
"""If we have more tasks than runners, we should get the limit of runners"""
started = 0
waiters = []
processed = []
- async def callback(v):
+ async def callback(v: int) -> None:
# when we first enter, bump the start count
nonlocal started
started += 1
@@ -235,7 +243,7 @@ class ConcurrentlyExecuteTest(TestCase):
processed.append(v)
# wait for the goahead before returning
- d2 = Deferred()
+ d2: "Deferred[int]" = Deferred()
waiters.append(d2)
await d2
@@ -265,16 +273,16 @@ class ConcurrentlyExecuteTest(TestCase):
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
self.successResultOf(d2)
- def test_preserves_stacktraces(self):
+ def test_preserves_stacktraces(self) -> None:
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
- d1 = Deferred()
+ d1: "Deferred[int]" = Deferred()
- async def callback(v):
+ async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here
await d1
raise _TestException("bah")
- async def caller():
+ async def caller() -> None:
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
@@ -290,17 +298,17 @@ class ConcurrentlyExecuteTest(TestCase):
d1.callback(0)
self.successResultOf(d2)
- def test_preserves_stacktraces_on_preformed_failure(self):
+ def test_preserves_stacktraces_on_preformed_failure(self) -> None:
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
- d1 = Deferred()
+ d1: "Deferred[int]" = Deferred()
f = Failure(_TestException("bah"))
- async def callback(v):
+ async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here
await d1
await defer.fail(f)
- async def caller():
+ async def caller() -> None:
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
@@ -336,7 +344,7 @@ class CancellationWrapperTests(TestCase):
else:
raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
- def test_succeed(self):
+ def test_succeed(self) -> None:
"""Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
@@ -346,7 +354,7 @@ class CancellationWrapperTests(TestCase):
self.assertTrue(wrapper_deferred.called)
self.assertEqual("success", self.successResultOf(wrapper_deferred))
- def test_failure(self):
+ def test_failure(self) -> None:
"""Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
@@ -361,7 +369,7 @@ class CancellationWrapperTests(TestCase):
class StopCancellationTests(TestCase):
"""Tests for the `stop_cancellation` function."""
- def test_cancellation(self):
+ def test_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
@@ -384,7 +392,7 @@ class StopCancellationTests(TestCase):
class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""
- def test_deferred_cancellation(self):
+ def test_deferred_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@@ -405,12 +413,12 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
- def test_coroutine_cancellation(self):
+ def test_coroutine_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original."""
blocking_deferred: "Deferred[None]" = Deferred()
completion_deferred: "Deferred[None]" = Deferred()
- async def task():
+ async def task() -> NoReturn:
await blocking_deferred
completion_deferred.callback(None)
# Raise an exception. Twisted should consume it, otherwise unwanted
@@ -434,7 +442,7 @@ class DelayCancellationTests(TestCase):
# Now that the original coroutine has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
- def test_suppresses_second_cancellation(self):
+ def test_suppresses_second_cancellation(self) -> None:
"""Test that a second cancellation is suppressed.
Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
@@ -459,7 +467,7 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
- def test_propagates_cancelled_error(self):
+ def test_propagates_cancelled_error(self) -> None:
"""Test that a `CancelledError` from the original `Deferred` gets propagated."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@@ -472,14 +480,14 @@ class DelayCancellationTests(TestCase):
self.assertTrue(wrapper_deferred.called)
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
- def test_preserves_logcontext(self):
+ def test_preserves_logcontext(self) -> None:
"""Test that logging contexts are preserved."""
blocking_d: "Deferred[None]" = Deferred()
- async def inner():
+ async def inner() -> None:
await make_deferred_yieldable(blocking_d)
- async def outer():
+ async def outer() -> None:
with LoggingContext("c") as c:
try:
await delay_cancellation(inner())
@@ -503,7 +511,7 @@ class DelayCancellationTests(TestCase):
class AwakenableSleeperTests(TestCase):
"Tests AwakenableSleeper"
- def test_sleep(self):
+ def test_sleep(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@@ -518,7 +526,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
self.assertTrue(d.called)
- def test_explicit_wake(self):
+ def test_explicit_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@@ -535,7 +543,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
- def test_multiple_sleepers_timeout(self):
+ def test_multiple_sleepers_timeout(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@@ -555,7 +563,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
self.assertTrue(d2.called)
- def test_multiple_sleepers_wake(self):
+ def test_multiple_sleepers_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
index 07be57d72c..94ef91f645 100644
--- a/tests/util/test_batching_queue.py
+++ b/tests/util/test_batching_queue.py
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+
+from prometheus_client import Gauge
+
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable
@@ -26,7 +30,7 @@ from tests.unittest import TestCase
class BatchingQueueTestCase(TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock, hs_clock = get_clock()
# We ensure that we remove any existing metrics for "test_queue".
@@ -37,25 +41,27 @@ class BatchingQueueTestCase(TestCase):
except KeyError:
pass
- self._pending_calls = []
- self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
+ self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
+ self.queue: BatchingQueue[str, str] = BatchingQueue(
+ "test_queue", hs_clock, self._process_queue
+ )
- async def _process_queue(self, values):
- d = defer.Deferred()
+ async def _process_queue(self, values: List[str]) -> str:
+ d: "defer.Deferred[str]" = defer.Deferred()
self._pending_calls.append((values, d))
return await make_deferred_yieldable(d)
- def _get_sample_with_name(self, metric, name) -> int:
+ def _get_sample_with_name(self, metric: Gauge, name: str) -> float:
"""For a prometheus metric get the value of the sample that has a
matching "name" label.
"""
- for sample in metric.collect()[0].samples:
+ for sample in next(iter(metric.collect())).samples:
if sample.labels.get("name") == name:
return sample.value
self.fail("Found no matching sample")
- def _assert_metrics(self, queued, keys, in_flight):
+ def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None:
"""Assert that the metrics are correct"""
sample = self._get_sample_with_name(number_queued, self.queue._name)
@@ -75,7 +81,7 @@ class BatchingQueueTestCase(TestCase):
"number_in_flight",
)
- def test_simple(self):
+ def test_simple(self) -> None:
"""Tests the basic case of calling `add_to_queue` once and having
`_process_queue` return.
"""
@@ -106,7 +112,7 @@ class BatchingQueueTestCase(TestCase):
self._assert_metrics(queued=0, keys=0, in_flight=0)
- def test_batching(self):
+ def test_batching(self) -> None:
"""Test that multiple calls at the same time get batched up into one
call to `_process_queue`.
"""
@@ -134,7 +140,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d2), "bar")
self._assert_metrics(queued=0, keys=0, in_flight=0)
- def test_queuing(self):
+ def test_queuing(self) -> None:
"""Test that we queue up requests while a `_process_queue` is being
called.
"""
@@ -184,7 +190,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d3), "bar2")
self._assert_metrics(queued=0, keys=0, in_flight=0)
- def test_different_keys(self):
+ def test_different_keys(self) -> None:
"""Test that calls to different keys get processed in parallel."""
self.assertFalse(self._pending_calls)
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 6913de24b9..aa20fe6780 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -1,5 +1,20 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from contextlib import contextmanager
-from typing import Generator, Optional
+from os import PathLike
+from typing import Generator, Optional, Union
from unittest.mock import patch
from synapse.util.check_dependencies import (
@@ -12,17 +27,17 @@ from tests.unittest import TestCase
class DummyDistribution(metadata.Distribution):
- def __init__(self, version: object):
+ def __init__(self, version: str):
self._version = version
@property
- def version(self):
+ def version(self) -> str:
return self._version
- def locate_file(self, path):
+ def locate_file(self, path: Union[str, PathLike]) -> PathLike:
raise NotImplementedError()
- def read_text(self, filename):
+ def read_text(self, filename: str) -> None:
raise NotImplementedError()
@@ -30,7 +45,7 @@ old = DummyDistribution("0.1.2")
old_release_candidate = DummyDistribution("0.1.2rc3")
new = DummyDistribution("1.2.3")
new_release_candidate = DummyDistribution("1.2.3rc4")
-distribution_with_no_version = DummyDistribution(None)
+distribution_with_no_version = DummyDistribution(None) # type: ignore[arg-type]
# could probably use stdlib TestCase --- no need for twisted here
@@ -45,7 +60,7 @@ class TestDependencyChecker(TestCase):
If `distribution = None`, we pretend that the package is not installed.
"""
- def mock_distribution(name: str):
+ def mock_distribution(name: str) -> DummyDistribution:
if distribution is None:
raise metadata.PackageNotFoundError
else:
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index e8b6246ab5..acb251bfea 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -19,10 +19,12 @@ from tests import unittest
class DictCacheTestCase(unittest.TestCase):
- def setUp(self):
- self.cache = DictionaryCache("foobar", max_entries=10)
+ def setUp(self) -> None:
+ self.cache: DictionaryCache[str, str, str] = DictionaryCache(
+ "foobar", max_entries=10
+ )
- def test_simple_cache_hit_full(self):
+ def test_simple_cache_hit_full(self) -> None:
key = "test_simple_cache_hit_full"
v = self.cache.get(key)
@@ -37,7 +39,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key)
self.assertEqual(test_value, c.value)
- def test_simple_cache_hit_partial(self):
+ def test_simple_cache_hit_partial(self) -> None:
key = "test_simple_cache_hit_partial"
seq = self.cache.sequence
@@ -47,7 +49,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value)
- def test_simple_cache_miss_partial(self):
+ def test_simple_cache_miss_partial(self) -> None:
key = "test_simple_cache_miss_partial"
seq = self.cache.sequence
@@ -57,7 +59,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value)
- def test_simple_cache_hit_miss_partial(self):
+ def test_simple_cache_hit_miss_partial(self) -> None:
key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence
@@ -71,7 +73,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
- def test_multi_insert(self):
+ def test_multi_insert(self) -> None:
key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence
@@ -92,7 +94,7 @@ class DictCacheTestCase(unittest.TestCase):
)
self.assertEqual(c.full, False)
- def test_invalidation(self):
+ def test_invalidation(self) -> None:
"""Test that the partial dict and full dicts get invalidated
separately.
"""
@@ -106,7 +108,7 @@ class DictCacheTestCase(unittest.TestCase):
# entry for "a" warm.
for i in range(20):
self.cache.get(key, ["a"])
- self.cache.update(seq, f"key{i}", {1: 2})
+ self.cache.update(seq, f"key{i}", {"1": "2"})
# We should have evicted the full dict...
r = self.cache.get(key)
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 7f60aae5ba..9cf920daf8 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, cast
+from synapse.util import Clock
from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock
@@ -21,17 +23,21 @@ from .. import unittest
class ExpiringCacheTestCase(unittest.HomeserverTestCase):
- def test_get_set(self):
+ def test_get_set(self) -> None:
clock = MockClock()
- cache = ExpiringCache("test", clock, max_len=1)
+ cache: ExpiringCache[str, str] = ExpiringCache(
+ "test", cast(Clock, clock), max_len=1
+ )
cache["key"] = "value"
self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value")
- def test_eviction(self):
+ def test_eviction(self) -> None:
clock = MockClock()
- cache = ExpiringCache("test", clock, max_len=2)
+ cache: ExpiringCache[str, str] = ExpiringCache(
+ "test", cast(Clock, clock), max_len=2
+ )
cache["key"] = "value"
cache["key2"] = "value2"
@@ -43,9 +49,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key2"), "value2")
self.assertEqual(cache.get("key3"), "value3")
- def test_iterable_eviction(self):
+ def test_iterable_eviction(self) -> None:
clock = MockClock()
- cache = ExpiringCache("test", clock, max_len=5, iterable=True)
+ cache: ExpiringCache[str, List[int]] = ExpiringCache(
+ "test", cast(Clock, clock), max_len=5, iterable=True
+ )
cache["key"] = [1]
cache["key2"] = [2, 3]
@@ -61,9 +69,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key3"), [4, 5])
self.assertEqual(cache.get("key4"), [6, 7])
- def test_time_eviction(self):
+ def test_time_eviction(self) -> None:
clock = MockClock()
- cache = ExpiringCache("test", clock, expiry_ms=1000)
+ cache: ExpiringCache[str, int] = ExpiringCache(
+ "test", cast(Clock, clock), expiry_ms=1000
+ )
cache["key"] = 1
clock.advance_time(0.5)
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index 3bb4695405..4f3c983c15 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -12,22 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import threading
-from io import StringIO
+from io import BytesIO
+from typing import BinaryIO, Generator, Optional, cast
from unittest.mock import NonCallableMock
-from twisted.internet import defer, reactor
+from zope.interface import implementer
+from twisted.internet import defer, reactor as _reactor
+from twisted.internet.interfaces import IPullProducer
+
+from synapse.types import ISynapseReactor
from synapse.util.file_consumer import BackgroundFileConsumer
from tests import unittest
+reactor = cast(ISynapseReactor, _reactor)
+
class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
- def test_pull_consumer(self):
- string_file = StringIO()
+ def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
+ string_file = BytesIO()
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
@@ -35,55 +41,57 @@ class FileConsumerTests(unittest.TestCase):
yield producer.register_with_consumer(consumer)
- yield producer.write_and_wait("Foo")
+ yield producer.write_and_wait(b"Foo")
- self.assertEqual(string_file.getvalue(), "Foo")
+ self.assertEqual(string_file.getvalue(), b"Foo")
- yield producer.write_and_wait("Bar")
+ yield producer.write_and_wait(b"Bar")
- self.assertEqual(string_file.getvalue(), "FooBar")
+ self.assertEqual(string_file.getvalue(), b"FooBar")
finally:
consumer.unregisterProducer()
- yield consumer.wait()
+ yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed)
@defer.inlineCallbacks
- def test_push_consumer(self):
- string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file, reactor=reactor)
+ def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
+ string_file = BlockingBytesWrite()
+ consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
try:
producer = NonCallableMock(spec_set=[])
consumer.registerProducer(producer, True)
- consumer.write("Foo")
- yield string_file.wait_for_n_writes(1)
+ consumer.write(b"Foo")
+ yield string_file.wait_for_n_writes(1) # type: ignore[misc]
- self.assertEqual(string_file.buffer, "Foo")
+ self.assertEqual(string_file.buffer, b"Foo")
- consumer.write("Bar")
- yield string_file.wait_for_n_writes(2)
+ consumer.write(b"Bar")
+ yield string_file.wait_for_n_writes(2) # type: ignore[misc]
- self.assertEqual(string_file.buffer, "FooBar")
+ self.assertEqual(string_file.buffer, b"FooBar")
finally:
consumer.unregisterProducer()
- yield consumer.wait()
+ yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed)
@defer.inlineCallbacks
- def test_push_producer_feedback(self):
- string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file, reactor=reactor)
+ def test_push_producer_feedback(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
+ string_file = BlockingBytesWrite()
+ consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
- resume_deferred = defer.Deferred()
+ resume_deferred: defer.Deferred = defer.Deferred()
producer.resumeProducing.side_effect = lambda: resume_deferred.callback(
None
)
@@ -93,65 +101,72 @@ class FileConsumerTests(unittest.TestCase):
number_writes = 0
with string_file.write_lock:
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
- consumer.write("Foo")
+ consumer.write(b"Foo")
number_writes += 1
producer.pauseProducing.assert_called_once()
- yield string_file.wait_for_n_writes(number_writes)
+ yield string_file.wait_for_n_writes(number_writes) # type: ignore[misc]
yield resume_deferred
producer.resumeProducing.assert_called_once()
finally:
consumer.unregisterProducer()
- yield consumer.wait()
+ yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed)
+@implementer(IPullProducer)
class DummyPullProducer:
- def __init__(self):
- self.consumer = None
- self.deferred = defer.Deferred()
+ def __init__(self) -> None:
+ self.consumer: Optional[BackgroundFileConsumer] = None
+ self.deferred: "defer.Deferred[object]" = defer.Deferred()
- def resumeProducing(self):
+ def resumeProducing(self) -> None:
d = self.deferred
self.deferred = defer.Deferred()
d.callback(None)
- def write_and_wait(self, bytes):
+ def stopProducing(self) -> None:
+ raise RuntimeError("Unexpected call")
+
+ def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]":
+ assert self.consumer is not None
d = self.deferred
- self.consumer.write(bytes)
+ self.consumer.write(write_bytes)
return d
- def register_with_consumer(self, consumer):
+ def register_with_consumer(
+ self, consumer: BackgroundFileConsumer
+ ) -> "defer.Deferred[object]":
d = self.deferred
self.consumer = consumer
self.consumer.registerProducer(self, False)
return d
-class BlockingStringWrite:
- def __init__(self):
- self.buffer = ""
+class BlockingBytesWrite:
+ def __init__(self) -> None:
+ self.buffer = b""
self.closed = False
self.write_lock = threading.Lock()
- self._notify_write_deferred = None
+ self._notify_write_deferred: Optional[defer.Deferred] = None
self._number_of_writes = 0
- def write(self, bytes):
+ def write(self, write_bytes: bytes) -> None:
with self.write_lock:
- self.buffer += bytes
+ self.buffer += write_bytes
self._number_of_writes += 1
reactor.callFromThread(self._notify_write)
- def close(self):
+ def close(self) -> None:
self.closed = True
- def _notify_write(self):
+ def _notify_write(self) -> None:
"Called by write to indicate a write happened"
with self.write_lock:
if not self._notify_write_deferred:
@@ -161,7 +176,9 @@ class BlockingStringWrite:
d.callback(None)
@defer.inlineCallbacks
- def wait_for_n_writes(self, n):
+ def wait_for_n_writes(
+ self, n: int
+ ) -> Generator["defer.Deferred[object]", object, None]:
"Wait for n writes to have happened"
while True:
with self.write_lock:
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 3c0ddd4f18..406c16cdcf 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -19,7 +19,7 @@ from tests.unittest import TestCase
class ChunkSeqTests(TestCase):
- def test_short_seq(self):
+ def test_short_seq(self) -> None:
parts = chunk_seq("123", 8)
self.assertEqual(
@@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase):
["123"],
)
- def test_long_seq(self):
+ def test_long_seq(self) -> None:
parts = chunk_seq("abcdefghijklmnop", 8)
self.assertEqual(
@@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase):
["abcdefgh", "ijklmnop"],
)
- def test_uneven_parts(self):
+ def test_uneven_parts(self) -> None:
parts = chunk_seq("abcdefghijklmnop", 5)
self.assertEqual(
@@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase):
["abcde", "fghij", "klmno", "p"],
)
- def test_empty_input(self):
+ def test_empty_input(self) -> None:
parts: Iterable[Sequence] = chunk_seq([], 5)
self.assertEqual(
@@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase):
class SortTopologically(TestCase):
- def test_empty(self):
+ def test_empty(self) -> None:
"Test that an empty graph works correctly"
graph: Dict[int, List[int]] = {}
self.assertEqual(list(sorted_topologically([], graph)), [])
- def test_handle_empty_graph(self):
+ def test_handle_empty_graph(self) -> None:
"Test that a graph where a node doesn't have an entry is treated as empty"
graph: Dict[int, List[int]] = {}
@@ -67,7 +67,7 @@ class SortTopologically(TestCase):
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
- def test_disconnected(self):
+ def test_disconnected(self) -> None:
"Test that a graph with no edges work"
graph: Dict[int, List[int]] = {1: [], 2: []}
@@ -75,20 +75,20 @@ class SortTopologically(TestCase):
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
- def test_linear(self):
+ def test_linear(self) -> None:
"Test that a simple `4 -> 3 -> 2 -> 1` graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
- def test_subset(self):
+ def test_subset(self) -> None:
"Test that only sorting a subset of the graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
- def test_fork(self):
+ def test_fork(self) -> None:
"Test that a forked graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
@@ -96,13 +96,13 @@ class SortTopologically(TestCase):
# always get the same one.
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
- def test_duplicates(self):
+ def test_duplicates(self) -> None:
"Test that a graph with duplicate edges work"
graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
- def test_multiple_paths(self):
+ def test_multiple_paths(self) -> None:
"Test that a graph with multiple paths between two nodes work"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 2ad321e184..d64c162e1d 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,5 +1,21 @@
+# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Generator, cast
+
import twisted.python.failure
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as _reactor
from synapse.logging.context import (
SENTINEL_CONTEXT,
@@ -10,25 +26,30 @@ from synapse.logging.context import (
nested_logging_context,
run_in_background,
)
+from synapse.types import ISynapseReactor
from synapse.util import Clock
from .. import unittest
+reactor = cast(ISynapseReactor, _reactor)
+
class LoggingContextTestCase(unittest.TestCase):
- def _check_test_key(self, value):
- self.assertEqual(current_context().name, value)
+ def _check_test_key(self, value: str) -> None:
+ context = current_context()
+ assert isinstance(context, LoggingContext)
+ self.assertEqual(context.name, value)
- def test_with_context(self):
+ def test_with_context(self) -> None:
with LoggingContext("test"):
self._check_test_key("test")
@defer.inlineCallbacks
- def test_sleep(self):
+ def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]:
clock = Clock(reactor)
@defer.inlineCallbacks
- def competing_callback():
+ def competing_callback() -> Generator["defer.Deferred[object]", object, None]:
with LoggingContext("competing"):
yield clock.sleep(0)
self._check_test_key("competing")
@@ -39,17 +60,18 @@ class LoggingContextTestCase(unittest.TestCase):
yield clock.sleep(0)
self._check_test_key("one")
- def _test_run_in_background(self, function):
+ def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred:
sentinel_context = current_context()
- callback_completed = [False]
+ callback_completed = False
with LoggingContext("one"):
# fire off function, but don't wait on it.
d2 = run_in_background(function)
- def cb(res):
- callback_completed[0] = True
+ def cb(res: object) -> object:
+ nonlocal callback_completed
+ callback_completed = True
return res
d2.addCallback(cb)
@@ -60,8 +82,8 @@ class LoggingContextTestCase(unittest.TestCase):
# the logcontext is left in a sane state.
d2 = defer.Deferred()
- def check_logcontext():
- if not callback_completed[0]:
+ def check_logcontext() -> None:
+ if not callback_completed:
reactor.callLater(0.01, check_logcontext)
return
@@ -78,31 +100,31 @@ class LoggingContextTestCase(unittest.TestCase):
# test is done once d2 finishes
return d2
- def test_run_in_background_with_blocking_fn(self):
+ def test_run_in_background_with_blocking_fn(self) -> defer.Deferred:
@defer.inlineCallbacks
- def blocking_function():
+ def blocking_function() -> Generator["defer.Deferred[object]", object, None]:
yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function)
- def test_run_in_background_with_non_blocking_fn(self):
+ def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred:
@defer.inlineCallbacks
- def nonblocking_function():
+ def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]:
with PreserveLoggingContext():
yield defer.succeed(None)
return self._test_run_in_background(nonblocking_function)
- def test_run_in_background_with_chained_deferred(self):
+ def test_run_in_background_with_chained_deferred(self) -> defer.Deferred:
# a function which returns a deferred which looks like it has been
# called, but is actually paused
- def testfunc():
+ def testfunc() -> defer.Deferred:
return make_deferred_yieldable(_chained_deferred_function())
return self._test_run_in_background(testfunc)
- def test_run_in_background_with_coroutine(self):
- async def testfunc():
+ def test_run_in_background_with_coroutine(self) -> defer.Deferred:
+ async def testfunc() -> None:
self._check_test_key("one")
d = Clock(reactor).sleep(0)
self.assertIs(current_context(), SENTINEL_CONTEXT)
@@ -111,18 +133,20 @@ class LoggingContextTestCase(unittest.TestCase):
return self._test_run_in_background(testfunc)
- def test_run_in_background_with_nonblocking_coroutine(self):
- async def testfunc():
+ def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred:
+ async def testfunc() -> None:
self._check_test_key("one")
return self._test_run_in_background(testfunc)
@defer.inlineCallbacks
- def test_make_deferred_yieldable(self):
+ def test_make_deferred_yieldable(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
# a function which returns an incomplete deferred, but doesn't follow
# the synapse rules.
- def blocking_function():
- d = defer.Deferred()
+ def blocking_function() -> defer.Deferred:
+ d: defer.Deferred = defer.Deferred()
reactor.callLater(0, d.callback, None)
return d
@@ -139,7 +163,9 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
@defer.inlineCallbacks
- def test_make_deferred_yieldable_with_chained_deferreds(self):
+ def test_make_deferred_yieldable_with_chained_deferreds(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
sentinel_context = current_context()
with LoggingContext("one"):
@@ -152,7 +178,7 @@ class LoggingContextTestCase(unittest.TestCase):
# now it should be restored
self._check_test_key("one")
- def test_nested_logging_context(self):
+ def test_nested_logging_context(self) -> None:
with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.name, "foo-bar")
@@ -161,11 +187,11 @@ class LoggingContextTestCase(unittest.TestCase):
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on
# its callback list, so won't yet call any other new callbacks.
-def _chained_deferred_function():
+def _chained_deferred_function() -> defer.Deferred:
d = defer.succeed(None)
- def cb(res):
- d2 = defer.Deferred()
+ def cb(res: object) -> defer.Deferred:
+ d2: defer.Deferred = defer.Deferred()
reactor.callLater(0, d2.callback, res)
return d2
diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py
index a2e08281e6..0dee69a6fe 100644
--- a/tests/util/test_logformatter.py
+++ b/tests/util/test_logformatter.py
@@ -23,7 +23,7 @@ class TestException(Exception):
class LogFormatterTestCase(unittest.TestCase):
- def test_formatter(self):
+ def test_formatter(self) -> None:
formatter = LogFormatter()
try:
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 67173a4f5b..1fc5a473f0 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -13,10 +13,11 @@
# limitations under the License.
-from typing import List
+from typing import List, Tuple
from unittest.mock import Mock, patch
from synapse.metrics.jemalloc import JemallocStats
+from synapse.types import JsonDict
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
from synapse.util.caches.treecache import TreeCache
@@ -25,14 +26,14 @@ from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase):
- def test_get_set(self):
- cache = LruCache(1)
+ def test_get_set(self) -> None:
+ cache: LruCache[str, str] = LruCache(1)
cache["key"] = "value"
self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value")
- def test_eviction(self):
- cache = LruCache(2)
+ def test_eviction(self) -> None:
+ cache: LruCache[int, int] = LruCache(2)
cache[1] = 1
cache[2] = 2
@@ -45,8 +46,8 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(2), 2)
self.assertEqual(cache.get(3), 3)
- def test_setdefault(self):
- cache = LruCache(1)
+ def test_setdefault(self) -> None:
+ cache: LruCache[str, int] = LruCache(1)
self.assertEqual(cache.setdefault("key", 1), 1)
self.assertEqual(cache.get("key"), 1)
self.assertEqual(cache.setdefault("key", 2), 1)
@@ -54,14 +55,15 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
cache["key"] = 2 # Make sure overriding works.
self.assertEqual(cache.get("key"), 2)
- def test_pop(self):
- cache = LruCache(1)
+ def test_pop(self) -> None:
+ cache: LruCache[str, int] = LruCache(1)
cache["key"] = 1
self.assertEqual(cache.pop("key"), 1)
self.assertEqual(cache.pop("key"), None)
- def test_del_multi(self):
- cache = LruCache(4, cache_type=TreeCache)
+ def test_del_multi(self) -> None:
+ # The type here isn't quite correct as they don't handle TreeCache well.
+ cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom"
@@ -71,7 +73,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(("animal", "cat")), "mew")
self.assertEqual(cache.get(("vehicles", "car")), "vroom")
- cache.del_multi(("animal",))
+ cache.del_multi(("animal",)) # type: ignore[arg-type]
self.assertEqual(len(cache), 2)
self.assertEqual(cache.get(("animal", "cat")), None)
self.assertEqual(cache.get(("animal", "dog")), None)
@@ -79,22 +81,22 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(("vehicles", "train")), "chuff")
# Man from del_multi say "Yes".
- def test_clear(self):
- cache = LruCache(1)
+ def test_clear(self) -> None:
+ cache: LruCache[str, int] = LruCache(1)
cache["key"] = 1
cache.clear()
self.assertEqual(len(cache), 0)
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
- def test_special_size(self):
- cache = LruCache(10, "mycache")
+ def test_special_size(self) -> None:
+ cache: LruCache = LruCache(10, "mycache")
self.assertEqual(cache.max_size, 100)
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
- def test_get(self):
+ def test_get(self) -> None:
m = Mock()
- cache = LruCache(1)
+ cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
@@ -111,9 +113,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
- def test_multi_get(self):
+ def test_multi_get(self) -> None:
m = Mock()
- cache = LruCache(1)
+ cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
@@ -130,9 +132,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
- def test_set(self):
+ def test_set(self) -> None:
m = Mock()
- cache = LruCache(1)
+ cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@@ -146,9 +148,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
- def test_pop(self):
+ def test_pop(self) -> None:
m = Mock()
- cache = LruCache(1)
+ cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@@ -162,12 +164,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.pop("key")
self.assertEqual(m.call_count, 1)
- def test_del_multi(self):
+ def test_del_multi(self) -> None:
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
- cache = LruCache(4, cache_type=TreeCache)
+ # The type here isn't quite correct as they don't handle TreeCache well.
+ cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])
@@ -179,17 +182,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
self.assertEqual(m3.call_count, 0)
self.assertEqual(m4.call_count, 0)
- cache.del_multi(("a",))
+ cache.del_multi(("a",)) # type: ignore[arg-type]
self.assertEqual(m1.call_count, 1)
self.assertEqual(m2.call_count, 1)
self.assertEqual(m3.call_count, 0)
self.assertEqual(m4.call_count, 0)
- def test_clear(self):
+ def test_clear(self) -> None:
m1 = Mock()
m2 = Mock()
- cache = LruCache(5)
+ cache: LruCache[str, str] = LruCache(5)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@@ -202,11 +205,11 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
self.assertEqual(m1.call_count, 1)
self.assertEqual(m2.call_count, 1)
- def test_eviction(self):
+ def test_eviction(self) -> None:
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
- cache = LruCache(2)
+ cache: LruCache[str, str] = LruCache(2)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@@ -241,8 +244,8 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
class LruCacheSizedTestCase(unittest.HomeserverTestCase):
- def test_evict(self):
- cache = LruCache(5, size_callback=len)
+ def test_evict(self) -> None:
+ cache: LruCache[str, List[int]] = LruCache(5, size_callback=len)
cache["key1"] = [0]
cache["key2"] = [1, 2]
cache["key3"] = [3]
@@ -269,6 +272,7 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
cache["key1"] = []
self.assertEqual(len(cache), 0)
+ assert isinstance(cache.cache, dict)
cache.cache["key1"].drop_from_cache()
self.assertIsNone(
cache.pop("key1"), "Cache entry should have been evicted but wasn't"
@@ -278,17 +282,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
class TimeEvictionTestCase(unittest.HomeserverTestCase):
"""Test that time based eviction works correctly."""
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config.setdefault("caches", {})["expiry_time"] = "30m"
return config
- def test_evict(self):
+ def test_evict(self) -> None:
setup_expire_lru_cache_entries(self.hs)
- cache = LruCache(5, clock=self.hs.get_clock())
+ cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock())
# Check that we evict entries we haven't accessed for 30 minutes.
cache["key1"] = 1
@@ -332,7 +336,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
}
)
@patch("synapse.util.caches.lrucache.get_jemalloc_stats")
- def test_evict_memory(self, jemalloc_interface) -> None:
+ def test_evict_memory(self, jemalloc_interface: Mock) -> None:
mock_jemalloc_class = Mock(spec=JemallocStats)
jemalloc_interface.return_value = mock_jemalloc_class
@@ -340,7 +344,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
mock_jemalloc_class.get_stat.return_value = 924288000
setup_expire_lru_cache_entries(self.hs)
- cache = LruCache(4, clock=self.hs.get_clock())
+ cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock())
cache["key1"] = 1
cache["key2"] = 2
diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
index 40754a4711..e56ec2c860 100644
--- a/tests/util/test_macaroons.py
+++ b/tests/util/test_macaroons.py
@@ -21,14 +21,14 @@ from tests.unittest import TestCase
class MacaroonGeneratorTestCase(TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.reactor, hs_clock = get_clock()
self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret")
self.other_macaroon_generator = MacaroonGenerator(
hs_clock, "tesths", b"anothersecretkey"
)
- def test_guest_access_token(self):
+ def test_guest_access_token(self) -> None:
"""Test the generation and verification of guest access tokens"""
token = self.macaroon_generator.generate_guest_access_token("@user:tesths")
user_id = self.macaroon_generator.verify_guest_token(token)
@@ -47,7 +47,7 @@ class MacaroonGeneratorTestCase(TestCase):
with self.assertRaises(MacaroonVerificationFailedException):
self.macaroon_generator.verify_guest_token(token)
- def test_delete_pusher_token(self):
+ def test_delete_pusher_token(self) -> None:
"""Test the generation and verification of delete_pusher tokens"""
token = self.macaroon_generator.generate_delete_pusher_token(
"@user:tesths", "m.mail", "john@example.com"
@@ -84,7 +84,7 @@ class MacaroonGeneratorTestCase(TestCase):
)
self.assertEqual(user_id, "@user:tesths")
- def test_oidc_session_token(self):
+ def test_oidc_session_token(self) -> None:
"""Test the generation and verification of OIDC session cookies"""
state = "arandomstate"
session_data = OidcSessionData(
@@ -92,6 +92,7 @@ class MacaroonGeneratorTestCase(TestCase):
nonce="nonce",
client_redirect_url="https://example.com/",
ui_auth_session_id="",
+ code_verifier="",
)
token = self.macaroon_generator.generate_oidc_session_token(
state, session_data, duration_in_ms=2 * 60 * 1000
diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
index 89d8656634..fe4961dcf3 100644
--- a/tests/util/test_ratelimitutils.py
+++ b/tests/util/test_ratelimitutils.py
@@ -13,16 +13,20 @@
# limitations under the License.
from typing import Optional
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.util.ratelimitutils import FederationRateLimiter
-from tests.server import get_clock
+from tests.server import ThreadedMemoryReactorClock, get_clock
from tests.unittest import TestCase
from tests.utils import default_config
class FederationRateLimiterTestCase(TestCase):
- def test_ratelimit(self):
+ def test_ratelimit(self) -> None:
"""A simple test with the default values"""
reactor, clock = get_clock()
rc_config = build_rc_config()
@@ -32,7 +36,7 @@ class FederationRateLimiterTestCase(TestCase):
# shouldn't block
self.successResultOf(d1)
- def test_concurrent_limit(self):
+ def test_concurrent_limit(self) -> None:
"""Test what happens when we hit the concurrent limit"""
reactor, clock = get_clock()
rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
@@ -54,9 +58,10 @@ class FederationRateLimiterTestCase(TestCase):
# ... until we complete an earlier request
cm2.__exit__(None, None, None)
+ reactor.advance(0.0)
self.successResultOf(d3)
- def test_sleep_limit(self):
+ def test_sleep_limit(self) -> None:
"""Test what happens when we hit the sleep limit"""
reactor, clock = get_clock()
rc_config = build_rc_config(
@@ -78,8 +83,45 @@ class FederationRateLimiterTestCase(TestCase):
sleep_time = _await_resolution(reactor, d3)
self.assertAlmostEqual(sleep_time, 500, places=3)
+ def test_lots_of_queued_things(self) -> None:
+ """Tests lots of synchronous things queued up behind a slow thing.
-def _await_resolution(reactor, d):
+ The stack should *not* explode when the slow thing completes.
+ """
+ reactor, clock = get_clock()
+ rc_config = build_rc_config(
+ {
+ "rc_federation": {
+ "sleep_limit": 1000000000, # never sleep
+ "reject_limit": 1000000000, # never reject requests
+ "concurrent": 1,
+ }
+ }
+ )
+ ratelimiter = FederationRateLimiter(clock, rc_config)
+
+ with ratelimiter.ratelimit("testhost") as d:
+ # shouldn't block
+ self.successResultOf(d)
+
+ async def task() -> None:
+ with ratelimiter.ratelimit("testhost") as d:
+ await d
+
+ for _ in range(1, 100):
+ defer.ensureDeferred(task())
+
+ last_task = defer.ensureDeferred(task())
+
+ # Upon exiting the context manager, all the synchronous things will resume.
+ # If a stack overflow occurs, the final task will not complete.
+
+ # Wait for all the things to complete.
+ reactor.advance(0.0)
+ self.successResultOf(last_task)
+
+
+def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float:
"""advance the clock until the deferred completes.
Returns the number of milliseconds it took to complete.
@@ -90,7 +132,7 @@ def _await_resolution(reactor, d):
return (reactor.seconds() - start_time) * 1000
-def build_rc_config(settings: Optional[dict] = None):
+def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings:
config_dict = default_config("test")
config_dict.update(settings or {})
config = HomeServerConfig()
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 26cb71c640..9529ee53c8 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase
class RetryLimiterTestCase(HomeserverTestCase):
- def test_new_destination(self):
+ def test_new_destination(self) -> None:
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastores().main
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
@@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
- def test_limiter(self):
+ def test_limiter(self) -> None:
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastores().main
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 5da04362a9..bc93de62eb 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
acquired_d: "Deferred[None]" = Deferred()
unblock_d: "Deferred[None]" = Deferred()
- async def reader_or_writer():
+ async def reader_or_writer() -> str:
async with read_or_write(key):
acquired_d.callback(None)
await unblock_d
@@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
)
- def test_rwlock(self):
+ def test_rwlock(self) -> None:
rwlock = ReadWriteLock()
key = "key"
@@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
_, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
self.assertTrue(acquired_d.called)
- def test_lock_handoff_to_nonblocking_writer(self):
+ def test_lock_handoff_to_nonblocking_writer(self) -> None:
"""Test a writer handing the lock to another writer that completes instantly."""
rwlock = ReadWriteLock()
key = "key"
@@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
self.assertTrue(d3.called)
- def test_cancellation_while_holding_read_lock(self):
+ def test_cancellation_while_holding_read_lock(self) -> None:
"""Test cancellation while holding a read lock.
A waiting writer should be given the lock when the reader holding the lock is
@@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("write completed", self.successResultOf(writer_d))
- def test_cancellation_while_holding_write_lock(self):
+ def test_cancellation_while_holding_write_lock(self) -> None:
"""Test cancellation while holding a write lock.
A waiting reader should be given the lock when the writer holding the lock is
@@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("read completed", self.successResultOf(reader_d))
- def test_cancellation_while_waiting_for_read_lock(self):
+ def test_cancellation_while_waiting_for_read_lock(self) -> None:
"""Test cancellation while waiting for a read lock.
Tests that cancelling a waiting reader:
@@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
- def test_cancellation_while_waiting_for_write_lock(self):
+ def test_cancellation_while_waiting_for_write_lock(self) -> None:
"""Test cancellation while waiting for a write lock.
Tests that cancelling a waiting writer:
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 9ed01f7e0c..3df053493b 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
Tests for StreamChangeCache.
"""
- def test_prefilled_cache(self):
+ def test_prefilled_cache(self) -> None:
"""
Providing a prefilled cache to StreamChangeCache will result in a cache
with the prefilled-cache entered in.
@@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
- def test_has_entity_changed(self):
+ def test_has_entity_changed(self) -> None:
"""
StreamChangeCache.entity_has_changed will mark entities as changed, and
has_entity_changed will observe the changed entities.
@@ -51,8 +51,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# return True, whether it's a known entity or not.
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 3))
+ self.assertTrue(cache.has_entity_changed("not@here.website", 3))
- def test_entity_has_changed_pops_off_start(self):
+ def test_entity_has_changed_pops_off_start(self) -> None:
"""
StreamChangeCache.entity_has_changed will respect the max size and
purge the oldest items upon reaching that max size.
@@ -65,15 +67,16 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# The cache is at the max size, 2
self.assertEqual(len(cache._cache), 2)
+ # The cache's earliest known position is 2.
+ self.assertEqual(cache._earliest_known_stream_pos, 2)
# The oldest item has been popped off
self.assertTrue("user@foo.com" not in cache._entity_to_key)
self.assertEqual(
- cache.get_all_entities_changed(2),
- ["bar@baz.net", "user@elsewhere.org"],
+ cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
)
- self.assertIsNone(cache.get_all_entities_changed(1))
+ self.assertFalse(cache.get_all_entities_changed(2).hit)
# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("bar@baz.net", 5)
@@ -81,12 +84,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
)
self.assertEqual(
- cache.get_all_entities_changed(2),
+ cache.get_all_entities_changed(3).entities,
["user@elsewhere.org", "bar@baz.net"],
)
- self.assertIsNone(cache.get_all_entities_changed(1))
+ self.assertFalse(cache.get_all_entities_changed(2).hit)
- def test_get_all_entities_changed(self):
+ def test_get_all_entities_changed(self) -> None:
"""
StreamChangeCache.get_all_entities_changed will return all changed
entities since the given position. If the position is before the start
@@ -99,28 +102,17 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
cache.entity_has_changed("anotheruser@foo.com", 3)
cache.entity_has_changed("user@elsewhere.org", 4)
- r = cache.get_all_entities_changed(1)
-
- # either of these are valid
- ok1 = [
- "user@foo.com",
- "bar@baz.net",
- "anotheruser@foo.com",
- "user@elsewhere.org",
- ]
- ok2 = [
- "user@foo.com",
- "anotheruser@foo.com",
- "bar@baz.net",
- "user@elsewhere.org",
- ]
- self.assertTrue(r == ok1 or r == ok2)
-
r = cache.get_all_entities_changed(2)
- self.assertTrue(r == ok1[1:] or r == ok2[1:])
- self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
- self.assertEqual(cache.get_all_entities_changed(0), None)
+ # Results are ordered so either of these are valid.
+ ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"]
+ ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"]
+ self.assertTrue(r.entities == ok1 or r.entities == ok2)
+
+ self.assertEqual(
+ cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
+ )
+ self.assertFalse(cache.get_all_entities_changed(1).hit)
# ... later, things gest more updates
cache.entity_has_changed("user@foo.com", 5)
@@ -140,9 +132,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"anotheruser@foo.com",
]
r = cache.get_all_entities_changed(3)
- self.assertTrue(r == ok1 or r == ok2)
+ self.assertTrue(r.entities == ok1 or r.entities == ok2)
- def test_has_any_entity_changed(self):
+ def test_has_any_entity_changed(self) -> None:
"""
StreamChangeCache.has_any_entity_changed will return True if any
entities have been changed since the provided stream position, and
@@ -152,9 +144,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"""
cache = StreamChangeCache("#test", 1)
- # With no entities, it returns False for the past, present, and future.
- self.assertFalse(cache.has_any_entity_changed(0))
- self.assertFalse(cache.has_any_entity_changed(1))
+ # With no entities, it returns True for the past, present, and False for
+ # the future.
+ self.assertTrue(cache.has_any_entity_changed(0))
+ self.assertTrue(cache.has_any_entity_changed(1))
self.assertFalse(cache.has_any_entity_changed(2))
# We add an entity
@@ -168,7 +161,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertFalse(cache.has_any_entity_changed(2))
self.assertFalse(cache.has_any_entity_changed(3))
- def test_get_entities_changed(self):
+ def test_get_entities_changed(self) -> None:
"""
StreamChangeCache.get_entities_changed will return the entities in the
given list that have changed since the provided stream ID. If the
@@ -228,7 +221,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
{"bar@baz.net"},
)
- def test_max_pos(self):
+ def test_max_pos(self) -> None:
"""
StreamChangeCache.get_max_pos_of_last_change will return the most
recent point where the entity could have changed. If the entity is not
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index ad4dd7f007..f137e05191 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -19,7 +19,7 @@ from .. import unittest
class StringUtilsTestCase(unittest.TestCase):
- def test_client_secret_regex(self):
+ def test_client_secret_regex(self) -> None:
"""Ensure that client_secret does not contain illegal characters"""
good = [
"abcde12345",
@@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase):
with self.assertRaises(SynapseError):
assert_valid_client_secret(client_secret)
- def test_base62_encode(self):
+ def test_base62_encode(self) -> None:
self.assertEqual("0", base62_encode(0))
self.assertEqual("10", base62_encode(62))
self.assertEqual("1c", base62_encode(100))
diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py
index d957b953bb..3b35b8e4ec 100644
--- a/tests/util/test_threepids.py
+++ b/tests/util/test_threepids.py
@@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase
class CanonicaliseEmailTests(HomeserverTestCase):
- def test_no_at(self):
+ def test_no_at(self) -> None:
with self.assertRaises(ValueError):
canonicalise_email("address-without-at.bar")
- def test_two_at(self):
+ def test_two_at(self) -> None:
with self.assertRaises(ValueError):
canonicalise_email("foo@foo@test.bar")
- def test_bad_format(self):
+ def test_bad_format(self) -> None:
with self.assertRaises(ValueError):
canonicalise_email("user@bad.example.net@good.example.com")
- def test_valid_format(self):
+ def test_valid_format(self) -> None:
self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
- def test_domain_to_lower(self):
+ def test_domain_to_lower(self) -> None:
self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
- def test_domain_with_umlaut(self):
+ def test_domain_with_umlaut(self) -> None:
self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
- def test_address_casefold(self):
+ def test_address_casefold(self) -> None:
self.assertEqual(
canonicalise_email("Strauß@Example.com"), "strauss@example.com"
)
- def test_address_trim(self):
+ def test_address_trim(self) -> None:
self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 567cb18468..fe3b4dc6a4 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -19,7 +19,7 @@ from .. import unittest
class TreeCacheTestCase(unittest.TestCase):
- def test_get_set_onelevel(self):
+ def test_get_set_onelevel(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
@@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b",)), "B")
self.assertEqual(len(cache), 2)
- def test_pop_onelevel(self):
+ def test_pop_onelevel(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
@@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b",)), "B")
self.assertEqual(len(cache), 1)
- def test_get_set_twolevel(self):
+ def test_get_set_twolevel(self) -> None:
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b", "a")), "BA")
self.assertEqual(len(cache), 3)
- def test_pop_twolevel(self):
+ def test_pop_twolevel(self) -> None:
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.pop(("b", "a")), None)
self.assertEqual(len(cache), 1)
- def test_pop_mixedlevel(self):
+ def test_pop_mixedlevel(self) -> None:
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
- def test_clear(self):
+ def test_clear(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
cache.clear()
self.assertEqual(len(cache), 0)
- def test_contains(self):
+ def test_contains(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
self.assertTrue(("a",) in cache)
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index 0d5039de04..c9d22b6d8c 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -18,8 +18,8 @@ from .. import unittest
class WheelTimerTestCase(unittest.TestCase):
- def test_single_insert_fetch(self):
- wheel = WheelTimer(bucket_size=5)
+ def test_single_insert_fetch(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 150)
@@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(156), [obj])
self.assertListEqual(wheel.fetch(170), [])
- def test_multi_insert(self):
- wheel = WheelTimer(bucket_size=5)
+ def test_multi_insert(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()
@@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(200), [obj3])
self.assertListEqual(wheel.fetch(210), [])
- def test_insert_past(self):
- wheel = WheelTimer(bucket_size=5)
+ def test_insert_past(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 50)
self.assertListEqual(wheel.fetch(120), [obj])
- def test_insert_past_multi(self):
- wheel = WheelTimer(bucket_size=5)
+ def test_insert_past_multi(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()
diff --git a/tests/utils.py b/tests/utils.py
index 045a8b5fa7..d76bf9716a 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -125,7 +125,8 @@ def default_config(
"""
config_dict = {
"server_name": name,
- "send_federation": False,
+ # Setting this to an empty list turns off federation sending.
+ "federation_sender_instances": [],
"media_store_path": "media",
# the test signing key is just an arbitrary ed25519 key to keep the config
# parser happy
@@ -183,8 +184,9 @@ def default_config(
# rooms will fail.
"default_room_version": DEFAULT_ROOM_VERSION,
# disable user directory updates, because they get done in the
- # background, which upsets the test runner.
- "update_user_directory": False,
+ # background, which upsets the test runner. Setting this to an
+ # (obviously) fake worker name disables updating the user directory.
+ "update_user_directory_from_worker": "does_not_exist_worker_name",
"caches": {"global_factor": 1, "sync_response_cache_duration": 0},
"listeners": [{"port": 0, "type": "http"}],
}