From 94a396e7c4b4488d7f0ca08672114a4a586cf42c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 22 Feb 2022 14:52:56 +0000 Subject: [PATCH 01/40] Prune setup.cfg some more (#12059) * Remove `trial` section from setup.cfg This was added in the initial commit from 2014. I can't see that it does anything. Maybe it's there so that you can run `trial` without any extra args, but if I do that then I just get the `--help` message. * Move flake8's config to its own file --- .flake8 | 11 +++++++++++ MANIFEST.in | 1 + changelog.d/12052.misc | 2 +- changelog.d/12059.misc | 1 + setup.cfg | 12 ------------ 5 files changed, 14 insertions(+), 13 deletions(-) create mode 100644 .flake8 create mode 100644 changelog.d/12059.misc diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..acb118c86e --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +# TODO: incorporate this into pyproject.toml if flake8 supports it in the future. +# See https://github.com/PyCQA/flake8/issues/234 +[flake8] +# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes +# for error codes. The ones we ignore are: +# W503: line break before binary operator +# W504: line break after binary operator +# E203: whitespace before ':' (which is contrary to pep8?) +# E731: do not assign a lambda expression, use a def +# E501: Line too long (black enforces this for us) +ignore=W503,W504,E203,E731,E501 diff --git a/MANIFEST.in b/MANIFEST.in index c24786c3b3..76d14eb642 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -45,6 +45,7 @@ include book.toml include pyproject.toml recursive-include changelog.d * +include .flake8 prune .circleci prune .github prune .ci diff --git a/changelog.d/12052.misc b/changelog.d/12052.misc index fbaff67e95..11755ae61b 100644 --- a/changelog.d/12052.misc +++ b/changelog.d/12052.misc @@ -1 +1 @@ -Move `isort` configuration to `pyproject.toml`. +Move configuration out of `setup.cfg`. diff --git a/changelog.d/12059.misc b/changelog.d/12059.misc new file mode 100644 index 0000000000..9ba4759d99 --- /dev/null +++ b/changelog.d/12059.misc @@ -0,0 +1 @@ +Move configuration out of `setup.cfg`. \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index a0506572d9..6213f3265b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,3 @@ -[trial] -test_suite = tests - [check-manifest] ignore = .git-blame-ignore-revs @@ -10,12 +7,3 @@ ignore = pylint.cfg tox.ini -[flake8] -# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes -# for error codes. The ones we ignore are: -# W503: line break before binary operator -# W504: line break after binary operator -# E203: whitespace before ':' (which is contrary to pep8?) -# E731: do not assign a lambda expression, use a def -# E501: Line too long (black enforces this for us) -ignore=W503,W504,E203,E731,E501 From 250104d357c17a1c87fa46af35bbf3612f4ef171 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 22 Feb 2022 16:10:10 +0100 Subject: [PATCH 02/40] Implement account status endpoints (MSC3720) (#12001) See matrix-org/matrix-doc#3720 Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/12001.feature | 1 + synapse/config/experimental.py | 3 + synapse/federation/federation_client.py | 60 +++++- synapse/federation/transport/client.py | 19 +- .../federation/transport/server/__init__.py | 8 + .../federation/transport/server/federation.py | 35 +++ synapse/handlers/account.py | 144 +++++++++++++ synapse/rest/client/account.py | 33 +++ synapse/rest/client/capabilities.py | 5 + synapse/server.py | 5 + tests/rest/client/test_account.py | 204 +++++++++++++++++- 11 files changed, 511 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12001.feature create mode 100644 synapse/handlers/account.py diff --git a/changelog.d/12001.feature b/changelog.d/12001.feature new file mode 100644 index 0000000000..dc1153c49e --- /dev/null +++ b/changelog.d/12001.feature @@ -0,0 +1 @@ +Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index bcdeb9ee23..772eb35013 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -65,3 +65,6 @@ class ExperimentalConfig(Config): # experimental support for faster joins over federation (msc2775, msc3706) # requires a target server with msc3706_enabled enabled. self.faster_joins_enabled: bool = experimental.get("faster_joins", False) + + # MSC3720 (Account status endpoint) + self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c2997997da..2121e92e3a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -56,7 +56,7 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.transport.client import SendJoinResponse -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -1610,6 +1610,64 @@ class FederationClient(FederationBase): except ValueError as e: raise InvalidResponseError(str(e)) + async def get_account_status( + self, destination: str, user_ids: List[str] + ) -> Tuple[JsonDict, List[str]]: + """Retrieves account statuses for a given list of users on a given remote + homeserver. + + If the request fails for any reason, all user IDs for this destination are marked + as failed. + + Args: + destination: the destination to contact + user_ids: the user ID(s) for which to request account status(es) + + Returns: + The account statuses, as well as the list of user IDs for which it was not + possible to retrieve a status. + """ + try: + res = await self.transport_layer.get_account_status(destination, user_ids) + except Exception: + # If the query failed for any reason, mark all the users as failed. + return {}, user_ids + + statuses = res.get("account_statuses", {}) + failures = res.get("failures", []) + + if not isinstance(statuses, dict) or not isinstance(failures, list): + # Make sure we're not feeding back malformed data back to the caller. + logger.warning( + "Destination %s responded with malformed data to account_status query", + destination, + ) + return {}, user_ids + + for user_id in user_ids: + # Any account whose status is missing is a user we failed to receive the + # status of. + if user_id not in statuses and user_id not in failures: + failures.append(user_id) + + # Filter out any user ID that doesn't belong to the remote server that sent its + # status (or failure). + def filter_user_id(user_id: str) -> bool: + try: + return UserID.from_string(user_id).domain == destination + except SynapseError: + # If the user ID doesn't parse, ignore it. + return False + + filtered_statuses = dict( + # item is a (key, value) tuple, so item[0] is the user ID. + filter(lambda item: filter_user_id(item[0]), statuses.items()) + ) + + filtered_failures = list(filter(filter_user_id, failures)) + + return filtered_statuses, filtered_failures + @attr.s(frozen=True, slots=True, auto_attribs=True) class TimestampToEventResponse: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 7e510e224a..69998de520 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -258,8 +258,9 @@ class TransportLayerClient: args: dict, retry_on_dns_fail: bool, ignore_backoff: bool = False, + prefix: str = FEDERATION_V1_PREFIX, ) -> JsonDict: - path = _create_v1_path("/query/%s", query_type) + path = _create_path(prefix, "/query/%s", query_type) return await self.client.get_json( destination=destination, @@ -1247,6 +1248,22 @@ class TransportLayerClient: args={"suggested_only": "true" if suggested_only else "false"}, ) + async def get_account_status( + self, destination: str, user_ids: List[str] + ) -> JsonDict: + """ + Args: + destination: The remote server. + user_ids: The user ID(s) for which to request account status(es). + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc3720/account_status" + ) + + return await self.client.post_json( + destination=destination, path=path, data={"user_ids": user_ids} + ) + def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index db4fe2c798..67a6347907 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -24,6 +24,7 @@ from synapse.federation.transport.server._base import ( ) from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, + FederationAccountStatusServlet, FederationTimestampLookupServlet, ) from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES @@ -336,6 +337,13 @@ def register_servlets( ): continue + # Only allow the `/account_status` servlet if msc3720 is enabled + if ( + servletclass == FederationAccountStatusServlet + and not hs.config.experimental.msc3720_enabled + ): + continue + servletclass( hs=hs, authenticator=authenticator, diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index e85a8eda5b..4d75e58bfc 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -766,6 +766,40 @@ class RoomComplexityServlet(BaseFederationServlet): return 200, complexity +class FederationAccountStatusServlet(BaseFederationServerServlet): + PATH = "/query/account_status" + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3720" + + def __init__( + self, + hs: "HomeServer", + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._account_handler = hs.get_account_handler() + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + if "user_ids" not in content: + raise SynapseError( + 400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM + ) + + statuses, failures = await self._account_handler.get_account_statuses( + content["user_ids"], + allow_remote=False, + ) + + return 200, {"account_statuses": statuses, "failures": failures} + + FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationSendServlet, FederationEventServlet, @@ -797,4 +831,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, + FederationAccountStatusServlet, ) diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py new file mode 100644 index 0000000000..f8cfe9f6de --- /dev/null +++ b/synapse/handlers/account.py @@ -0,0 +1,144 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, List, Tuple + +from synapse.api.errors import Codes, SynapseError +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class AccountHandler: + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + self._is_mine = hs.is_mine + self._federation_client = hs.get_federation_client() + + async def get_account_statuses( + self, + user_ids: List[str], + allow_remote: bool, + ) -> Tuple[JsonDict, List[str]]: + """Get account statuses for a list of user IDs. + + If one or more account(s) belong to remote homeservers, retrieve their status(es) + over federation if allowed. + + Args: + user_ids: The list of accounts to retrieve the status of. + allow_remote: Whether to try to retrieve the status of remote accounts, if + any. + + Returns: + The account statuses as well as the list of users whose statuses could not be + retrieved. + + Raises: + SynapseError if a required parameter is missing or malformed, or if one of + the accounts isn't local to this homeserver and allow_remote is False. + """ + statuses = {} + failures = [] + remote_users: List[UserID] = [] + + for raw_user_id in user_ids: + try: + user_id = UserID.from_string(raw_user_id) + except SynapseError: + raise SynapseError( + 400, + f"Not a valid Matrix user ID: {raw_user_id}", + Codes.INVALID_PARAM, + ) + + if self._is_mine(user_id): + status = await self._get_local_account_status(user_id) + statuses[user_id.to_string()] = status + else: + if not allow_remote: + raise SynapseError( + 400, + f"Not a local user: {raw_user_id}", + Codes.INVALID_PARAM, + ) + + remote_users.append(user_id) + + if allow_remote and len(remote_users) > 0: + remote_statuses, remote_failures = await self._get_remote_account_statuses( + remote_users, + ) + + statuses.update(remote_statuses) + failures += remote_failures + + return statuses, failures + + async def _get_local_account_status(self, user_id: UserID) -> JsonDict: + """Retrieve the status of a local account. + + Args: + user_id: The account to retrieve the status of. + + Returns: + The account's status. + """ + status = {"exists": False} + + userinfo = await self._store.get_userinfo_by_id(user_id.to_string()) + + if userinfo is not None: + status = { + "exists": True, + "deactivated": userinfo.is_deactivated, + } + + return status + + async def _get_remote_account_statuses( + self, remote_users: List[UserID] + ) -> Tuple[JsonDict, List[str]]: + """Send out federation requests to retrieve the statuses of remote accounts. + + Args: + remote_users: The accounts to retrieve the statuses of. + + Returns: + The statuses of the accounts, and a list of accounts for which no status + could be retrieved. + """ + # Group remote users by destination, so we only send one request per remote + # homeserver. + by_destination: Dict[str, List[str]] = {} + for user in remote_users: + if user.domain not in by_destination: + by_destination[user.domain] = [] + + by_destination[user.domain].append(user.to_string()) + + # Retrieve the statuses and failures for remote accounts. + final_statuses: JsonDict = {} + final_failures: List[str] = [] + for destination, users in by_destination.items(): + statuses, failures = await self._federation_client.get_account_status( + destination, + users, + ) + + final_statuses.update(statuses) + final_failures += failures + + return final_statuses, final_failures diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index efe299e698..5802de5b7c 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -896,6 +896,36 @@ class WhoamiRestServlet(RestServlet): return 200, response +class AccountStatusRestServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc3720/account_status$", unstable=True, releases=() + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self._is_mine = hs.is_mine + self._federation_client = hs.get_federation_client() + self._account_handler = hs.get_account_handler() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self._auth.get_user_by_req(request) + + body = parse_json_object_from_request(request) + if "user_ids" not in body: + raise SynapseError( + 400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM + ) + + statuses, failures = await self._account_handler.get_account_statuses( + body["user_ids"], + allow_remote=True, + ) + + return 200, {"account_statuses": statuses, "failures": failures} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) @@ -910,3 +940,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) + + if hs.config.experimental.msc3720_enabled: + AccountStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index e05c926b6f..b80fdd3712 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -75,6 +75,11 @@ class CapabilitiesRestServlet(RestServlet): if self.config.experimental.msc3440_enabled: response["capabilities"]["io.element.thread"] = {"enabled": True} + if self.config.experimental.msc3720_enabled: + response["capabilities"]["org.matrix.msc3720.account_status"] = { + "enabled": True, + } + return HTTPStatus.OK, response diff --git a/synapse/server.py b/synapse/server.py index 564afdcb96..4c07f21015 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -62,6 +62,7 @@ from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler +from synapse.handlers.account import AccountHandler from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.admin import AdminHandler @@ -807,6 +808,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_external_cache(self) -> ExternalCache: return ExternalCache(self) + @cache_in_self + def get_account_handler(self) -> AccountHandler: + return AccountHandler(self) + @cache_in_self def get_outbound_redis_connection(self) -> "RedisProtocol": """ diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 51146c471d..afaa597f65 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -1,6 +1,4 @@ -# Copyright 2015-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,16 +15,22 @@ import json import os import re from email.parser import Parser -from typing import Optional +from typing import Dict, List, Optional +from unittest.mock import Mock import pkg_resources +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService +from synapse.rest import admin from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -1040,3 +1044,195 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} self.assertIn(expected_email, threepids) + + +class AccountStatusTestCase(unittest.HomeserverTestCase): + servlets = [ + account.register_servlets, + admin.register_servlets, + login.register_servlets, + ] + + url = "/_matrix/client/unstable/org.matrix.msc3720/account_status" + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["experimental_features"] = {"msc3720_enabled": True} + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + self.requester = self.register_user("requester", "password") + self.requester_tok = self.login("requester", "password") + self.server_name = homeserver.config.server.server_name + + def test_missing_mxid(self): + """Tests that not providing any MXID raises an error.""" + self._test_status( + users=None, + expected_status_code=400, + expected_errcode=Codes.MISSING_PARAM, + ) + + def test_invalid_mxid(self): + """Tests that providing an invalid MXID raises an error.""" + self._test_status( + users=["bad:test"], + expected_status_code=400, + expected_errcode=Codes.INVALID_PARAM, + ) + + def test_local_user_not_exists(self): + """Tests that the account status endpoints correctly reports that a user doesn't + exist. + """ + user = "@unknown:" + self.hs.config.server.server_name + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": False, + }, + }, + expected_failures=[], + ) + + def test_local_user_exists(self): + """Tests that the account status endpoint correctly reports that a user doesn't + exist. + """ + user = self.register_user("someuser", "password") + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": False, + }, + }, + expected_failures=[], + ) + + def test_local_user_deactivated(self): + """Tests that the account status endpoint correctly reports a deactivated user.""" + user = self.register_user("someuser", "password") + self.get_success( + self.hs.get_datastore().set_user_deactivated_status(user, deactivated=True) + ) + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": True, + }, + }, + expected_failures=[], + ) + + def test_mixed_local_and_remote_users(self): + """Tests that if some users are remote the account status endpoint correctly + merges the remote responses with the local result. + """ + # We use 3 users: one doesn't exist but belongs on the local homeserver, one is + # deactivated and belongs on one remote homeserver, and one belongs to another + # remote homeserver that didn't return any result (the federation code should + # mark that user as a failure). + users = [ + "@unknown:" + self.hs.config.server.server_name, + "@deactivated:remote", + "@failed:otherremote", + "@bad:badremote", + ] + + async def post_json(destination, path, data, *a, **kwa): + if destination == "remote": + return { + "account_statuses": { + users[1]: { + "exists": True, + "deactivated": True, + }, + } + } + if destination == "otherremote": + return {} + if destination == "badremote": + # badremote tries to overwrite the status of a user that doesn't belong + # to it (i.e. users[1]) with false data, which Synapse is expected to + # ignore. + return { + "account_statuses": { + users[3]: { + "exists": False, + }, + users[1]: { + "exists": False, + }, + } + } + + # Register a mock that will return the expected result depending on the remote. + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) + + # Check that we've got the correct response from the client-side endpoint. + self._test_status( + users=users, + expected_statuses={ + users[0]: { + "exists": False, + }, + users[1]: { + "exists": True, + "deactivated": True, + }, + users[3]: { + "exists": False, + }, + }, + expected_failures=[users[2]], + ) + + def _test_status( + self, + users: Optional[List[str]], + expected_status_code: int = 200, + expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, + expected_failures: Optional[List[str]] = None, + expected_errcode: Optional[str] = None, + ): + """Send a request to the account status endpoint and check that the response + matches with what's expected. + + Args: + users: The account(s) to request the status of, if any. If set to None, no + `user_id` query parameter will be included in the request. + expected_status_code: The expected HTTP status code. + expected_statuses: The expected account statuses, if any. + expected_failures: The expected failures, if any. + expected_errcode: The expected Matrix error code, if any. + """ + content = {} + if users is not None: + content["user_ids"] = users + + channel = self.make_request( + method="POST", + path=self.url, + content=content, + access_token=self.requester_tok, + ) + + self.assertEqual(channel.code, expected_status_code) + + if expected_statuses is not None: + self.assertEqual(channel.json_body["account_statuses"], expected_statuses) + + if expected_failures is not None: + self.assertEqual(channel.json_body["failures"], expected_failures) + + if expected_errcode is not None: + self.assertEqual(channel.json_body["errcode"], expected_errcode) From 6d14b3dabfe38c6ae487d0f663e294056b6cc056 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Feb 2022 15:52:08 +0000 Subject: [PATCH 03/40] Better error message when failing to request from another process (#12060) --- changelog.d/12060.misc | 1 + synapse/replication/http/_base.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12060.misc diff --git a/changelog.d/12060.misc b/changelog.d/12060.misc new file mode 100644 index 0000000000..d771e6a1b3 --- /dev/null +++ b/changelog.d/12060.misc @@ -0,0 +1 @@ +Fix error message when a worker process fails to talk to another worker process. diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index bc1d28dd19..2e697c74a6 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -268,7 +268,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): raise e.to_synapse_error() except Exception as e: _outgoing_request_counter.labels(cls.NAME, "ERR").inc() - raise SynapseError(502, "Failed to talk to main process") from e + raise SynapseError( + 502, f"Failed to talk to {instance_name} process" + ) from e _outgoing_request_counter.labels(cls.NAME, 200).inc() return result From e3fe6347be1da930b6a0ed2005b565369800a327 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Tue, 22 Feb 2022 11:35:01 -0700 Subject: [PATCH 04/40] Remove excess condition on `knock->leave` check (#11900) --- changelog.d/11900.misc | 1 + synapse/event_auth.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11900.misc diff --git a/changelog.d/11900.misc b/changelog.d/11900.misc new file mode 100644 index 0000000000..edd2852fd4 --- /dev/null +++ b/changelog.d/11900.misc @@ -0,0 +1 @@ +Remove unnecessary condition on knock->leave auth rule check. \ No newline at end of file diff --git a/synapse/event_auth.py b/synapse/event_auth.py index eca00bc975..621a3efccc 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -374,9 +374,9 @@ def _is_membership_change_allowed( return # Require the user to be in the room for membership changes other than join/knock. - if Membership.JOIN != membership and ( - RoomVersion.msc2403_knocking and Membership.KNOCK != membership - ): + # Note that the room version check for knocking is done implicitly by `caller_knocked` + # and the ability to set a membership of `knock` in the first place. + if Membership.JOIN != membership and Membership.KNOCK != membership: # If the user has been invited or has knocked, they are allowed to change their # membership event to leave if ( From c1ac2a81350f3b5b86f4c53a585eccd17e3b8e75 Mon Sep 17 00:00:00 2001 From: Nicolas Werner <89468146+nico-famedly@users.noreply.github.com> Date: Wed, 23 Feb 2022 10:06:18 +0000 Subject: [PATCH 05/40] Rename default branch of complement.sh to main (#12063) The complement.sh script relies on the name of the ref matching the name of the unpacked folder. The branch redirect from renaming the default branch breaks that assumption. Signed-off-by: Nicolas Werner --- changelog.d/12063.misc | 1 + scripts-dev/complement.sh | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12063.misc diff --git a/changelog.d/12063.misc b/changelog.d/12063.misc new file mode 100644 index 0000000000..e48c5dd08b --- /dev/null +++ b/changelog.d/12063.misc @@ -0,0 +1 @@ +Fix using the complement.sh script without specifying a dir or a branch. Contributed by Nico on behalf of Famedly. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index e08ffedaf3..0aecb3daf1 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -5,7 +5,7 @@ # It makes a Synapse image which represents the current checkout, # builds a synapse-complement image on top, then runs tests with it. # -# By default the script will fetch the latest Complement master branch and +# By default the script will fetch the latest Complement main branch and # run tests with that. This can be overridden to use a custom Complement # checkout by setting the COMPLEMENT_DIR environment variable to the # filepath of a local Complement checkout or by setting the COMPLEMENT_REF @@ -32,7 +32,7 @@ cd "$(dirname $0)/.." # Check for a user-specified Complement checkout if [[ -z "$COMPLEMENT_DIR" ]]; then - COMPLEMENT_REF=${COMPLEMENT_REF:-master} + COMPLEMENT_REF=${COMPLEMENT_REF:-main} echo "COMPLEMENT_DIR not set. Fetching Complement checkout from ${COMPLEMENT_REF}..." wget -Nq https://github.com/matrix-org/complement/archive/${COMPLEMENT_REF}.tar.gz tar -xzf ${COMPLEMENT_REF}.tar.gz From e24ff8ebe3d4119d377355402245947f7de61c00 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 23 Feb 2022 11:04:02 +0000 Subject: [PATCH 06/40] Remove `HomeServer.get_datastore()` (#12031) The presence of this method was confusing, and mostly present for backwards compatibility. Let's get rid of it. Part of #11733 --- changelog.d/12031.misc | 1 + docs/manhole.md | 2 +- scripts/update_synapse_database | 2 +- synapse/api/auth.py | 2 +- synapse/api/auth_blocking.py | 2 +- synapse/api/filtering.py | 4 +-- synapse/app/_base.py | 2 +- synapse/app/generic_worker.py | 2 +- synapse/app/homeserver.py | 2 +- synapse/app/phone_stats_home.py | 14 +++++--- synapse/appservice/scheduler.py | 2 +- synapse/crypto/keyring.py | 4 +-- synapse/events/builder.py | 2 +- synapse/events/third_party_rules.py | 2 +- synapse/federation/federation_base.py | 2 +- synapse/federation/sender/__init__.py | 2 +- .../sender/per_destination_queue.py | 9 ++--- .../federation/sender/transaction_manager.py | 2 +- synapse/federation/transport/server/_base.py | 2 +- .../federation/transport/server/federation.py | 2 +- synapse/groups/attestations.py | 2 +- synapse/groups/groups_server.py | 2 +- synapse/handlers/account_data.py | 4 +-- synapse/handlers/account_validity.py | 2 +- synapse/handlers/admin.py | 2 +- synapse/handlers/appservice.py | 2 +- synapse/handlers/auth.py | 4 +-- synapse/handlers/cas.py | 2 +- synapse/handlers/deactivate_account.py | 2 +- synapse/handlers/device.py | 4 +-- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/directory.py | 2 +- synapse/handlers/e2e_keys.py | 4 +-- synapse/handlers/e2e_room_keys.py | 2 +- synapse/handlers/event_auth.py | 2 +- synapse/handlers/events.py | 4 +-- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 2 +- synapse/handlers/groups_local.py | 2 +- synapse/handlers/identity.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/message.py | 4 +-- synapse/handlers/oidc.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/presence.py | 4 +-- synapse/handlers/profile.py | 2 +- synapse/handlers/read_marker.py | 2 +- synapse/handlers/receipts.py | 4 +-- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 10 +++--- synapse/handlers/room_batch.py | 2 +- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/room_summary.py | 2 +- synapse/handlers/saml.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/set_password.py | 2 +- synapse/handlers/sso.py | 2 +- synapse/handlers/state_deltas.py | 2 +- synapse/handlers/stats.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/handlers/typing.py | 4 +-- synapse/handlers/ui_auth/checkers.py | 4 +-- synapse/handlers/user_directory.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/module_api/__init__.py | 8 +++-- synapse/notifier.py | 2 +- synapse/push/__init__.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 4 +-- synapse/push/emailpusher.py | 2 +- synapse/push/httppusher.py | 4 +-- synapse/push/mailer.py | 2 +- synapse/push/pusherpool.py | 2 +- synapse/replication/http/devices.py | 2 +- synapse/replication/http/federation.py | 10 +++--- synapse/replication/http/membership.py | 10 +++--- synapse/replication/http/register.py | 4 +-- synapse/replication/http/send_event.py | 2 +- synapse/replication/tcp/client.py | 4 +-- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/resource.py | 2 +- synapse/replication/tcp/streams/_base.py | 24 ++++++------- synapse/replication/tcp/streams/events.py | 2 +- synapse/rest/admin/__init__.py | 2 +- synapse/rest/admin/background_updates.py | 2 +- synapse/rest/admin/devices.py | 6 ++-- synapse/rest/admin/event_reports.py | 4 +-- synapse/rest/admin/federation.py | 8 ++--- synapse/rest/admin/media.py | 20 +++++------ synapse/rest/admin/registration_tokens.py | 6 ++-- synapse/rest/admin/rooms.py | 16 ++++----- synapse/rest/admin/statistics.py | 2 +- synapse/rest/admin/users.py | 24 ++++++------- synapse/rest/client/account.py | 18 +++++----- synapse/rest/client/account_data.py | 4 +-- synapse/rest/client/directory.py | 6 ++-- synapse/rest/client/events.py | 2 +- synapse/rest/client/groups.py | 8 ++--- synapse/rest/client/initial_sync.py | 2 +- synapse/rest/client/keys.py | 2 +- synapse/rest/client/login.py | 4 +-- synapse/rest/client/notifications.py | 2 +- synapse/rest/client/openid.py | 2 +- synapse/rest/client/push_rule.py | 2 +- synapse/rest/client/pusher.py | 4 ++- synapse/rest/client/register.py | 10 +++--- synapse/rest/client/relations.py | 6 ++-- synapse/rest/client/report_event.py | 2 +- synapse/rest/client/room.py | 12 +++---- synapse/rest/client/room_batch.py | 2 +- synapse/rest/client/shared_rooms.py | 2 +- synapse/rest/client/sync.py | 2 +- synapse/rest/client/tags.py | 2 +- synapse/rest/consent/consent_resource.py | 2 +- synapse/rest/key/v2/remote_key_resource.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/rest/media/v1/thumbnail_resource.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/rest/synapse/client/password_reset.py | 2 +- synapse/server.py | 16 +++------ .../server_notices/consent_server_notices.py | 2 +- .../resource_limits_server_notices.py | 2 +- .../server_notices/server_notices_manager.py | 2 +- synapse/state/__init__.py | 2 +- .../databases/main/monthly_active_users.py | 2 +- synapse/streams/events.py | 2 +- tests/api/test_auth.py | 8 +++-- tests/api/test_filtering.py | 2 +- tests/api/test_ratelimiting.py | 18 +++++----- tests/app/test_phone_stats_home.py | 34 ++++++++++--------- tests/crypto/test_keyring.py | 10 +++--- tests/events/test_snapshot.py | 2 +- tests/federation/test_complexity.py | 4 +-- tests/federation/test_federation_catch_up.py | 26 +++++++------- tests/federation/test_federation_sender.py | 6 ++-- tests/federation/transport/test_knocking.py | 2 +- tests/handlers/test_appservice.py | 10 +++--- tests/handlers/test_auth.py | 12 +++---- tests/handlers/test_cas.py | 2 +- tests/handlers/test_deactivate_account.py | 2 +- tests/handlers/test_device.py | 4 +-- tests/handlers/test_directory.py | 6 ++-- tests/handlers/test_e2e_keys.py | 2 +- tests/handlers/test_federation.py | 2 +- tests/handlers/test_message.py | 2 +- tests/handlers/test_oidc.py | 6 ++-- tests/handlers/test_presence.py | 4 +-- tests/handlers/test_profile.py | 4 +-- tests/handlers/test_register.py | 6 ++-- tests/handlers/test_saml.py | 4 +-- tests/handlers/test_stats.py | 2 +- tests/handlers/test_sync.py | 4 +-- tests/handlers/test_typing.py | 2 +- tests/handlers/test_user_directory.py | 2 +- tests/module_api/test_api.py | 2 +- tests/push/test_email.py | 22 ++++++------ tests/push/test_http.py | 22 ++++++------ tests/replication/_base.py | 6 ++-- tests/replication/slave/storage/_base.py | 4 +-- .../tcp/streams/test_account_data.py | 4 +-- tests/replication/tcp/streams/test_events.py | 4 +-- .../replication/tcp/streams/test_receipts.py | 4 +-- .../test_federation_sender_shard.py | 2 +- tests/replication/test_pusher_shard.py | 2 +- .../test_sharded_event_persister.py | 8 ++--- tests/rest/admin/test_background_updates.py | 2 +- tests/rest/admin/test_federation.py | 4 +-- tests/rest/admin/test_media.py | 4 +-- tests/rest/admin/test_registration_tokens.py | 2 +- tests/rest/admin/test_room.py | 6 ++-- tests/rest/admin/test_server_notice.py | 2 +- tests/rest/admin/test_user.py | 20 +++++------ tests/rest/client/test_account.py | 10 +++--- tests/rest/client/test_filter.py | 2 +- tests/rest/client/test_login.py | 4 +-- tests/rest/client/test_profile.py | 2 +- tests/rest/client/test_register.py | 28 ++++++++------- tests/rest/client/test_relations.py | 4 +-- tests/rest/client/test_retention.py | 4 +-- tests/rest/client/test_rooms.py | 10 +++--- tests/rest/client/test_shadow_banned.py | 2 +- tests/rest/client/test_shared_rooms.py | 2 +- tests/rest/client/test_sync.py | 2 +- tests/rest/client/test_typing.py | 2 +- tests/rest/client/test_upgrade_room.py | 2 +- tests/rest/media/v1/test_media_storage.py | 2 +- .../test_resource_limits_server_notices.py | 2 +- .../databases/main/test_deviceinbox.py | 2 +- .../databases/main/test_events_worker.py | 6 ++-- tests/storage/databases/main/test_lock.py | 2 +- tests/storage/databases/main/test_room.py | 2 +- tests/storage/test__base.py | 2 +- tests/storage/test_account_data.py | 2 +- tests/storage/test_appservice.py | 2 +- tests/storage/test_background_update.py | 8 ++--- tests/storage/test_cleanup_extrems.py | 4 +-- tests/storage/test_client_ips.py | 4 +-- tests/storage/test_devices.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_e2e_room_keys.py | 2 +- tests/storage/test_end_to_end_keys.py | 2 +- tests/storage/test_event_chain.py | 4 +-- tests/storage/test_event_federation.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_events.py | 4 +-- tests/storage/test_id_generators.py | 6 ++-- tests/storage/test_keys.py | 4 +-- tests/storage/test_main.py | 2 +- tests/storage/test_monthly_active_users.py | 2 +- tests/storage/test_profile.py | 2 +- tests/storage/test_purge.py | 4 +-- tests/storage/test_redaction.py | 2 +- tests/storage/test_registration.py | 2 +- tests/storage/test_rollback_worker.py | 6 ++-- tests/storage/test_room.py | 4 +-- tests/storage/test_room_search.py | 2 +- tests/storage/test_roommember.py | 4 +-- tests/storage/test_state.py | 2 +- tests/storage/test_stream.py | 2 +- tests/storage/test_transactions.py | 2 +- tests/storage/test_user_directory.py | 4 +-- tests/test_federation.py | 8 +++-- tests/test_mau.py | 2 +- tests/test_state.py | 4 +-- tests/test_utils/event_injection.py | 4 ++- tests/test_visibility.py | 4 ++- tests/unittest.py | 14 ++++---- tests/util/test_retryutils.py | 4 +-- tests/utils.py | 2 +- 230 files changed, 526 insertions(+), 500 deletions(-) create mode 100644 changelog.d/12031.misc diff --git a/changelog.d/12031.misc b/changelog.d/12031.misc new file mode 100644 index 0000000000..d4bedc6b97 --- /dev/null +++ b/changelog.d/12031.misc @@ -0,0 +1 @@ +Remove legacy `HomeServer.get_datastore()`. diff --git a/docs/manhole.md b/docs/manhole.md index 715ed840f2..a82fad0f0f 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -94,6 +94,6 @@ As a simple example, retrieving an event from the database: ```pycon >>> from twisted.internet import defer ->>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')) +>>> defer.ensureDeferred(hs.get_datastores().main.get_event('$1416420717069yeQaw:matrix.org')) > ``` diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database index 5c6453d77f..f43676afaa 100755 --- a/scripts/update_synapse_database +++ b/scripts/update_synapse_database @@ -44,7 +44,7 @@ class MockHomeserver(HomeServer): def run_background_updates(hs): - store = hs.get_datastore() + store = hs.get_datastores().main async def run_background_updates(): await store.db_pool.updates.run_background_updates(sleep=False) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 683241201c..01c32417d8 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -60,7 +60,7 @@ class Auth: def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._account_validity_handler = hs.get_account_validity_handler() diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 08fe160c98..22348d2d86 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class AuthBlocking: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._hs_disabled = hs.config.server.hs_disabled diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index d087c816db..fe4cc2e8ee 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -150,7 +150,7 @@ def matrix_user_id_validator(user_id_str: str) -> UserID: class Filtering: def __init__(self, hs: "HomeServer"): self._hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) @@ -294,7 +294,7 @@ class FilterCollection: class Filter: def __init__(self, hs: "HomeServer", filter_json: JsonDict): self._hs = hs - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.filter_json = filter_json self.limit = filter_json.get("limit", 10) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 452c0c09d5..3e59805baa 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -448,7 +448,7 @@ async def start(hs: "HomeServer") -> None: # It is now safe to start your Synapse. hs.start_listening() - hs.get_datastore().db_pool.start_profiling() + hs.get_datastores().main.db_pool.start_profiling() hs.get_pusherpool().start() # Log when we start the shut down process. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index aadc882bf8..1536a42723 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -142,7 +142,7 @@ class KeyUploadServlet(RestServlet): """ super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.http_client = hs.get_simple_http_client() self.main_uri = hs.config.worker.worker_main_http_uri diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bfb30003c2..b9931001c2 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -372,7 +372,7 @@ def setup(config_options: List[str]) -> SynapseHomeServer: await _base.start(hs) - hs.get_datastore().db_pool.updates.start_doing_background_updates() + hs.get_datastores().main.db_pool.updates.start_doing_background_updates() register_start(start) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 899dba5c3d..40dbdace8e 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -82,7 +82,7 @@ async def phone_stats_home( # General statistics # - store = hs.get_datastore() + store = hs.get_datastores().main stats["homeserver"] = hs.config.server.server_name stats["server_context"] = hs.config.server.server_context @@ -170,18 +170,22 @@ def start_phone_stats_home(hs: "HomeServer") -> None: # Rather than update on per session basis, batch up the requests. # If you increase the loop period, the accuracy of user_daily_visits # table will decrease - clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000) + clock.looping_call( + hs.get_datastores().main.generate_user_daily_visits, 5 * 60 * 1000 + ) # monthly active user limiting functionality - clock.looping_call(hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60) - hs.get_datastore().reap_monthly_active_users() + clock.looping_call( + hs.get_datastores().main.reap_monthly_active_users, 1000 * 60 * 60 + ) + hs.get_datastores().main.reap_monthly_active_users() @wrap_as_background_process("generate_monthly_active_users") async def generate_monthly_active_users() -> None: current_mau_count = 0 current_mau_count_by_service = {} reserved_users: Sized = () - store = hs.get_datastore() + store = hs.get_datastores().main if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() current_mau_count_by_service = ( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index c42fa32fff..b4e602e880 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -92,7 +92,7 @@ class ApplicationServiceScheduler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.as_api = hs.get_application_service_api() self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 72d4a69aac..93d56c077a 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -476,7 +476,7 @@ class StoreKeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _fetch_keys( self, keys_to_fetch: List[_FetchKeyRequest] @@ -498,7 +498,7 @@ class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config async def process_v2_response( diff --git a/synapse/events/builder.py b/synapse/events/builder.py index eb39e0ae32..1ea1bb7d37 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -189,7 +189,7 @@ class EventBuilderFactory: self.hostname = hs.hostname self.signing_key = hs.signing_key - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 1bb8ca7145..71ec100a7f 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -143,7 +143,7 @@ class ThirdPartyEventRules: def __init__(self, hs: "HomeServer"): self.third_party_rules = None - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index fab6da3c08..41ac49fdc8 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -39,7 +39,7 @@ class FederationBase: self.server_name = hs.hostname self.keyring = hs.get_keyring() self.spam_checker = hs.get_spam_checker() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._clock = hs.get_clock() async def _check_sigs_and_hash( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 720d7bd74d..6106a486d1 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -228,7 +228,7 @@ class FederationSender(AbstractFederationSender): self.hs = hs self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index c3132f7319..c8768f22bc 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -76,7 +76,7 @@ class PerDestinationQueue: ): self._server_name = hs.hostname self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._transaction_manager = transaction_manager self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config @@ -381,9 +381,8 @@ class PerDestinationQueue: ) ) - last_successful_stream_ordering = self._last_successful_stream_ordering - - if last_successful_stream_ordering is None: + _tmp_last_successful_stream_ordering = self._last_successful_stream_ordering + if _tmp_last_successful_stream_ordering is None: # if it's still None, then this means we don't have the information # in our database ­ we haven't successfully sent a PDU to this server # (at least since the introduction of the feature tracking @@ -393,6 +392,8 @@ class PerDestinationQueue: self._catching_up = False return + last_successful_stream_ordering: int = _tmp_last_successful_stream_ordering + # get at most 50 catchup room/PDUs while True: event_ids = await self._store.get_catch_up_room_event_ids( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 742ee57255..0c1cad86ab 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -53,7 +53,7 @@ class TransactionManager: def __init__(self, hs: "synapse.server.HomeServer"): self._server_name = hs.hostname self.clock = hs.get_clock() # nb must be called this for @measure_func - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index dff2b68359..87e99c7ddf 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -55,7 +55,7 @@ class Authenticator: self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist ) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 4d75e58bfc..9cc9a7339d 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -746,7 +746,7 @@ class RoomComplexityServlet(BaseFederationServlet): server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self._store = self.hs.get_datastore() + self._store = self.hs.get_datastores().main async def on_GET( self, diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index a87896e538..ed26d6a6ce 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -140,7 +140,7 @@ class GroupAttestionRenewer: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.assestations = hs.get_groups_attestation_signing() self.transport_client = hs.get_federation_transport_client() self.is_mine_id = hs.is_mine_id diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 449bbc7004..4c3a5a6e24 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -45,7 +45,7 @@ MAX_LONG_DESC_LEN = 10000 class GroupsServerWorkerHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_list_handler = hs.get_room_list_handler() self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index bad48713bc..177b4f8991 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: class AccountDataHandler: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() @@ -166,7 +166,7 @@ class AccountDataHandler: class AccountDataEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def get_current_key(self, direction: str = "f") -> int: return self.store.get_max_account_data_stream_id() diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 87e415df75..9d0975f636 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -43,7 +43,7 @@ class AccountValidityHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.send_email_handler = self.hs.get_send_email_handler() self.clock = self.hs.get_clock() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 00ab5e79bf..96376963f2 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index a42c3558e4..e6461cc3c9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -47,7 +47,7 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed class ApplicationServicesHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() self.scheduler = hs.get_application_service_scheduler() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 572f54b1e3..3e29c96a49 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -194,7 +194,7 @@ class AuthHandler: SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self.checkers: Dict[str, UserInteractiveAuthChecker] = {} @@ -1183,7 +1183,7 @@ class AuthHandler: # No password providers were able to handle this 3pid # Check local store - user_id = await self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( medium, address ) if not user_id: diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 5d8f6c50a9..7163af8004 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -61,7 +61,7 @@ class CasHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 7a13d76a68..e4eae03056 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -29,7 +29,7 @@ class DeactivateAccountHandler: """Handler which deals with deactivating user accounts.""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 36c05f8363..934b5bd734 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -63,7 +63,7 @@ class DeviceWorkerHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self.state_store = hs.get_storage().state @@ -628,7 +628,7 @@ class DeviceListUpdater: "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index b582266af9..4cb725d027 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -43,7 +43,7 @@ class DeviceMessageHandler: Args: hs: server """ - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.is_mine = hs.is_mine diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 082f521791..b7064c6624 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -44,7 +44,7 @@ class DirectoryHandler: self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d4dfddf63f..d96456cd40 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -47,7 +47,7 @@ logger = logging.getLogger(__name__) class E2eKeysHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() self.is_mine = hs.is_mine @@ -1335,7 +1335,7 @@ class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.e2e_keys_handler = e2e_keys_handler diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 12614b2c5d..52e44a2d42 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -45,7 +45,7 @@ class E2eRoomKeysHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # Used to lock whenever a client is uploading key data. This prevents collisions # between clients trying to upload the details of a new session, given all diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 365063ebdf..d441ebb0ab 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -43,7 +43,7 @@ class EventAuthHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._server_name = hs.hostname async def check_auth_rules_from_context( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index bac5de0526..97e75e60c3 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) class EventStreamHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs @@ -134,7 +134,7 @@ class EventStreamHandler: class EventHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e9ac920bcc..c055c26eca 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -107,7 +107,7 @@ class FederationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self.federation_client = hs.get_federation_client() diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 7683246bef..09d0de1ead 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -95,7 +95,7 @@ class FederationEventHandler: """ def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._storage = hs.get_storage() self._state_store = self._storage.state diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 9e270d461b..e7a399787b 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -63,7 +63,7 @@ def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]: class GroupsLocalWorkerHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_list_handler = hs.get_room_list_handler() self.groups_server_handler = hs.get_groups_server_handler() self.transport_client = hs.get_federation_transport_client() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c83eaea359..57c9fdfe62 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -49,7 +49,7 @@ id_server_scheme = "https://" class IdentityHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 346a06ff49..344f20f37c 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) class InitialSyncHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.hs = hs diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4d0da84287..a9c964cd75 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -75,7 +75,7 @@ class MessageHandler: self.auth = hs.get_auth() self.clock = hs.get_clock() self.state = hs.get_state_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() @@ -397,7 +397,7 @@ class EventCreationHandler: self.hs = hs self.auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 8f71d975e9..593a2aac66 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -273,7 +273,7 @@ class OidcProvider: token_generator: "OidcSessionTokenGenerator", provider: OidcProviderConfig, ): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._token_generator = token_generator diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 973f262964..5c01a426ff 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -127,7 +127,7 @@ class PaginationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self.clock = hs.get_clock() diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b223b72623..c155098bee 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -133,7 +133,7 @@ class BasePresenceHandler(abc.ABC): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.presence_router = hs.get_presence_router() self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -1541,7 +1541,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self.get_presence_handler = hs.get_presence_handler self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def get_new_events( self, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 36e3ad2ba9..dd27f0accc 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -54,7 +54,7 @@ class ProfileHandler: PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 58593e570e..bad1acc634 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class ReadMarkerHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.config.server.server_name - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 5cb1ff749d..b4132c353a 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -29,7 +29,7 @@ class ReceiptsHandler: def __init__(self, hs: "HomeServer"): self.notifier = hs.get_notifier() self.server_name = hs.config.server.server_name - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_auth_handler = hs.get_event_auth_handler() self.hs = hs @@ -163,7 +163,7 @@ class ReceiptsHandler: class ReceiptEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config @staticmethod diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 80320d2c07..05bb1e0225 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -86,7 +86,7 @@ class LoginDict(TypedDict): class RegistrationHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a990727fc5..7b965b4b96 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -105,7 +105,7 @@ class EventContext: class RoomCreationHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self.hs = hs @@ -1115,7 +1115,7 @@ class RoomContextHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state @@ -1246,7 +1246,7 @@ class RoomContextHandler: class TimestampLookupHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.federation_client = hs.get_federation_client() @@ -1386,7 +1386,7 @@ class TimestampLookupHandler: class RoomEventSource(EventSource[RoomStreamToken, EventBase]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def get_new_events( self, @@ -1476,7 +1476,7 @@ class RoomShutdownHandler: self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def shutdown_room( self, diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index f8137ec04c..abbf7b7b27 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 1a33211a1f..f3577b5d5a 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -49,7 +49,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b2adc0f48b..a582837cf0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -66,7 +66,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 4844b69a03..2e61d1cbe9 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -90,7 +90,7 @@ class RoomSummaryHandler: def __init__(self, hs: "HomeServer"): self._event_auth_handler = hs.get_event_auth_handler() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 727d75a50c..9602f0d0bb 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -52,7 +52,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.hostname self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 0e0e58de02..aa16e417eb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -49,7 +49,7 @@ class _SearchResult: class SearchHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.clock = hs.get_clock() self.hs = hs diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 706ad72761..73861bbd40 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -27,7 +27,7 @@ class SetPasswordHandler: """Handler which deals with changing user account passwords""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 0bb8b0929e..ff5b5169ca 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -180,7 +180,7 @@ class SsoHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index d30ba2b724..2d197282ed 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -30,7 +30,7 @@ class MatchChange(Enum): class StateDeltasHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _get_key_change( self, diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 29e41a4c79..436cd971ce 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -39,7 +39,7 @@ class StatsHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self.server_name = hs.hostname self.clock = hs.get_clock() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e6050cbce6..98eaad3318 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -266,7 +266,7 @@ class SyncResult: class SyncHandler: def __init__(self, hs: "HomeServer"): self.hs_config = hs.config - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e4bed1c937..843c68eb0f 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -57,7 +57,7 @@ class FollowerTypingHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.config.server.server_name self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -446,7 +446,7 @@ class TypingWriterHandler(FollowerTypingHandler): class TypingNotificationEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self._main_store = hs.get_datastore() + self._main_store = hs.get_datastores().main self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: # diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 184730ebe8..014754a630 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -139,7 +139,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): class _BaseThreepidAuthChecker: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _check_threepid(self, medium: str, authdict: dict) -> dict: if "threepid_creds" not in authdict: @@ -255,7 +255,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): super().__init__(hs) self.hs = hs self._enabled = bool(hs.config.registration.registration_requires_token) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def is_enabled(self) -> bool: return self._enabled diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 1565e034cb..d27ed2be6a 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -55,7 +55,7 @@ class UserDirectoryHandler(StateDeltasHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index e7656fbb9f..40bf1e06d6 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -351,7 +351,7 @@ class MatrixFederationHttpClient: ) self.clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 07020bfb8d..902916d800 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -172,7 +172,9 @@ class ModuleApi: # TODO: Fix this type hint once the types for the data stores have been ironed # out. - self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore() + self._store: Union[ + DataStore, "GenericWorkerSlavedStore" + ] = hs.get_datastores().main self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname @@ -926,7 +928,7 @@ class ModuleApi: ) # Try to retrieve the resulting event. - event = await self._hs.get_datastore().get_event(event_id) + event = await self._hs.get_datastores().main.get_event(event_id) # update_membership is supposed to always return after the event has been # successfully persisted. @@ -1270,7 +1272,7 @@ class PublicRoomListManager: """ def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def room_is_in_public_room_list(self, room_id: str) -> bool: """Checks whether a room is in the public room list. diff --git a/synapse/notifier.py b/synapse/notifier.py index 753dd6b6a5..16d15a1f33 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -222,7 +222,7 @@ class Notifier: self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] # Called when there are new things to stream over replication diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 5176a1c186..a1b7711098 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -68,7 +68,7 @@ class ThrottleParams: class Pusher(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): self.hs = hs - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() self.pusher_id = pusher_config.id diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index bee660893b..fecf86034e 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -103,7 +103,7 @@ class BulkPushRuleEvaluator: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._event_auth_handler = hs.get_event_auth_handler() # Used by `RulesForRoom` to ensure only one thing mutates the cache at a @@ -366,7 +366,7 @@ class RulesForRoom: """ self.room_id = room_id self.is_mine_id = hs.is_mine_id - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_push_rule_cache_metrics = room_push_rule_cache_metrics # Used to ensure only one thing mutates the cache at a time. Keyed off diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 39bb2acae4..1710dd51b9 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -66,7 +66,7 @@ class EmailPusher(Pusher): super().__init__(hs, pusher_config) self.mailer = mailer - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.email = pusher_config.pushkey self.timed_call: Optional[IDelayedCall] = None self.throttle_params: Dict[str, ThrottleParams] = {} diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 52c7ff3572..5818344520 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -133,7 +133,7 @@ class HttpPusher(Pusher): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. badge = await push_tools.get_badge_count( - self.hs.get_datastore(), + self.hs.get_datastores().main, self.user_id, group_by_room=self._group_unread_count_by_room, ) @@ -283,7 +283,7 @@ class HttpPusher(Pusher): tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions) badge = await push_tools.get_badge_count( - self.hs.get_datastore(), + self.hs.get_datastores().main, self.user_id, group_by_room=self._group_unread_count_by_room, ) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 3df8452eec..649a4f49d0 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -112,7 +112,7 @@ class Mailer: self.template_text = template_text self.send_email_handler = hs.get_send_email_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.state_store = self.hs.get_storage().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 7912311d24..d0cc657b44 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -59,7 +59,7 @@ class PusherPool: def __init__(self, hs: "HomeServer"): self.hs = hs self.pusher_factory = PusherFactory(hs) - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() # We shard the handling of push notifications by user ID. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index f2f40129fe..3d63645726 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -63,7 +63,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index d529c8a19f..3e7300b4a1 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -68,7 +68,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() @@ -167,7 +167,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.registry = hs.get_federation_registry() @@ -214,7 +214,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.registry = hs.get_federation_registry() @@ -260,7 +260,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @staticmethod async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override] @@ -297,7 +297,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @staticmethod async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override] diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 0145858e47..663bff5738 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -50,7 +50,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): super().__init__(hs) self.federation_handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod @@ -119,7 +119,7 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): super().__init__(hs) self.federation_handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod @@ -188,7 +188,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @@ -258,7 +258,7 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @@ -325,7 +325,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): super().__init__(hs) self.registeration_handler = hs.get_registration_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.distributor = hs.get_distributor() diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index c7f751b70d..6c8f8388fd 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -36,7 +36,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() @staticmethod @@ -112,7 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() @staticmethod diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 33e98daf8a..ce78176836 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -69,7 +69,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.clock = hs.get_clock() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d59ce7ccf9..1b8479b0b4 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -111,7 +111,7 @@ class ReplicationDataHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._reactor = hs.get_reactor() self._clock = hs.get_clock() @@ -340,7 +340,7 @@ class FederationSenderHandler: def __init__(self, hs: "HomeServer"): assert hs.should_send_federation() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id self._hs = hs diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 17e1572393..0d2013a3cf 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -95,7 +95,7 @@ class ReplicationCommandHandler: def __init__(self, hs: "HomeServer"): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ecd6190f5b..494e42a2be 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -72,7 +72,7 @@ class ReplicationStreamer: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.notifier = hs.get_notifier() self._instance_name = hs.get_instance_name() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 914b9eae84..23d631a769 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -239,7 +239,7 @@ class BackfillStream(Stream): ROW_TYPE = BackfillStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), self._current_token, @@ -267,7 +267,7 @@ class PresenceStream(Stream): ROW_TYPE = PresenceStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main if hs.get_instance_name() in hs.config.worker.writers.presence: # on the presence writer, query the presence handler @@ -355,7 +355,7 @@ class ReceiptsStream(Stream): ROW_TYPE = ReceiptsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_receipt_stream_id), @@ -374,7 +374,7 @@ class PushRulesStream(Stream): ROW_TYPE = PushRulesStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -401,7 +401,7 @@ class PushersStream(Stream): ROW_TYPE = PushersStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -434,7 +434,7 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), store.get_cache_stream_token_for_writer, @@ -455,7 +455,7 @@ class DeviceListsStream(Stream): ROW_TYPE = DeviceListsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), @@ -474,7 +474,7 @@ class ToDeviceStream(Stream): ROW_TYPE = ToDeviceStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_to_device_stream_token), @@ -495,7 +495,7 @@ class TagAccountDataStream(Stream): ROW_TYPE = TagAccountDataStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_account_data_stream_id), @@ -516,7 +516,7 @@ class AccountDataStream(Stream): ROW_TYPE = AccountDataStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), @@ -585,7 +585,7 @@ class GroupServerStream(Stream): ROW_TYPE = GroupsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_group_stream_token), @@ -604,7 +604,7 @@ class UserSignatureStream(Stream): ROW_TYPE = UserSignatureStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 50c4a5ba03..26f4fa7cfd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -124,7 +124,7 @@ class EventsStream(Stream): NAME = "events" def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main super().__init__( hs.get_instance_name(), self._store._stream_id_gen.get_current_token_for_writer, diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ba0d989d81..6de302f813 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -116,7 +116,7 @@ class PurgeHistoryRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index e9bce22a34..93a78db811 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -112,7 +112,7 @@ class BackgroundUpdateStartJobRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d9905ff560..cef46ba0dd 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -44,7 +44,7 @@ class DeviceRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_GET( @@ -113,7 +113,7 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_GET( @@ -144,7 +144,7 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_POST( diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 38477f8ead..6d634eef70 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -53,7 +53,7 @@ class EventReportsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -115,7 +115,7 @@ class EventReportDetailRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, report_id: str diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index d162e0081e..023ed92144 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -48,7 +48,7 @@ class ListDestinationsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) @@ -105,7 +105,7 @@ class DestinationRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, destination: str @@ -165,7 +165,7 @@ class DestinationMembershipRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, destination: str @@ -221,7 +221,7 @@ class DestinationResetConnectionRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._authenticator = Authenticator(hs) async def on_POST( diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 299f5c9eb0..8ca57bdb28 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -47,7 +47,7 @@ class QuarantineMediaInRoom(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -74,7 +74,7 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -103,7 +103,7 @@ class QuarantineMediaByID(RestServlet): ) def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -132,7 +132,7 @@ class UnquarantineMediaByID(RestServlet): ) def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -156,7 +156,7 @@ class ProtectMediaByID(RestServlet): PATTERNS = admin_patterns("/media/protect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -178,7 +178,7 @@ class UnprotectMediaByID(RestServlet): PATTERNS = admin_patterns("/media/unprotect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -200,7 +200,7 @@ class ListMediaInRoom(RestServlet): PATTERNS = admin_patterns("/room/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET( @@ -251,7 +251,7 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]*)/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() @@ -283,7 +283,7 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() @@ -352,7 +352,7 @@ class UserMediaRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repository = hs.get_media_repository() async def on_GET( diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 04948b6408..af606e9252 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -71,7 +71,7 @@ class ListRegistrationTokensRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -109,7 +109,7 @@ class NewRegistrationTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() # A string of all the characters allowed to be in a registration_token self.allowed_chars = string.ascii_letters + string.digits + "._~-" @@ -260,7 +260,7 @@ class RegistrationTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: """Retrieve a registration token.""" diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 5b706efbcf..f4736a3dad 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -65,7 +65,7 @@ class RoomRestV2Servlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._pagination_handler = hs.get_pagination_handler() async def on_DELETE( @@ -188,7 +188,7 @@ class ListRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -278,7 +278,7 @@ class RoomRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() @@ -382,7 +382,7 @@ class RoomMembersRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -408,7 +408,7 @@ class RoomStateRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() @@ -525,7 +525,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_creation_handler = hs.get_event_creation_handler() self.state_handler = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -670,7 +670,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_DELETE( self, request: SynapseRequest, room_identifier: str @@ -781,7 +781,7 @@ class BlockRoomRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 7a6546372e..3b142b8402 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -38,7 +38,7 @@ class UserMediaStatisticsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c2617ee30c..8e29ada8a0 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,7 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -156,7 +156,7 @@ class UserRestServletV2(RestServlet): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.set_password_handler = hs.get_set_password_handler() @@ -588,7 +588,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() self.is_mine = hs.is_mine - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, target_user_id: str @@ -674,7 +674,7 @@ class ResetPasswordRestServlet(RestServlet): PATTERNS = admin_patterns("/reset_password/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -717,7 +717,7 @@ class SearchUsersRestServlet(RestServlet): PATTERNS = admin_patterns("/search_users/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine = hs.is_mine @@ -775,7 +775,7 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine = hs.is_mine @@ -835,7 +835,7 @@ class UserMembershipRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, user_id: str @@ -864,7 +864,7 @@ class PushersRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET( @@ -905,7 +905,7 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.is_mine_id = hs.is_mine_id @@ -974,7 +974,7 @@ class ShadowBanRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id @@ -1026,7 +1026,7 @@ class RateLimitRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id @@ -1129,7 +1129,7 @@ class AccountDataRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id async def on_GET( diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 5802de5b7c..4b217882e8 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -60,7 +60,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main self.config = hs.config self.identity_handler = hs.get_identity_handler() @@ -114,7 +114,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after # canonicalisation, matches the one in our database. - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "email", email ) @@ -168,7 +168,7 @@ class PasswordRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() @@ -347,7 +347,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.hs = hs self.config = hs.config self.identity_handler = hs.get_identity_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( @@ -450,7 +450,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -533,7 +533,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): super().__init__() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( self.config.email.email_add_threepid_template_failure_html @@ -600,7 +600,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): super().__init__() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: @@ -634,7 +634,7 @@ class ThreepidRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -768,7 +768,7 @@ class ThreepidUnbindRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 58b8adbd32..bfe985939b 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -42,7 +42,7 @@ class AccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() async def on_PUT( @@ -90,7 +90,7 @@ class RoomAccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() async def on_PUT( diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index ee247e3d1e..e181a0dde2 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -47,7 +47,7 @@ class ClientDirectoryServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() @@ -129,7 +129,7 @@ class ClientDirectoryListServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() @@ -173,7 +173,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 672c821061..916f5230f1 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -39,7 +39,7 @@ class EventStreamRestServlet(RestServlet): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index a7e9aa3e9b..7e1149c7f4 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -705,7 +705,7 @@ class GroupAdminUsersInviteServlet(RestServlet): self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id @_validate_group_id @@ -854,7 +854,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @_validate_group_id async def on_PUT( @@ -879,7 +879,7 @@ class PublicisedGroupsForUserServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.groups_handler = hs.get_groups_local_handler() async def on_GET( @@ -901,7 +901,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.groups_handler = hs.get_groups_local_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index 49b1037b28..cfadcb8e50 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -33,7 +33,7 @@ class InitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 730c18f08f..ce806e3c11 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -198,7 +198,7 @@ class KeyChangesServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f9994658c4..c9d44c5964 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -104,13 +104,13 @@ class LoginRestServlet(RestServlet): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( - store=hs.get_datastore(), + store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( - store=hs.get_datastore(), + store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 8e427a96a3..20377a9ac6 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -35,7 +35,7 @@ class NotificationsServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py index add56d6998..820682ec42 100644 --- a/synapse/rest/client/openid.py +++ b/synapse/rest/client/openid.py @@ -67,7 +67,7 @@ class IdTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.config.server.server_name diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 8fe75bd750..a93f6fd5e0 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -57,7 +57,7 @@ class PushRuleRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 98604a9388..d6487c31dd 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -46,7 +46,9 @@ class PushersRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) user = requester.user - pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + pushers = await self.hs.get_datastores().main.get_pushers_by_user_id( + user.to_string() + ) filtered_pushers = [p.as_dict() for p in pushers] diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index b8a5135e02..70baf50fa4 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -123,7 +123,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): request, "email", email ) - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "email", email ) @@ -203,7 +203,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): request, "msisdn", msisdn ) - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "msisdn", msisdn ) @@ -258,7 +258,7 @@ class RegistrationSubmitTokenServlet(RestServlet): self.auth = hs.get_auth() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( @@ -385,7 +385,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), @@ -415,7 +415,7 @@ class RegisterRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.identity_handler = hs.get_identity_handler() diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 2cab83c4e6..487ea38b55 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -85,7 +85,7 @@ class RelationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() @@ -190,7 +190,7 @@ class RelationAggregationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_handler = hs.get_event_handler() async def on_GET( @@ -282,7 +282,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index d4a4adb50c..6e962a4532 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -38,7 +38,7 @@ class ReportEventRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, room_id: str, event_id: str diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 90355e44b2..5ccfe5a92f 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -477,7 +477,7 @@ class RoomMemberListRestServlet(RestServlet): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -553,7 +553,7 @@ class RoomMessageListRestServlet(RestServlet): self._hs = hs self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -621,7 +621,7 @@ class RoomInitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -642,7 +642,7 @@ class RoomEventServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() @@ -1027,7 +1027,7 @@ class JoinedRoomsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -1116,7 +1116,7 @@ class TimestampLookupRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler() async def on_GET( diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 4b6be38327..0048973e59 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -75,7 +75,7 @@ class RoomBatchSendEventRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self.room_batch_handler = hs.get_room_batch_handler() diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py index 09a46737de..e669fa7890 100644 --- a/synapse/rest/client/shared_rooms.py +++ b/synapse/rest/client/shared_rooms.py @@ -41,7 +41,7 @@ class UserSharedRoomsServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_directory_active = hs.config.server.update_user_directory async def on_GET( diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f9615da525..f3018ff690 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -103,7 +103,7 @@ class SyncRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py index c88cb9367c..ca638755c7 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py @@ -39,7 +39,7 @@ class TagListServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, user_id: str, room_id: str diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 3d2afacc50..25f9ea285b 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -78,7 +78,7 @@ class ConsentResource(DirectServeHtmlResource): super().__init__() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() # this is required by the request_handler wrapper diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 3923ba8439..3525d6ae54 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -94,7 +94,7 @@ class RemoteKey(DirectServeJsonResource): super().__init__() self.fetcher = ServerKeyFetcher(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 71b9a34b14..6c414402bd 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -75,7 +75,7 @@ class MediaRepository: self.client = hs.get_federation_http_client() self.clock = hs.get_clock() self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.max_upload_size = hs.config.media.max_upload_size self.max_image_pixels = hs.config.media.max_image_pixels diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index c08b60d10a..14ea88b240 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -134,7 +134,7 @@ class PreviewUrlResource(DirectServeJsonResource): self.filepaths = media_repo.filepaths self.max_spider_size = hs.config.media.max_spider_size self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.client = SimpleHttpClient( hs, treq_args={"browser_like_redirects": True}, diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index ed91ef5a42..53b1565243 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -50,7 +50,7 @@ class ThumbnailResource(DirectServeJsonResource): ): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = media_repo self.media_storage = media_storage self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index fde28d08cb..e73e431dc9 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -37,7 +37,7 @@ class UploadResource(DirectServeJsonResource): self.media_repo = media_repo self.filepaths = media_repo.filepaths - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.hostname self.auth = hs.get_auth() diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py index 28a67f04e3..6ac9dbc7c9 100644 --- a/synapse/rest/synapse/client/password_reset.py +++ b/synapse/rest/synapse/client/password_reset.py @@ -44,7 +44,7 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource): super().__init__() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._local_threepid_handling_disabled_due_to_email_config = ( hs.config.email.local_threepid_handling_disabled_due_to_email_config diff --git a/synapse/server.py b/synapse/server.py index 4c07f21015..b5e2a319bc 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -17,7 +17,7 @@ # homeservers; either as a full homeserver as a real application, or a small # partial one for unit test mocking. -# Imports required for the default HomeServer() implementation + import abc import functools import logging @@ -134,7 +134,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import Databases, DataStore, Storage +from synapse.storage import Databases, Storage from synapse.streams.events import EventSources from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock @@ -225,7 +225,7 @@ class HomeServer(metaclass=abc.ABCMeta): # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be - # instantiated during setup() for future return by get_datastore() + # instantiated during setup() for future return by get_datastores() DATASTORE_CLASS = abc.abstractproperty() tls_server_context_factory: Optional[IOpenSSLContextFactory] @@ -355,12 +355,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_clock(self) -> Clock: return Clock(self._reactor) - def get_datastore(self) -> DataStore: - if not self.datastores: - raise Exception("HomeServer.setup must be called before getting datastores") - - return self.datastores.main - def get_datastores(self) -> Databases: if not self.datastores: raise Exception("HomeServer.setup must be called before getting datastores") @@ -374,7 +368,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( - store=self.get_datastore(), + store=self.get_datastores().main, clock=self.get_clock(), rate_hz=self.config.ratelimiting.rc_registration.per_second, burst_count=self.config.ratelimiting.rc_registration.burst_count, @@ -847,7 +841,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_request_ratelimiter(self) -> RequestRatelimiter: return RequestRatelimiter( - self.get_datastore(), + self.get_datastores().main, self.get_clock(), self.config.ratelimiting.rc_message, self.config.ratelimiting.rc_admin_redaction, diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index e09a25591f..698ca742ed 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -32,7 +32,7 @@ class ConsentServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._users_in_progress: Set[str] = set() diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 8522930b50..015dd08f05 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -36,7 +36,7 @@ class ResourceLimitsServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._auth = hs.get_auth() self._config = hs.config self._resouce_limited = False diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 0cf60236f8..7b4814e049 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -29,7 +29,7 @@ SERVER_NOTICE_ROOM_TAG = "m.server_notice" class ServerNoticesManager: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._config = hs.config self._account_data_handler = hs.get_account_data_handler() self._room_creation_handler = hs.get_room_creation_handler() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 67e8bc6ec2..fcc24ad129 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -126,7 +126,7 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 8f09dd8e87..e9a0cdc6be 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -112,7 +112,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): for tp in self.hs.config.server.mau_limits_reserved_threepids[ : self.hs.config.server.max_mau_value ]: - user_id = await self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( tp["medium"], canonicalise_email(tp["address"]) ) if user_id: diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 4ec2a713cf..fb8fe17295 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -48,7 +48,7 @@ class EventSources: # all the attributes of `_EventSourcesInner` are annotated. *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) # type: ignore[misc] ) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def get_current_token(self) -> StreamToken: push_rules_key = self.store.get_max_push_rules_stream_id() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4b53b6d40b..686d17c0de 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -16,6 +16,8 @@ from unittest.mock import Mock import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -26,8 +28,10 @@ from synapse.api.errors import ( ResourceLimitError, ) from synapse.appservice import ApplicationService +from synapse.server import HomeServer from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester +from synapse.util import Clock from tests import unittest from tests.test_utils import simple_async_mock @@ -36,10 +40,10 @@ from tests.utils import mock_getRawHeaders class AuthTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): self.store = Mock() - hs.get_datastore = Mock(return_value=self.store) + hs.datastores.main = self.store hs.get_auth_handler().store = self.store self.auth = Auth(hs) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index b7fc33dc94..973f0f7fa1 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -40,7 +40,7 @@ def MockEvent(**kwargs): class FilteringTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main def test_errors_on_invalid_filters(self): invalid_filters = [ diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dcf0110c16..4ef754a186 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -8,7 +8,7 @@ from tests import unittest class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) @@ -39,7 +39,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -70,7 +70,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -92,7 +92,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_ratelimit(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # Shouldn't raise @@ -116,7 +116,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # First attempt should be allowed @@ -162,7 +162,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # First attempt should be allowed @@ -190,7 +190,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_pruning(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) @@ -208,7 +208,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """Test that users that have ratelimiting disabled in the DB aren't ratelimited. """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user_id = "@user:test" requester = create_requester(user_id) @@ -233,7 +233,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_multiple_actions(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=3 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 19eb4c79d0..df731eb599 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -32,7 +32,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Advance 30 days (+ 1 second, because strict inequality causes issues if we are @@ -40,7 +40,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Send a message (this counts as activity) @@ -51,21 +51,21 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(2 * 60 * 60) # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance 29 days. The user has now not posted for 29 days. self.reactor.advance(29 * ONE_DAY_IN_SECONDS) # The user is still counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance another day. The user has now not posted for 30 days. self.reactor.advance(ONE_DAY_IN_SECONDS) # The user is now no longer counted in R30. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) def test_r30_minimum_usage_using_default_config(self): @@ -84,7 +84,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Advance 30 days (+ 1 second, because strict inequality causes issues if we are @@ -92,7 +92,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Send a message (this counts as activity) @@ -103,14 +103,14 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(2 * 60 * 60) # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance 27 days. The user has now not posted for 27 days. self.reactor.advance(27 * ONE_DAY_IN_SECONDS) # The user is still counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance another day. The user has now not posted for 28 days. @@ -119,7 +119,7 @@ class PhoneHomeTestCase(HomeserverTestCase): # The user is now no longer counted in R30. # (This is because the user_ips table has been pruned, which by default # only preserves the last 28 days of entries.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) def test_r30_user_must_be_retained_for_at_least_a_month(self): @@ -135,7 +135,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the user does not contribute to R30 yet. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) for _ in range(30): @@ -144,14 +144,16 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "I'm still here", tok=access_token) # Notice that the user *still* does not contribute to R30! - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success( + self.hs.get_datastores().main.count_r30_users() + ) self.assertEqual(r30_results, {"all": 0}) self.reactor.advance(ONE_DAY_IN_SECONDS) self.helper.send(room_id, "Still here!", tok=access_token) # *Now* the user appears in R30. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) @@ -196,7 +198,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check the R30 results do not count that user. r30_results = self.get_success(store.count_r30v2_users()) @@ -275,7 +277,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check the user does not contribute to R30 yet. r30_results = self.get_success(store.count_r30v2_users()) @@ -347,7 +349,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check that the user does not contribute to R30v2, even though it's been # more than 30 days since registration. diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 17a9fb63a1..3a4d502719 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -179,7 +179,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.get_datastore().store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], @@ -272,7 +272,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.get_datastore().store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], @@ -448,7 +448,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) @@ -564,7 +564,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) @@ -683,7 +683,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index ca27388ae8..defbc68c18 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -28,7 +28,7 @@ class TestEventContext(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.user_id = self.register_user("u1", "pass") diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index e40ef95874..9336181c96 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -55,7 +55,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertTrue(complexity > 0, complexity) # Artificially raise the complexity - store = self.hs.get_datastore() + store = self.hs.get_datastores().main store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value @@ -149,7 +149,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): ) # Artificially raise the complexity - self.hs.get_datastore().get_current_state_event_counts = ( + self.hs.get_datastores().main.get_current_state_event_counts = ( lambda x: make_awaitable(600) ) diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index f0aa8ed9db..2873b4d430 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -64,7 +64,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): Dictionary of { event_id: str, stream_ordering: int } """ event_id, stream_ordering = self.get_success( - self.hs.get_datastore().db_pool.execute( + self.hs.get_datastores().main.db_pool.execute( "test:get_destination_rooms", None, """ @@ -125,7 +125,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.pump() lsso_1 = self.get_success( - self.hs.get_datastore().get_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.get_destination_last_successful_stream_ordering( "host2" ) ) @@ -141,7 +141,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] lsso_2 = self.get_success( - self.hs.get_datastore().get_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.get_destination_last_successful_stream_ordering( "host2" ) ) @@ -216,7 +216,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # let's also clear any backoffs self.get_success( - self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0) + self.hs.get_datastores().main.set_destination_retry_timings( + "host2", None, 0, 0 + ) ) # bring the remote online and clear the received pdu list @@ -296,13 +298,13 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # destination_rooms should already be populated, but let us pretend that we already # sent (successfully) up to and including event id 2 - event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2)) + event_2 = self.get_success(self.hs.get_datastores().main.get_event(event_id_2)) # also fetch event 5 so we know its last_successful_stream_ordering later - event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5)) + event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5)) self.get_success( - self.hs.get_datastore().set_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_2.internal_metadata.stream_ordering ) ) @@ -359,7 +361,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # ASSERT: # - All servers are up to date so none should have outstanding catch-up outstanding_when_successful = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None) ) self.assertEqual(outstanding_when_successful, []) @@ -370,7 +372,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # - Mark zzzerver as being backed-off from now = self.clock.time_msec() self.get_success( - self.hs.get_datastore().set_destination_retry_timings( + self.hs.get_datastores().main.set_destination_retry_timings( "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day ) ) @@ -382,14 +384,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # - all remotes are outstanding # - they are returned in batches of 25, in order outstanding_1 = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None) ) self.assertEqual(len(outstanding_1), 25) self.assertEqual(outstanding_1, server_names[0:25]) outstanding_2 = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations( + self.hs.get_datastores().main.get_catch_up_outstanding_destinations( outstanding_1[-1] ) ) @@ -457,7 +459,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): ) self.get_success( - self.hs.get_datastore().set_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_1.internal_metadata.stream_ordering ) ) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b2376e2db9..60e0c31f43 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -176,7 +176,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def get_users_who_share_room_with_user(user_id): return defer.succeed({"@user2:host2"}) - hs.get_datastore().get_users_who_share_room_with_user = ( + hs.get_datastores().main.get_users_who_share_room_with_user = ( get_users_who_share_room_with_user ) @@ -395,7 +395,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # run the prune job self.reactor.advance(10) self.get_success( - self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1) ) # recover the server @@ -445,7 +445,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # run the prune job self.reactor.advance(10) self.get_success( - self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1) ) # recover the server diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 686f42ab48..adf0535d97 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -198,7 +198,7 @@ class FederationKnockingTestCase( ] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main # We're not going to be properly signing events as our remote homeserver is fake, # therefore disable event signature checks. diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index fe57ff2671..9918ff6807 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -38,7 +38,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api = Mock() self.mock_scheduler = Mock() hs = Mock() - hs.get_datastore.return_value = self.mock_store + hs.get_datastores.return_value = Mock(main=self.mock_store) self.mock_store.get_received_ts.return_value = make_awaitable(0) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( @@ -355,7 +355,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastore().get_app_services = Mock(return_value=self._services) + self.hs.get_datastores().main.get_app_services = Mock( + return_value=self._services + ) # A user on the homeserver. self.local_user_device_id = "local_device" @@ -494,7 +496,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Create a fake device per message. We can't send to-device messages to # a device that doesn't exist. self.get_success( - self.hs.get_datastore().db_pool.simple_insert_many( + self.hs.get_datastores().main.db_pool.simple_insert_many( desc="test_application_services_receive_burst_of_to_device", table="devices", keys=("user_id", "device_id"), @@ -510,7 +512,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Seed the device_inbox table with our fake messages self.get_success( - self.hs.get_datastore().add_messages_to_device_inbox(messages, {}) + self.hs.get_datastores().main.add_messages_to_device_inbox(messages, {}) ) # Now have local_user send a final to-device message to exclusive_as_user. All unsent diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 03b8b8615c..0c6e55e725 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -129,7 +129,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) @@ -140,7 +140,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) self.get_failure( @@ -156,7 +156,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True # Set the server to be at the edge of too many users. - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.auth_blocking._max_mau_value) ) @@ -175,7 +175,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) # If in monthly active cohort - self.hs.get_datastore().user_last_seen_monthly_active = Mock( + self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( return_value=make_awaitable(self.clock.time_msec()) ) self.get_success( @@ -192,7 +192,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_mau_limits_not_exceeded(self): self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception @@ -202,7 +202,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) self.get_success( diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 8705ff8943..a267228846 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -77,7 +77,7 @@ class CasHandlerTestCase(HomeserverTestCase): def test_map_cas_user_to_existing_user(self): """Existing users can log in with CAS account.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 01096a1581..ddda36c5a9 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -34,7 +34,7 @@ class DeactivateAccountTestCase(HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.user = self.register_user("user", "pass") self.token = self.login("user", "pass") diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 43031e07ea..683677fd07 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -28,7 +28,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def prepare(self, reactor, clock, hs): @@ -263,7 +263,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.handler = hs.get_device_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def test_dehydrate_and_rehydrate_device(self): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 0ea4e753e2..65ab107d0e 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -46,7 +46,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.handler = hs.get_directory_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.my_room = RoomAlias.from_string("#my-room:test") self.your_room = RoomAlias.from_string("#your-room:test") @@ -174,7 +174,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -289,7 +289,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 734ed84d78..9338ab92e9 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = hs.get_e2e_keys_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_query_local_devices_no_devices(self): """If the user has no devices, we expect an empty list.""" diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 496b581726..e8b4e39d1a 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -45,7 +45,7 @@ class FederationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self._event_auth_handler = hs.get_event_auth_handler() return hs diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 5816295d8b..f4f7ab4845 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -44,7 +44,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) self.info = self.get_success( - self.hs.get_datastore().get_user_by_access_token( + self.hs.get_datastores().main.get_user_by_access_token( self.access_token, ) ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a552d8182e..e8418b6638 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -856,7 +856,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler.complete_sso_login.reset_mock() # Test if the mxid is already taken - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user3 = UserID.from_string("@test_user_3:test") self.get_success( store.register_user(user_id=user3.to_string(), password_hash=None) @@ -872,7 +872,7 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) def test_map_userinfo_to_existing_user(self): """Existing users can log in with OpenID Connect when allow_existing_users is True.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user = UserID.from_string("@test_user:test") self.get_success( store.register_user(user_id=user.to_string(), password_hash=None) @@ -996,7 +996,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 671dc7d083..61d28603ae 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -43,7 +43,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_offline_to_online(self): wheel_timer = Mock() @@ -891,7 +891,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # self.event_builder_for_2 = EventBuilderFactory(hs) # self.event_builder_for_2.hostname = "test2" - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 60235e5699..69e299fc17 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -48,7 +48,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.frank = UserID.from_string("@1234abcd:test") self.bob = UserID.from_string("@4567:test") @@ -325,7 +325,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): properties are "mimetype" (for the file's type) and "size" (for the file's size). """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main for name, props in names_and_props.items(): self.get_success( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index cd6f2c77ae..51ee667ab4 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -154,7 +154,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_registration_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.lots_of_users = 100 self.small_number_of_users = 1 @@ -172,7 +172,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertGreater(len(result_token), 20) def test_if_user_exists(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main frank = UserID.from_string("@frank:test") self.get_success( store.register_user(user_id=frank.to_string(), password_hash=None) @@ -760,7 +760,7 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_registration_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main @override_config({"auto_join_rooms": ["#room:remotetest"]}) def test_auto_create_auto_join_remote_room(self): diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 50551aa6e3..23941abed8 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -142,7 +142,7 @@ class SamlHandlerTestCase(HomeserverTestCase): @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): """Existing users can log in with SAML account.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) @@ -217,7 +217,7 @@ class SamlHandlerTestCase(HomeserverTestCase): sso_handler.render_error = Mock(return_value=None) # register a user to occupy the first-choice MXID - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 56207f4db6..ecd78fa369 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -33,7 +33,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = self.hs.get_stats_handler() def _add_background_updates(self): diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 07a760e91a..66b0bd4d1a 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -41,7 +41,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): self.sync_handler = self.hs.get_sync_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' @@ -248,7 +248,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # the prev_events used when creating the join event, such that the ban does not # precede the join. mocked_get_prev_events = patch.object( - self.hs.get_datastore(), + self.hs.get_datastores().main, "get_prev_events_for_room", new_callable=MagicMock, return_value=make_awaitable([last_room_creation_event_id]), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 000f9b9fde..e461e03599 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -91,7 +91,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source = hs.get_event_sources().sources.typing - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main self.datastore.get_destination_retry_timings = Mock( return_value=defer.succeed(None) ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 482c90ef68..e159169e22 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -77,7 +77,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.event_creation_handler = self.hs.get_event_creation_handler() diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index d16cd141a7..c3f20f9692 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -41,7 +41,7 @@ class ModuleApiTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() self.sync_handler = homeserver.get_sync_handler() diff --git a/tests/push/test_email.py b/tests/push/test_email.py index f8cba7b645..7a3b0d6755 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -102,13 +102,13 @@ class EmailPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(self.access_token) + self.hs.get_datastores().main.get_user_by_access_token(self.access_token) ) self.token_id = user_tuple.token_id # We need to add email to account before we can create a pusher. self.get_success( - hs.get_datastore().user_add_threepid( + hs.get_datastores().main.user_add_threepid( self.user_id, "email", "a@example.com", 0, 0 ) ) @@ -128,7 +128,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.auth_handler = hs.get_auth_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated @@ -375,7 +375,7 @@ class EmailPusherTests(HomeserverTestCase): # check that the pusher for that email address has been deleted pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 0) @@ -388,14 +388,14 @@ class EmailPusherTests(HomeserverTestCase): # This resembles the old behaviour, which the background update below is intended # to clean up. self.get_success( - self.hs.get_datastore().user_delete_threepid( + self.hs.get_datastores().main.user_delete_threepid( self.user_id, "email", "a@example.com" ) ) # Run the "remove_deleted_email_pushers" background job self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="background_updates", values={ "update_name": "remove_deleted_email_pushers", @@ -406,14 +406,14 @@ class EmailPusherTests(HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.hs.get_datastore().db_pool.updates._all_done = False + self.hs.get_datastores().main.db_pool.updates._all_done = False # Now let's actually drive the updates to completion self.wait_for_background_updates() # Check that all pushers with unlinked addresses were deleted pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 0) @@ -428,7 +428,7 @@ class EmailPusherTests(HomeserverTestCase): """ # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -439,7 +439,7 @@ class EmailPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -458,7 +458,7 @@ class EmailPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index e1e3fb97c5..c284beb37c 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -62,7 +62,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -108,7 +108,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -138,7 +138,7 @@ class HTTPPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -149,7 +149,7 @@ class HTTPPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -170,7 +170,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -192,7 +192,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased, again pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -224,7 +224,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -344,7 +344,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -430,7 +430,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -507,7 +507,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -613,7 +613,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 9fc50f8852..a7a05a564f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -68,7 +68,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Since we use sqlite in memory databases we need to make sure the # databases objects are the same. - self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool + self.worker_hs.get_datastores().main.db_pool = hs.get_datastores().main.db_pool # Normally we'd pass in the handler to `setup_test_homeserver`, which would # eventually hit "Install @cache_in_self attributes" in tests/utils.py. @@ -233,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # We may have an attempt to connect to redis for the external cache already. self.connect_any_redis_attempts() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" @@ -332,7 +332,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): lambda: self._handle_http_replication_attempt(worker_hs, port), ) - store = worker_hs.get_datastore() + store = worker_hs.get_datastores().main store.db_pool._db_pool = self.database_pool._db_pool # Set up TCP replication between master and the new worker if we don't diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 83e89383f6..85be79d19d 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -30,8 +30,8 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): self.reconnect() - self.master_store = hs.get_datastore() - self.slaved_store = self.worker_hs.get_datastore() + self.master_store = hs.get_datastores().main + self.slaved_store = self.worker_hs.get_datastores().main self.storage = hs.get_storage() def replicate(self): diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py index cdd052001b..50fbff5f32 100644 --- a/tests/replication/tcp/streams/test_account_data.py +++ b/tests/replication/tcp/streams/test_account_data.py @@ -23,7 +23,7 @@ from tests.replication._base import BaseStreamTestCase class AccountDataStreamTestCase(BaseStreamTestCase): def test_update_function_room_account_data_limit(self): """Test replication with many room account data updates""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # generate lots of account data updates updates = [] @@ -69,7 +69,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase): def test_update_function_global_account_data_limit(self): """Test replication with many global account data updates""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # generate lots of account data updates updates = [] diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f198a94887..f9d5da723c 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -136,7 +136,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # this is the point in the DAG where we make a fork fork_point: List[str] = self.get_success( - self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) events = [ @@ -291,7 +291,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # this is the point in the DAG where we make a fork fork_point: List[str] = self.get_success( - self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) events: List[EventBase] = [] diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 38e292c1ab..eb00117845 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -32,7 +32,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # tell the master to send a new receipt self.get_success( - self.hs.get_datastore().insert_receipt( + self.hs.get_datastores().main.insert_receipt( "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) @@ -56,7 +56,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.test_handler.on_rdata.reset_mock() self.get_success( - self.hs.get_datastore().insert_receipt( + self.hs.get_datastores().main.insert_receipt( "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} ) ) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 92a5b53e11..ba1a63c0d6 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -204,7 +204,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main federation = self.hs.get_federation_event_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 4094a75f36..8f4f6688ce 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -50,7 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): # Register a pusher user_dict = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_dict.token_id diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 596ba5a0c9..5f142e84c3 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -47,7 +47,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.other_access_token = self.login("otheruser", "pass") self.room_creator = self.hs.get_room_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def default_config(self): conf = super().default_config() @@ -99,7 +99,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): persisted_on_1 = False persisted_on_2 = False - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user_id = self.register_user("user", "pass") access_token = self.login("user", "pass") @@ -166,7 +166,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): user_id = self.register_user("user", "pass") access_token = self.login("user", "pass") - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create two room on the different workers. self._create_room(room_id1, user_id, access_token) @@ -194,7 +194,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # # Worker2's event stream position will not advance until we call # __aexit__ again. - worker_store2 = worker_hs2.get_datastore() + worker_store2 = worker_hs2.get_datastores().main assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) actx = worker_store2._stream_id_gen.get_next() diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 1e3fe9c62c..fb36aa9940 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -36,7 +36,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 71068d16cd..929bbdc37d 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -35,7 +35,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -537,7 +537,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 86aff7575c..0d47dd0aff 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -634,7 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.hostname self.admin_user = self.register_user("admin", "pass", admin=True) @@ -767,7 +767,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 8513b1d2df..8354250ec2 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -34,7 +34,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 23da0ad736..09c48e85c7 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -50,7 +50,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -465,7 +465,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2239,7 +2239,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 3c59f5f766..2c855bff99 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -38,7 +38,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() self.server_notices_manager = self.hs.get_server_notices_manager() diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 272637e965..a60ea0a563 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -410,7 +410,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): even if the MAU limit is reached. """ handler = self.hs.get_registration_handler() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Set monthly active users to the limit store.get_monthly_active_count = Mock( @@ -455,7 +455,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -913,7 +913,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1167,7 +1167,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() # create users and get access tokens @@ -2609,7 +2609,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2737,7 +2737,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository_resource() self.filepaths = MediaFilePaths(hs.config.media.media_store_path) @@ -3317,7 +3317,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3609,7 +3609,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3687,7 +3687,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3913,7 +3913,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index afaa597f65..aa019c9a44 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -77,7 +77,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) def test_basic_password_reset(self): @@ -398,7 +398,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.deactivate(user_id, tok) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check that the user has been marked as deactivated. self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) @@ -409,7 +409,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): def test_pending_invites(self): """Tests that deactivating a user rejects every pending invite for them.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main inviter_id = self.register_user("inviter", "test") inviter_tok = self.login("inviter", "test") @@ -527,7 +527,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) whoami = self._whoami(as_token) self.assertEqual( @@ -586,7 +586,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): return self.hs def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("kermit", "test") self.user_id_tok = self.login("kermit", "test") diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 475c6bed3d..a573cc3c2e 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -32,7 +32,7 @@ class FilterTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_add_filter(self): channel = self.make_request( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 19f5e46537..26d0d83e00 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -1101,8 +1101,8 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): }, ) - self.hs.get_datastore().services_cache.append(self.service) - self.hs.get_datastore().services_cache.append(self.another_service) + self.hs.get_datastores().main.services_cache.append(self.service) + self.hs.get_datastores().main.services_cache.append(self.another_service) return self.hs def test_login_appservice_user(self): diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index ead883ded8..b9647d5bd8 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -292,7 +292,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): properties are "mimetype" (for the file's type) and "size" (for the file's size). """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main for name, props in names_and_props.items(): self.get_success( diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 0f1c47dcbb..2835d86e5b 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -56,7 +56,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): sender="@as:test", ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) request_data = json.dumps( {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} ) @@ -80,7 +80,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): sender="@as:test", ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) request_data = json.dumps({"username": "as_user_kermit"}) channel = self.make_request( @@ -210,7 +210,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): username = "kermit" device_id = "frogfone" token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -316,7 +316,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"registration_requires_token": True}) def test_POST_registration_token_limit_uses(self): token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create token that can be used once self.get_success( store.db_pool.simple_insert( @@ -391,7 +391,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_registration_token_expiry(self): token = "abcd" now = self.hs.get_clock().time_msec() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create token that expired yesterday self.get_success( store.db_pool.simple_insert( @@ -439,7 +439,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_registration_token_session_expiry(self): """Test `pending` is decremented when an uncompleted session expires.""" token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -530,7 +530,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): 3. Expire the session """ token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -657,7 +657,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Add a threepid self.get_success( - self.hs.get_datastore().user_add_threepid( + self.hs.get_datastores().main.user_add_threepid( user_id=user_id, medium="email", address=email, @@ -941,7 +941,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts = [] self.hs.get_send_email_handler()._sendmail = sendmail - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main return self.hs @@ -1126,10 +1126,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): # We need to set these directly, instead of in the homeserver config dict above. # This is due to account validity-related config options not being read by # Synapse when account_validity.enabled is False. - self.hs.get_datastore()._account_validity_period = self.validity_period - self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta + self.hs.get_datastores().main._account_validity_period = self.validity_period + self.hs.get_datastores().main._account_validity_startup_job_max_delta = ( + self.max_delta + ) - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main return self.hs @@ -1163,7 +1165,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): def test_GET_token_valid(self): token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index dfd9ffcb93..5687dea48d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -53,7 +53,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id, self.user_token = self._create_user("alice") self.user2_id, self.user2_token = self._create_user("bob") @@ -107,7 +107,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # Unless that event is referenced from another event! self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="event_relations", values={ "event_id": "bar", diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index fe5b536d97..c41a1c14a1 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -51,7 +51,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.serializer = self.hs.get_event_client_serializer() self.clock = self.hs.get_clock() @@ -114,7 +114,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): """Tests that synapse.visibility.filter_events_for_client correctly filters out outdated events """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main storage = self.hs.get_storage() room_id = self.helper.create_room_as(self.user_id, tok=self.token) events = [] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index b7f086927b..1afd96b8f5 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -65,7 +65,7 @@ class RoomBase(unittest.HomeserverTestCase): async def _insert_client_ip(*args, **kwargs): return None - self.hs.get_datastore().insert_client_ip = _insert_client_ip + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip return self.hs @@ -667,7 +667,7 @@ class RoomsCreateTestCase(RoomBase): # Add the current user to the ratelimit overrides, allowing them no ratelimiting. self.get_success( - self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) + self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) ) # Test that the invites aren't ratelimited anymore. @@ -1060,7 +1060,9 @@ class RoomJoinRatelimitTestCase(RoomBase): user_id = self.register_user("testuser", "password") # Check that the new user successfully joined the four rooms - rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) + rooms = self.get_success( + self.hs.get_datastores().main.get_rooms_for_user(user_id) + ) self.assertEqual(len(rooms), 4) @@ -1184,7 +1186,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("end" in channel.json_body) def test_room_messages_purge(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main pagination_handler = self.hs.get_pagination_handler() # Send a first message in the room, which will be removed by the purge. diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index b0c44af033..7d0e66b534 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -34,7 +34,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase): self.banned_user_id = self.register_user("banned", "test") self.banned_access_token = self.login("banned", "test") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.get_success( self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True) diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py index 283eccd53f..c42c8aff6c 100644 --- a/tests/rest/client/test_shared_rooms.py +++ b/tests/rest/client/test_shared_rooms.py @@ -36,7 +36,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() def _get_shared_rooms(self, token, other_user) -> FakeChannel: diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index cd4af2b1f3..e062561365 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -299,7 +299,7 @@ class SyncKnockTestCase( ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.url = "/sync?since=%s" self.next_batch = "s0" diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index ee0abd5295..de312cb63c 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -57,7 +57,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): async def _insert_client_ip(*args, **kwargs): return None - hs.get_datastore().insert_client_ip = _insert_client_ip + hs.get_datastores().main.insert_client_ip = _insert_client_ip return hs diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index a42388b26f..7f79336abc 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -32,7 +32,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.creator = self.register_user("creator", "pass") self.creator_token = self.login(self.creator, "pass") diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 4cf1ed5ddf..6878ccddbf 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -243,7 +243,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] self.thumbnail_resource = media_resource.children[b"thumbnail"] - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository() self.media_id = "example.com/12345" diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 36c495954f..02b96c9e6e 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -242,7 +242,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() self.event_source = self.hs.get_event_sources() diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 36c933b9e9..50c20c5b92 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -26,7 +26,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") def test_background_remove_deleted_devices_from_device_inbox(self): diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 5ae491ff5a..59def6e59c 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -37,7 +37,7 @@ from tests import unittest class HaveSeenEventsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main # insert some test data for rid in ("room1", "room2"): @@ -122,7 +122,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main self.user = self.register_user("user", "pass") self.token = self.login(self.user, "pass") @@ -163,7 +163,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): """Test event fetching during a database outage.""" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main self.room_id = f"!room:{hs.hostname}" self.event_ids = [f"event{i}" for i in range(20)] diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index d326a1d6a6..3ac4646969 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -20,7 +20,7 @@ from tests import unittest class LockTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_simple_lock(self): """Test that we can take out a lock and that while we hold it nobody diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 7496974da3..9abd0cb446 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -28,7 +28,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 200b9198f9..4899cd5c36 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,7 +20,7 @@ from tests import unittest class UpsertManyTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.storage = hs.get_datastore() + self.storage = hs.get_datastores().main self.table_name = "table_" + secrets.token_hex(6) self.get_success( diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index d697d2bc1e..272cd35402 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -21,7 +21,7 @@ from tests import unittest class IgnoredUsersTestCase(unittest.HomeserverTestCase): def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user = "@user:test" def _update_ignore_list( diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ddcb7f5549..50703ccaee 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -467,7 +467,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: self.service = Mock(id="foo") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.get_success( self.store.set_appservice_state(self.service, ApplicationServiceState.UP) ) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 6156dfac4e..39dcc094bd 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -24,7 +24,7 @@ from tests.test_utils import make_awaitable, simple_async_mock class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -42,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the target runtime for each bg update target_background_update_duration_ms = 100 - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "background_updates", @@ -102,7 +102,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -138,7 +138,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): ) def test_controller(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "background_updates", diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index a59c28f896..ce89c96912 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -30,7 +30,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() # Create a test user and room @@ -242,7 +242,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() self.event_creator_handler = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c8ac67e35b..49ad3c1324 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -35,7 +35,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): return hs def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_insert_new_client_ip(self): self.reactor.advance(12345678) @@ -666,7 +666,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): return hs def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user_id = self.register_user("bob", "abc123", True) def test_request_with_xforwarded(self): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index b547bf8d99..21ffc5a909 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -19,7 +19,7 @@ from tests.unittest import HomeserverTestCase class DeviceStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_store_new_device(self): self.get_success( diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 43628ce44f..7b72a92424 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -19,7 +19,7 @@ from tests.unittest import HomeserverTestCase class DirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 7556171d8a..fb96ab3a2f 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -28,7 +28,7 @@ room_key: RoomKey = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver("server", federation_http_client=None) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def test_room_keys_version_delete(self): diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 3bf6e337f4..0f04493ad0 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -17,7 +17,7 @@ from tests.unittest import HomeserverTestCase class EndToEndKeyStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_key_without_device_name(self): now = 1470174257070 diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index e3273a93f9..401020fd63 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -30,7 +30,7 @@ from tests.unittest import HomeserverTestCase class EventChainStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._next_stream_ordering = 1 def test_simple(self): @@ -492,7 +492,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") self.requester = create_requester(self.user_id) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 667ca90a4d..645d564d1c 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -31,7 +31,7 @@ import tests.utils class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_get_prev_events_for_room(self): room_id = "@ROOM:local" diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 738f3ad1dc..c9e3b9fa79 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -30,7 +30,7 @@ HIGHLIGHT = [ class EventPushActionsStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.persist_events_store = hs.get_datastores().persist_events def test_get_unread_push_actions_for_user_in_range_for_http(self): diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index a8639d8f82..ef5e25873c 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -32,7 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self.persistence = self.hs.get_storage().persistence - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.register_user("user", "pass") self.token = self.login("user", "pass") @@ -341,7 +341,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self.persistence = self.hs.get_storage().persistence - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_remote_user_rooms_cache_invalidated(self): """Test that if the server leaves a room the `get_rooms_for_user` cache diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 7486078284..6ac4b93f98 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -26,7 +26,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -459,7 +459,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -585,7 +585,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index a94b5fd721..9059095525 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -37,7 +37,7 @@ KEY_2 = decode_verify_key_base64( class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_server_verify_keys(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:KEY_ID_2" @@ -74,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_cache(self): """Check that updates correctly invalidate the cache.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:key2" diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index f8d11bac4e..4ca212fd11 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -22,7 +22,7 @@ class DataStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: super(DataStoreTestCase, self).setUp() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user = UserID.from_string("@abcde:test") self.displayname = "Frank" diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index d6b4cdd788..79648d45db 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -45,7 +45,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main # Advance the clock a bit reactor.advance(FORTY_DAYS) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index d37736edf8..b6f99af2f1 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -22,7 +22,7 @@ from tests import unittest class ProfileStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 22a77c3ccc..08cc60237e 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -30,7 +30,7 @@ class PurgeTests(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = self.hs.get_storage() def test_purge_history(self): @@ -47,7 +47,7 @@ class PurgeTests(HomeserverTestCase): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) - token_str = self.get_success(token.to_string(self.hs.get_datastore())) + token_str = self.get_success(token.to_string(self.hs.get_datastores().main)) # Purge everything before this topological token self.get_success( diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 8c95a0a2fb..03e9cc7d4a 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -30,7 +30,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 9748065282..1fa495f778 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase class RegistrationStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = "@my-user:test" self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", "BcDeFgHiJkLmNoPqRsTuVwXyZa"] diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py index cfc8098af6..0baa54312e 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py @@ -56,7 +56,7 @@ class WorkerSchemaTests(HomeserverTestCase): def test_rolling_back(self): """Test that workers can start if the DB is a newer schema version""" - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, @@ -72,7 +72,7 @@ class WorkerSchemaTests(HomeserverTestCase): def test_not_upgraded_old_schema_version(self): """Test that workers don't start if the DB has an older schema version""" - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, @@ -92,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase): Test that workers don't start if the DB is on the current schema version, but there are still outstanding delta migrations to run. """ - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 31ce7f6252..42bfca2a83 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -23,7 +23,7 @@ class RoomStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#a-room-name:test") @@ -71,7 +71,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): # Room events need the full datastore, for persist_event() and # get_room_state() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 8971ecccbd..befaa0fcee 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -46,7 +46,7 @@ class NullByteInsertionTest(HomeserverTestCase): self.assertIn("event_id", response) # Check that search works for the message where the null byte was replaced - store = self.hs.get_datastore() + store = self.hs.get_datastores().main result = self.get_success( store.search_msgs([room_id], "hi bob", ["content.body"]) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5cfdfe9b85..7028f0dfb0 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -35,7 +35,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -212,7 +212,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() def test_can_rerun_update(self): diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 28c767ecfd..f88f1c55fc 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index ce782c7e1d..6a1cf33054 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -115,7 +115,7 @@ class PaginationTestCase(HomeserverTestCase): ) events, next_key = self.get_success( - self.hs.get_datastore().paginate_room_events( + self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, from_key=from_token.room_key, to_key=None, diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index bea9091d30..e05daa285e 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase class TransactionStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_get_set_transactions(self): """Tests that we can successfully get a non-existent entry for diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 48f1e9d841..7f1964eb6a 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -149,7 +149,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_dir_helper = GetUserDirectoryTables(self.store) def _purge_and_rebuild_user_dir(self) -> None: @@ -415,7 +415,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): class UserDirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_federation.py b/tests/test_federation.py index 2b9804aba0..c39816de85 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -52,11 +52,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) )[0]["room_id"] - self.store = self.homeserver.get_datastore() + self.store = self.homeserver.get_datastores().main # Figure out what the most recent event is most_recent = self.get_success( - self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.homeserver.get_datastores().main.get_latest_event_ids_in_room( + self.room_id + ) )[0] join_event = make_event_from_dict( @@ -185,7 +187,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. - store = self.homeserver.get_datastore() + store = self.homeserver.get_datastores().main store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at diff --git a/tests/test_mau.py b/tests/test_mau.py index 80ab40e255..46bd3075de 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -52,7 +52,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_simple_deny_mau(self): # Create and sync so that the MAU counts get updated diff --git a/tests/test_state.py b/tests/test_state.py index 76e0e8ca7f..90800421fb 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -162,7 +162,7 @@ class StateTestCase(unittest.TestCase): hs = Mock( spec_set=[ "config", - "get_datastore", + "get_datastores", "get_storage", "get_auth", "get_state_handler", @@ -173,7 +173,7 @@ class StateTestCase(unittest.TestCase): ] ) hs.config = default_config("tesths", True) - hs.get_datastore.return_value = self.store + hs.get_datastores.return_value = Mock(main=self.store) hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index e9ec9e085b..c654e36ee4 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -85,7 +85,9 @@ async def create_event( **kwargs, ) -> Tuple[EventBase, EventContext]: if room_version is None: - room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) + room_version = await hs.get_datastores().main.get_room_version_id( + kwargs["room_id"] + ) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e0b08d67d4..219b5660b1 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -93,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): events_to_filter.append(evt) # the erasey user gets erased - self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs")) + self.get_success( + self.hs.get_datastores().main.mark_user_erased("@erased:local_hs") + ) # ... and the filtering happens. filtered = self.get_success( diff --git a/tests/unittest.py b/tests/unittest.py index 7983c1e8b8..0caa8e7a45 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -280,7 +280,7 @@ class HomeserverTestCase(TestCase): # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( - self.hs.get_datastore().add_access_token_to_user( + self.hs.get_datastores().main.add_access_token_to_user( self.helper.auth_user_id, "some_fake_token", None, @@ -337,7 +337,7 @@ class HomeserverTestCase(TestCase): def wait_for_background_updates(self) -> None: """Block until all background database updates have completed.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main while not self.get_success( store.db_pool.updates.has_completed_background_updates() ): @@ -504,7 +504,7 @@ class HomeserverTestCase(TestCase): self.get_success(stor.db_pool.updates.run_background_updates(False)) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) - stor = hs.get_datastore() + stor = hs.get_datastores().main # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": @@ -722,14 +722,16 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", ) ) - self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,)) + self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate( + (room_id,) + ) def attempt_wrong_password_login(self, username, password): """Attempts to login as the user with the given password, asserting @@ -775,7 +777,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) self.get_success( - hs.get_datastore().store_server_verify_keys( + hs.get_datastores().main.store_server_verify_keys( from_server=self.OTHER_SERVER_NAME, ts_added_ms=clock.time_msec(), verify_keys=[ diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 9e1bebdc83..26cb71c640 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -24,7 +24,7 @@ from tests.unittest import HomeserverTestCase class RetryLimiterTestCase(HomeserverTestCase): def test_new_destination(self): """A happy-path case with a new destination and a successful operation""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) # advance the clock a bit before making the request @@ -38,7 +38,7 @@ class RetryLimiterTestCase(HomeserverTestCase): def test_limiter(self): """General test case which walks through the process of a failing request""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) diff --git a/tests/utils.py b/tests/utils.py index c06fc320f3..ef99c72e0b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -367,7 +367,7 @@ async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" persistence_store = hs.get_storage().persistence - store = hs.get_datastore() + store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() From 5b2b36809fc3543ed0c9ec587398a09f2e176265 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 23 Feb 2022 12:35:53 +0000 Subject: [PATCH 07/40] Remove more references to `get_datastore` (#12067) These have snuck in since #12031 was started. Also a couple of other cleanups while we're in the area. --- changelog.d/12067.feature | 1 + synapse/handlers/account.py | 4 ++-- synapse/rest/client/account.py | 3 --- tests/rest/client/test_account.py | 4 +++- 4 files changed, 6 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12067.feature diff --git a/changelog.d/12067.feature b/changelog.d/12067.feature new file mode 100644 index 0000000000..dc1153c49e --- /dev/null +++ b/changelog.d/12067.feature @@ -0,0 +1 @@ +Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py index f8cfe9f6de..d5badf635b 100644 --- a/synapse/handlers/account.py +++ b/synapse/handlers/account.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: class AccountHandler: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._main_store = hs.get_datastores().main self._is_mine = hs.is_mine self._federation_client = hs.get_federation_client() @@ -98,7 +98,7 @@ class AccountHandler: """ status = {"exists": False} - userinfo = await self._store.get_userinfo_by_id(user_id.to_string()) + userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string()) if userinfo is not None: status = { diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 4b217882e8..5587cae98a 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -904,9 +904,6 @@ class AccountStatusRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._store = hs.get_datastore() - self._is_mine = hs.is_mine - self._federation_client = hs.get_federation_client() self._account_handler = hs.get_account_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index aa019c9a44..008d635b70 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -1119,7 +1119,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that the account status endpoint correctly reports a deactivated user.""" user = self.register_user("someuser", "password") self.get_success( - self.hs.get_datastore().set_user_deactivated_status(user, deactivated=True) + self.hs.get_datastores().main.set_user_deactivated_status( + user, deactivated=True + ) ) self._test_status( From 64c73c6ac88a740ee480a0ad1f9afc8596bccfa4 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 23 Feb 2022 14:33:19 +0100 Subject: [PATCH 08/40] Add type hints to `tests/rest/client` (#12066) --- changelog.d/12066.misc | 1 + tests/rest/client/test_auth.py | 70 ++++++++------- tests/rest/client/test_capabilities.py | 30 ++++--- tests/rest/client/test_login.py | 120 ++++++++++++++----------- tests/rest/client/test_sync.py | 47 +++++----- 5 files changed, 149 insertions(+), 119 deletions(-) create mode 100644 changelog.d/12066.misc diff --git a/changelog.d/12066.misc b/changelog.d/12066.misc new file mode 100644 index 0000000000..0360dbd61e --- /dev/null +++ b/changelog.d/12066.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 4a68d66573..9653f45837 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -13,17 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from typing import Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from twisted.internet.defer import succeed +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict, UserID +from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC @@ -33,11 +37,11 @@ from tests.unittest import override_config, skip_unless class DummyRecaptchaChecker(UserInteractiveAuthChecker): - def __init__(self, hs): + def __init__(self, hs: HomeServer) -> None: super().__init__(hs) - self.recaptcha_attempts = [] + self.recaptcha_attempts: List[Tuple[dict, str]] = [] - def check_auth(self, authdict, clientip): + def check_auth(self, authdict: dict, clientip: str) -> Any: self.recaptcha_attempts.append((authdict, clientip)) return succeed(True) @@ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ] hijack_auth = False - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() @@ -61,7 +65,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.recaptcha_checker = DummyRecaptchaChecker(hs) auth_handler = hs.get_auth_handler() auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker @@ -101,7 +105,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): self.assertEqual(len(attempts), 1) self.assertEqual(attempts[0][0]["response"], "a") - def test_fallback_captcha(self): + def test_fallback_captcha(self) -> None: """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( @@ -132,7 +136,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") - def test_complete_operation_unknown_session(self): + def test_complete_operation_unknown_session(self) -> None: """ Attempting to mark an invalid session as complete should error. """ @@ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase): register.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns @@ -182,12 +186,12 @@ class UIAuthTests(unittest.HomeserverTestCase): return config - def create_resource_dict(self): + def create_resource_dict(self) -> Dict[str, Resource]: resource_dict = super().create_resource_dict() resource_dict.update(build_synapse_client_resource_tree(self.hs)) return resource_dict - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) self.device_id = "dev1" @@ -229,7 +233,7 @@ class UIAuthTests(unittest.HomeserverTestCase): return channel - def test_ui_auth(self): + def test_ui_auth(self) -> None: """ Test user interactive authentication outside of registration. """ @@ -259,7 +263,7 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_grandfathered_identifier(self): + def test_grandfathered_identifier(self) -> None: """Check behaviour without "identifier" dict Synapse used to require clients to submit a "user" field for m.login.password @@ -286,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_can_change_body(self): + def test_can_change_body(self) -> None: """ The client dict can be modified during the user interactive authentication session. @@ -325,7 +329,7 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_cannot_change_uri(self): + def test_cannot_change_uri(self) -> None: """ The initial requested URI cannot be modified during the user interactive authentication session. """ @@ -362,7 +366,7 @@ class UIAuthTests(unittest.HomeserverTestCase): ) @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) - def test_can_reuse_session(self): + def test_can_reuse_session(self) -> None: """ The session can be reused if configured. @@ -409,7 +413,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_via_sso(self): + def test_ui_auth_via_sso(self) -> None: """Test a successful UI Auth flow via SSO This includes: @@ -452,7 +456,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_does_not_offer_password_for_sso_user(self): + def test_does_not_offer_password_for_sso_user(self) -> None: login_resp = self.helper.login_via_oidc("username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -464,7 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase): flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) - def test_does_not_offer_sso_for_password_user(self): + def test_does_not_offer_sso_for_password_user(self) -> None: channel = self.delete_device( self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED ) @@ -474,7 +478,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_offers_both_flows_for_upgraded_user(self): + def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) @@ -491,7 +495,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_fails_for_incorrect_sso_user(self): + def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" # log the user in login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) @@ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): ] hijack_auth = False - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) @@ -548,7 +552,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) - def is_access_token_valid(self, access_token) -> bool: + def is_access_token_valid(self, access_token: str) -> bool: """ Checks whether an access token is valid, returning whether it is or not. """ @@ -561,7 +565,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): return code == HTTPStatus.OK - def test_login_issue_refresh_token(self): + def test_login_issue_refresh_token(self) -> None: """ A login response should include a refresh_token only if asked. """ @@ -591,7 +595,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body) - def test_register_issue_refresh_token(self): + def test_register_issue_refresh_token(self) -> None: """ A register response should include a refresh_token only if asked. """ @@ -627,7 +631,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body) - def test_token_refresh(self): + def test_token_refresh(self) -> None: """ A refresh token can be used to issue a new access token. """ @@ -665,7 +669,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): ) @override_config({"refreshable_access_token_lifetime": "1m"}) - def test_refreshable_access_token_expiration(self): + def test_refreshable_access_token_expiration(self) -> None: """ The access token should have some time as specified in the config. """ @@ -722,7 +726,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "nonrefreshable_access_token_lifetime": "10m", } ) - def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): + def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens( + self, + ) -> None: """ Tests that the expiry times for refreshable and non-refreshable access tokens can be different. @@ -782,7 +788,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} ) - def test_refresh_token_expiry(self): + def test_refresh_token_expiry(self) -> None: """ The refresh token can be configured to have a limited lifetime. When that lifetime has ended, the refresh token can no longer be used to @@ -834,7 +840,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "session_lifetime": "3m", } ) - def test_ultimate_session_expiry(self): + def test_ultimate_session_expiry(self) -> None: """ The session can be configured to have an ultimate, limited lifetime. """ @@ -882,7 +888,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result ) - def test_refresh_token_invalidation(self): + def test_refresh_token_invalidation(self) -> None: """Refresh tokens are invalidated after first use of the next token. A refresh token is considered invalid if: @@ -987,7 +993,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) - def test_many_token_refresh(self): + def test_many_token_refresh(self) -> None: """ If a refresh is performed many times during a session, there shouldn't be extra 'cruft' built up over time. diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index 989e801768..d1751e1557 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -13,9 +13,13 @@ # limitations under the License. from http import HTTPStatus +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.rest.client import capabilities, login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -29,24 +33,24 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.url = b"/capabilities" hs = self.setup_test_homeserver() self.config = hs.config self.auth_handler = hs.get_auth_handler() return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.localpart = "user" self.password = "pass" self.user = self.register_user(self.localpart, self.password) - def test_check_auth_required(self): + def test_check_auth_required(self) -> None: channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401) - def test_get_room_version_capabilities(self): + def test_get_room_version_capabilities(self) -> None: access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) @@ -61,7 +65,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities["m.room_versions"]["default"], ) - def test_get_change_password_capabilities_password_login(self): + def test_get_change_password_capabilities_password_login(self) -> None: access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) @@ -71,7 +75,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(capabilities["m.change_password"]["enabled"]) @override_config({"password_config": {"localdb_enabled": False}}) - def test_get_change_password_capabilities_localdb_disabled(self): + def test_get_change_password_capabilities_localdb_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -85,7 +89,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.change_password"]["enabled"]) @override_config({"password_config": {"enabled": False}}) - def test_get_change_password_capabilities_password_disabled(self): + def test_get_change_password_capabilities_password_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -98,7 +102,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) - def test_get_change_users_attributes_capabilities(self): + def test_get_change_users_attributes_capabilities(self) -> None: """Test that server returns capabilities by default.""" access_token = self.login(self.localpart, self.password) @@ -112,7 +116,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(capabilities["m.3pid_changes"]["enabled"]) @override_config({"enable_set_displayname": False}) - def test_get_set_displayname_capabilities_displayname_disabled(self): + def test_get_set_displayname_capabilities_displayname_disabled(self) -> None: """Test if set displayname is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -123,7 +127,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.set_displayname"]["enabled"]) @override_config({"enable_set_avatar_url": False}) - def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): + def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None: """Test if set avatar_url is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -134,7 +138,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) @override_config({"enable_3pid_changes": False}) - def test_get_change_3pid_capabilities_3pid_disabled(self): + def test_get_change_3pid_capabilities_3pid_disabled(self) -> None: """Test if change 3pid is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -145,7 +149,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.3pid_changes"]["enabled"]) @override_config({"experimental_features": {"msc3244_enabled": False}}) - def test_get_does_not_include_msc3244_fields_when_disabled(self): + def test_get_does_not_include_msc3244_fields_when_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -160,7 +164,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] ) - def test_get_does_include_msc3244_fields_when_enabled(self): + def test_get_does_include_msc3244_fields_when_enabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 26d0d83e00..d48defda63 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -20,6 +20,7 @@ from urllib.parse import urlencode import pymacaroons +from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin @@ -27,12 +28,15 @@ from synapse.appservice import ApplicationService from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -95,7 +99,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() self.hs.config.registration.enable_registration = True self.hs.config.registration.registrations_require_3pid = [] @@ -117,7 +121,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_address(self): + def test_POST_ratelimiting_per_address(self) -> None: # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -165,7 +169,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_account(self): + def test_POST_ratelimiting_per_account(self) -> None: self.register_user("kermit", "monkey") for i in range(0, 6): @@ -210,7 +214,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_account_failed_attempts(self): + def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: self.register_user("kermit", "monkey") for i in range(0, 6): @@ -243,7 +247,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) @override_config({"session_lifetime": "24h"}) - def test_soft_logout(self): + def test_soft_logout(self) -> None: self.register_user("kermit", "monkey") # we shouldn't be able to make requests without an access token @@ -298,7 +302,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], False) - def _delete_device(self, access_token, user_id, password, device_id): + def _delete_device( + self, access_token: str, user_id: str, password: str, device_id: str + ) -> None: """Perform the UI-Auth to delete a device""" channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token @@ -329,7 +335,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_after_being_soft_logged_out(self): + def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: self.register_user("kermit", "monkey") # log in as normal @@ -353,7 +359,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( + self, + ) -> None: self.register_user("kermit", "monkey") # log in as normal @@ -432,7 +440,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_get_login_flows(self): + def test_get_login_flows(self) -> None: """GET /login should return password and SSO flows""" channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) @@ -459,12 +467,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ], ) - def test_multi_sso_redirect(self): + def test_multi_sso_redirect(self) -> None: """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker channel = self._make_sso_redirect_request(None) self.assertEqual(channel.code, 302, channel.result) - uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + uri = location_headers[0] # hitting that picker should give us some HTML channel = self.make_request("GET", uri) @@ -487,7 +497,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) - def test_multi_sso_redirect_to_cas(self): + def test_multi_sso_redirect_to_cas(self) -> None: """If CAS is chosen, should redirect to the CAS server""" channel = self.make_request( @@ -514,7 +524,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri_params = urllib.parse.parse_qs(service_uri_query) self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_saml(self): + def test_multi_sso_redirect_to_saml(self) -> None: """If SAML is chosen, should redirect to the SAML server""" channel = self.make_request( "GET", @@ -536,7 +546,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): relay_state_param = saml_uri_params["RelayState"][0] self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_login_via_oidc(self): + def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" # pick the default OIDC provider @@ -604,7 +614,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") - def test_multi_sso_redirect_to_unknown(self): + def test_multi_sso_redirect_to_unknown(self) -> None: """An unknown IdP should cause a 400""" channel = self.make_request( "GET", @@ -612,23 +622,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def test_client_idp_redirect_to_unknown(self): + def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" channel = self._make_sso_redirect_request("xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - def test_client_idp_redirect_to_oidc(self): + def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + oidc_uri = location_headers[0] oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - def _make_sso_redirect_request(self, idp_prov: Optional[str] = None): + def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect ... possibly specifying an IDP provider @@ -659,7 +671,7 @@ class CASTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.base_url = "https://matrix.goodserver.com/" self.redirect_path = "_synapse/client/login/sso/redirect/confirm" @@ -675,7 +687,7 @@ class CASTestCase(unittest.HomeserverTestCase): cas_user_id = "username" self.user_id = "@%s:test" % cas_user_id - async def get_raw(uri, args): + async def get_raw(uri: str, args: Any) -> bytes: """Return an example response payload from a call to the `/proxyValidate` endpoint of a CAS server, copied from https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 @@ -709,10 +721,10 @@ class CASTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.deactivate_account_handler = hs.get_deactivate_account_handler() - def test_cas_redirect_confirm(self): + def test_cas_redirect_confirm(self) -> None: """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. """ @@ -754,15 +766,15 @@ class CASTestCase(unittest.HomeserverTestCase): } } ) - def test_cas_redirect_whitelisted(self): + def test_cas_redirect_whitelisted(self) -> None: """Tests that the SSO login flow serves a redirect to a whitelisted url""" self._test_redirect("https://legit-site.com/") @override_config({"public_baseurl": "https://example.com"}) - def test_cas_redirect_login_fallback(self): + def test_cas_redirect_login_fallback(self) -> None: self._test_redirect("https://example.com/_matrix/static/client/login") - def _test_redirect(self, redirect_url): + def _test_redirect(self, redirect_url: str) -> None: """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" @@ -778,7 +790,7 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) - def test_deactivated_user(self): + def test_deactivated_user(self) -> None: """Logging in as a deactivated account should error.""" redirect_url = "https://legit-site.com/" @@ -821,7 +833,7 @@ class JWTTestCase(unittest.HomeserverTestCase): "algorithm": jwt_algorithm, } - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # If jwt_config has been defined (eg via @override_config), don't replace it. @@ -837,23 +849,23 @@ class JWTTestCase(unittest.HomeserverTestCase): return result.decode("ascii") return result - def jwt_login(self, *args): + def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid_registered(self): + def test_login_jwt_valid_registered(self) -> None: self.register_user("kermit", "monkey") channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_valid_unregistered(self): + def test_login_jwt_valid_unregistered(self) -> None: channel = self.jwt_login({"sub": "frog"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_jwt_invalid_signature(self): + def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, "notsecret") self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -862,7 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: Signature verification failed", ) - def test_login_jwt_expired(self): + def test_login_jwt_expired(self) -> None: channel = self.jwt_login({"sub": "frog", "exp": 864000}) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -870,7 +882,7 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Signature has expired" ) - def test_login_jwt_not_before(self): + def test_login_jwt_not_before(self) -> None: now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -880,14 +892,14 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: The token is not yet valid (nbf)", ) - def test_login_no_sub(self): + def test_login_no_sub(self) -> None: channel = self.jwt_login({"username": "root"}) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) - def test_login_iss(self): + def test_login_iss(self) -> None: """Test validating the issuer claim.""" # A valid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) @@ -911,14 +923,14 @@ class JWTTestCase(unittest.HomeserverTestCase): 'JWT validation failed: Token is missing the "iss" claim', ) - def test_login_iss_no_config(self): + def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) - def test_login_aud(self): + def test_login_aud(self) -> None: """Test validating the audience claim.""" # A valid audience. channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) @@ -942,7 +954,7 @@ class JWTTestCase(unittest.HomeserverTestCase): 'JWT validation failed: Token is missing the "aud" claim', ) - def test_login_aud_no_config(self): + def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -951,20 +963,20 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Invalid audience" ) - def test_login_default_sub(self): + def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) - def test_login_custom_sub(self): + def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" channel = self.jwt_login({"username": "frog"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_no_token(self): + def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -1026,7 +1038,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): ] ) - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["jwt_config"] = { "enabled": True, @@ -1042,17 +1054,17 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return result.decode("ascii") return result - def jwt_login(self, *args): + def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid(self): + def test_login_jwt_valid(self) -> None: channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_invalid_signature(self): + def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -1071,7 +1083,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() self.service = ApplicationService( @@ -1105,7 +1117,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.services_cache.append(self.another_service) return self.hs - def test_login_appservice_user(self): + def test_login_appservice_user(self) -> None: """Test that an appservice user can use /login""" self.register_appservice_user(AS_USER, self.service.token) @@ -1119,7 +1131,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) - def test_login_appservice_user_bot(self): + def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" self.register_appservice_user(AS_USER, self.service.token) @@ -1133,7 +1145,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) - def test_login_appservice_wrong_user(self): + def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" self.register_appservice_user(AS_USER, self.service.token) @@ -1147,7 +1159,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) - def test_login_appservice_wrong_as(self): + def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" self.register_appservice_user(AS_USER, self.service.token) @@ -1161,7 +1173,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) - def test_login_appservice_no_token(self): + def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice login method """ @@ -1182,7 +1194,7 @@ class UsernamePickerTestCase(HomeserverTestCase): servlets = [login.register_servlets] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL @@ -1202,7 +1214,7 @@ class UsernamePickerTestCase(HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_username_picker(self): + def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" # do the start of the login flow diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index e062561365..69b4ef5378 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from typing import List, Optional from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import ( EventContentFields, @@ -24,6 +27,9 @@ from synapse.api.constants import ( RelationTypes, ) from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.federation.transport.test_knocking import ( @@ -43,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_sync_argless(self): + def test_sync_argless(self) -> None: channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) @@ -58,7 +64,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_sync_filter_labels(self): + def test_sync_filter_labels(self) -> None: """Test that we can filter by a label.""" sync_filter = json.dumps( { @@ -77,7 +83,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - def test_sync_filter_not_labels(self): + def test_sync_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label.""" sync_filter = json.dumps( { @@ -99,7 +105,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): events[2]["content"]["body"], "with two wrong labels", events[2] ) - def test_sync_filter_labels_not_labels(self): + def test_sync_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label.""" sync_filter = json.dumps( { @@ -118,7 +124,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - def _test_sync_filter_labels(self, sync_filter): + def _test_sync_filter_labels(self, sync_filter: str) -> List[JsonDict]: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -194,7 +200,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def test_sync_backwards_typing(self): + def test_sync_backwards_typing(self) -> None: """ If the typing serial goes backwards and the typing handler is then reset (such as when the master restarts and sets the typing serial to 0), we @@ -298,7 +304,7 @@ class SyncKnockTestCase( knock.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.url = "/sync?since=%s" self.next_batch = "s0" @@ -336,7 +342,7 @@ class SyncKnockTestCase( ) @override_config({"experimental_features": {"msc2403_enabled": True}}) - def test_knock_room_state(self): + def test_knock_room_state(self) -> None: """Tests that /sync returns state from a room after knocking on it.""" # Knock on a room channel = self.make_request( @@ -383,7 +389,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -402,7 +408,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_hidden_read_receipts(self): + def test_hidden_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -441,8 +447,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): ] ) def test_read_receipt_with_empty_body( - self, name, user_agent: str, expected_status_code: int - ): + self, name: str, user_agent: str, expected_status_code: int + ) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -455,11 +461,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, expected_status_code) - def _get_read_receipt(self): + def _get_read_receipt(self) -> Optional[JsonDict]: """Syncs and returns the read receipt.""" # Checks if event is a read receipt - def is_read_receipt(event): + def is_read_receipt(event: JsonDict) -> bool: return event["type"] == "m.receipt" # Sync @@ -477,7 +483,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ "ephemeral" ]["events"] - return next(filter(is_read_receipt, ephemeral_events), None) + receipt_event = filter(is_read_receipt, ephemeral_events) + return next(receipt_event, None) class UnreadMessagesTestCase(unittest.HomeserverTestCase): @@ -490,7 +497,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -533,7 +540,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): tok=self.tok, ) - def test_unread_counts(self): + def test_unread_counts(self) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" # Check that our own messages don't increase the unread count. @@ -640,7 +647,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): ) self._check_unread_count(5) - def _check_unread_count(self, expected_count: int): + def _check_unread_count(self, expected_count: int) -> None: """Syncs and compares the unread count with the expected value.""" channel = self.make_request( @@ -669,7 +676,7 @@ class SyncCacheTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_noop_sync_does_not_tightloop(self): + def test_noop_sync_does_not_tightloop(self) -> None: """If the sync times out, we shouldn't cache the result Essentially a regression test for #8518. @@ -720,7 +727,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): devices.register_servlets, ] - def test_user_with_no_rooms_receives_self_device_list_updates(self): + def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: """Tests that a user with no rooms still receives their own device list updates""" device_id = "TESTDEVICE" From a711ae78a8f8ba406ff122035c8bf096fac9a26c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Feb 2022 14:22:22 +0000 Subject: [PATCH 09/40] Add logging to `/sync` for debugging #11916 (#12068) --- changelog.d/12068.misc | 1 + synapse/handlers/sync.py | 9 +++++++++ 2 files changed, 10 insertions(+) create mode 100644 changelog.d/12068.misc diff --git a/changelog.d/12068.misc b/changelog.d/12068.misc new file mode 100644 index 0000000000..72b211e4f5 --- /dev/null +++ b/changelog.d/12068.misc @@ -0,0 +1 @@ +Add some logging to `/sync` to try and track down #11916. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 98eaad3318..0aa3052fd6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -697,6 +697,15 @@ class SyncHandler: else: # no events in this room - so presumably no state state = {} + + # (erikj) This should be rarely hit, but we've had some reports that + # we get more state down gappy syncs than we should, so let's add + # some logging. + logger.info( + "Failed to find any events in room %s at %s", + room_id, + stream_position.room_key, + ) return state async def compute_summary( From c56bfb08bc071368db23f3b1c593724eb4f205f0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 23 Feb 2022 17:49:04 -0500 Subject: [PATCH 10/40] Add documentation for missing worker types. (#11599) And clean-up the endpoints which should be routed to workers. --- changelog.d/11599.doc | 1 + docs/workers.md | 90 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 77 insertions(+), 14 deletions(-) create mode 100644 changelog.d/11599.doc diff --git a/changelog.d/11599.doc b/changelog.d/11599.doc new file mode 100644 index 0000000000..f07cfbef4e --- /dev/null +++ b/changelog.d/11599.doc @@ -0,0 +1 @@ +Document support for the `to_device`, `account_data`, `receipts`, and `presence` stream writers for workers. diff --git a/docs/workers.md b/docs/workers.md index dadde4d726..b82a6900ac 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -178,8 +178,11 @@ recommend the use of `systemd` where available: for information on setting up ### `synapse.app.generic_worker` -This worker can handle API requests matching the following regular -expressions: +This worker can handle API requests matching the following regular expressions. +These endpoints can be routed to any worker. If a worker is set up to handle a +stream then, for maximum efficiency, additional endpoints should be routed to that +worker: refer to the [stream writers](#stream-writers) section below for further +information. # Sync requests ^/_matrix/client/(v2_alpha|r0|v3)/sync$ @@ -225,19 +228,23 @@ expressions: ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$ ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/devices$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/query$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/changes$ + ^/_matrix/client/(r0|v3|unstable)/account/3pid$ + ^/_matrix/client/(r0|v3|unstable)/devices$ ^/_matrix/client/versions$ ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_groups$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups/ + ^/_matrix/client/(r0|v3|unstable)/joined_groups$ + ^/_matrix/client/(r0|v3|unstable)/publicised_groups$ + ^/_matrix/client/(r0|v3|unstable)/publicised_groups/ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ + # Encryption requests + ^/_matrix/client/(r0|v3|unstable)/keys/query$ + ^/_matrix/client/(r0|v3|unstable)/keys/changes$ + ^/_matrix/client/(r0|v3|unstable)/keys/claim$ + ^/_matrix/client/(r0|v3|unstable)/room_keys/ + # Registration/login requests ^/_matrix/client/(api/v1|r0|v3|unstable)/login$ ^/_matrix/client/(r0|v3|unstable)/register$ @@ -251,6 +258,20 @@ expressions: ^/_matrix/client/(api/v1|r0|v3|unstable)/join/ ^/_matrix/client/(api/v1|r0|v3|unstable)/profile/ + # Device requests + ^/_matrix/client/(r0|v3|unstable)/sendToDevice/ + + # Account data requests + ^/_matrix/client/(r0|v3|unstable)/.*/tags + ^/_matrix/client/(r0|v3|unstable)/.*/account_data + + # Receipts requests + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers + + # Presence requests + ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ + Additionally, the following REST endpoints can be handled for GET requests: @@ -330,12 +351,10 @@ Additionally, there is *experimental* support for moving writing of specific streams (such as events) off of the main process to a particular worker. (This is only supported with Redis-based replication.) -Currently supported streams are `events` and `typing`. - To enable this, the worker must have a HTTP replication listener configured, -have a `worker_name` and be listed in the `instance_map` config. For example to -move event persistence off to a dedicated worker, the shared configuration would -include: +have a `worker_name` and be listed in the `instance_map` config. The same worker +can handle multiple streams. For example, to move event persistence off to a +dedicated worker, the shared configuration would include: ```yaml instance_map: @@ -347,6 +366,12 @@ stream_writers: events: event_persister1 ``` +Some of the streams have associated endpoints which, for maximum efficiency, should +be routed to the workers handling that stream. See below for the currently supported +streams and the endpoints associated with them: + +##### The `events` stream + The `events` stream also experimentally supports having multiple writers, where work is sharded between them by room ID. Note that you *must* restart all worker instances when adding or removing event persisters. An example `stream_writers` @@ -359,6 +384,43 @@ stream_writers: - event_persister2 ``` +##### The `typing` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `typing` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing + +##### The `to_device` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `to_device` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/ + +##### The `account_data` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `account_data` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags + ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data + +##### The `receipts` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `receipts` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers + +##### The `presence` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `presence` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ + #### Background tasks There is also *experimental* support for moving background tasks to a separate From 41cf4c2cf6432336cc7477f130a2847449cff99a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 24 Feb 2022 11:52:28 +0000 Subject: [PATCH 11/40] Fix non-strings in the `event_search` table (#12037) Don't attempt to add non-string `value`s to `event_search` and add a background update to clear out bad rows from `event_search` when using sqlite. Signed-off-by: Sean Quah --- changelog.d/12037.bugfix | 1 + synapse/storage/databases/main/events.py | 18 +-- synapse/storage/databases/main/search.py | 26 ++++ ...e_non_strings_from_event_search.sql.sqlite | 22 ++++ tests/storage/test_room_search.py | 117 +++++++++++++++++- 5 files changed, 173 insertions(+), 11 deletions(-) create mode 100644 changelog.d/12037.bugfix create mode 100644 synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite diff --git a/changelog.d/12037.bugfix b/changelog.d/12037.bugfix new file mode 100644 index 0000000000..9295cb4dc0 --- /dev/null +++ b/changelog.d/12037.bugfix @@ -0,0 +1 @@ +Properly fix a long-standing bug where wrong data could be inserted in the `event_search` table when using sqlite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a1d7a9b413..e53e84054a 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1473,10 +1473,10 @@ class PersistEventsStore: def _update_metadata_tables_txn( self, - txn, + txn: LoggingTransaction, *, - events_and_contexts, - all_events_and_contexts, + events_and_contexts: List[Tuple[EventBase, EventContext]], + all_events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, ): """Update all the miscellaneous tables for new events @@ -1953,20 +1953,20 @@ class PersistEventsStore: txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) - def _store_room_topic_txn(self, txn, event): - if hasattr(event, "content") and "topic" in event.content: + def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("topic"), str): self.store_event_search_txn( txn, event, "content.topic", event.content["topic"] ) - def _store_room_name_txn(self, txn, event): - if hasattr(event, "content") and "name" in event.content: + def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("name"), str): self.store_event_search_txn( txn, event, "content.name", event.content["name"] ) - def _store_room_message_txn(self, txn, event): - if hasattr(event, "content") and "body" in event.content: + def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("body"), str): self.store_event_search_txn( txn, event, "content.body", event.content["body"] ) diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index acea300ed3..e23b119072 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -115,6 +115,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" + EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings" def __init__( self, @@ -147,6 +148,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) + self.db_pool.updates.register_background_update_handler( + self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings + ) + async def _background_reindex_search(self, progress, batch_size): # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] @@ -372,6 +377,27 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): return num_rows + async def _background_delete_non_strings( + self, progress: JsonDict, batch_size: int + ) -> int: + """Deletes rows with non-string `value`s from `event_search` if using sqlite. + + Prior to Synapse 1.44.0, malformed events received over federation could cause integers + to be inserted into the `event_search` table when using sqlite. + """ + + def delete_non_strings_txn(txn: LoggingTransaction) -> None: + txn.execute("DELETE FROM event_search WHERE typeof(value) != 'text'") + + await self.db_pool.runInteraction( + self.EVENT_SEARCH_DELETE_NON_STRINGS, delete_non_strings_txn + ) + + await self.db_pool.updates._end_background_update( + self.EVENT_SEARCH_DELETE_NON_STRINGS + ) + return 1 + class SearchStore(SearchBackgroundUpdateStore): def __init__( diff --git a/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite b/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite new file mode 100644 index 0000000000..140df65264 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite @@ -0,0 +1,22 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Delete rows with non-string `value`s from `event_search` if using sqlite. +-- +-- Prior to Synapse 1.44.0, malformed events received over federation could +-- cause integers to be inserted into the `event_search` table. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6805, 'event_search_sqlite_delete_non_strings', '{}'); diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index befaa0fcee..d62e01726c 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -13,13 +13,16 @@ # limitations under the License. import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.api.errors import StoreError from synapse.rest.client import login, room from synapse.storage.engines import PostgresEngine -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, skip_unless +from tests.utils import USE_POSTGRES_FOR_TESTS -class NullByteInsertionTest(HomeserverTestCase): +class EventSearchInsertionTest(HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -72,3 +75,113 @@ class NullByteInsertionTest(HomeserverTestCase): ) if isinstance(store.database_engine, PostgresEngine): self.assertIn("alice", result.get("highlights")) + + def test_non_string(self): + """Test that non-string `value`s are not inserted into `event_search`. + + This is particularly important when using sqlite, since a sqlite column can hold + both strings and integers. When using Postgres, integers are automatically + converted to strings. + + Regression test for #11918. + """ + store = self.hs.get_datastores().main + + # Register a user and create a room + user_id = self.register_user("alice", "password") + access_token = self.login("alice", "password") + room_id = self.helper.create_room_as("alice", tok=access_token) + room_version = self.get_success(store.get_room_version(room_id)) + + # Construct a message with a numeric body to be received over federation + # The message can't be sent using the client API, since Synapse's event + # validation will reject it. + prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) + prev_event = self.get_success(store.get_event(prev_event_ids[0])) + prev_state_map = self.get_success( + self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) + ) + + event_dict = { + "type": EventTypes.Message, + "content": {"msgtype": "m.text", "body": 2}, + "room_id": room_id, + "sender": user_id, + "depth": prev_event.depth + 1, + "prev_events": prev_event_ids, + "origin_server_ts": self.clock.time_msec(), + } + builder = self.hs.get_event_builder_factory().for_room_version( + room_version, event_dict + ) + event = self.get_success( + builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=self.hs.get_event_auth_handler().compute_auth_events( + builder, + prev_state_map, + for_verification=False, + ), + depth=event_dict["depth"], + ) + ) + + # Receive the event + self.get_success( + self.hs.get_federation_event_handler().on_receive_pdu( + self.hs.hostname, event + ) + ) + + # The event should not have an entry in the `event_search` table + f = self.get_failure( + store.db_pool.simple_select_one_onecol( + "event_search", + {"room_id": room_id, "event_id": event.event_id}, + "event_id", + ), + StoreError, + ) + self.assertEqual(f.value.code, 404) + + @skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite") + def test_sqlite_non_string_deletion_background_update(self): + """Test the background update to delete bad rows from `event_search`.""" + store = self.hs.get_datastores().main + + # Populate `event_search` with dummy data + self.get_success( + store.db_pool.simple_insert_many( + "event_search", + keys=["event_id", "room_id", "key", "value"], + values=[ + ("event1", "room_id", "content.body", "hi"), + ("event2", "room_id", "content.body", "2"), + ("event3", "room_id", "content.body", 3), + ], + desc="populate_event_search", + ) + ) + + # Run the background update + store.db_pool.updates._all_done = False + self.get_success( + store.db_pool.simple_insert( + "background_updates", + { + "update_name": "event_search_sqlite_delete_non_strings", + "progress_json": "{}", + }, + ) + ) + self.wait_for_background_updates() + + # The non-string `value`s ought to be gone now. + values = self.get_success( + store.db_pool.simple_select_onecol( + "event_search", + {"room_id": "room_id"}, + "value", + ), + ) + self.assertCountEqual(values, ["hi", "2"]) From 2cc5ea933dbe65445e3711bb3f05022b007029ea Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 24 Feb 2022 17:55:45 +0000 Subject: [PATCH 12/40] Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. (#11617) Co-authored-by: Erik Johnston --- changelog.d/11617.feature | 1 + synapse/appservice/__init__.py | 16 ++ synapse/appservice/api.py | 20 +- synapse/appservice/scheduler.py | 98 ++++++++- synapse/config/appservice.py | 13 +- synapse/config/experimental.py | 16 +- synapse/storage/databases/main/appservice.py | 33 ++- .../storage/databases/main/end_to_end_keys.py | 112 ++++++++++ tests/appservice/test_scheduler.py | 55 +++-- tests/handlers/test_appservice.py | 194 +++++++++++++++++- tests/storage/test_appservice.py | 8 +- 11 files changed, 528 insertions(+), 38 deletions(-) create mode 100644 changelog.d/11617.feature diff --git a/changelog.d/11617.feature b/changelog.d/11617.feature new file mode 100644 index 0000000000..cf03f00e7c --- /dev/null +++ b/changelog.d/11617.feature @@ -0,0 +1 @@ +Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. \ No newline at end of file diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index a340a8c9c7..4d3f8e4923 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -31,6 +31,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Type for the `device_one_time_key_counts` field in an appservice transaction +# user ID -> {device ID -> {algorithm -> count}} +TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] + +# Type for the `device_unused_fallback_keys` field in an appservice transaction +# user ID -> {device ID -> [algorithm]} +TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]] + class ApplicationServiceState(Enum): DOWN = "down" @@ -72,6 +80,7 @@ class ApplicationService: rate_limited: bool = True, ip_range_whitelist: Optional[IPSet] = None, supports_ephemeral: bool = False, + msc3202_transaction_extensions: bool = False, ): self.token = token self.url = ( @@ -84,6 +93,7 @@ class ApplicationService: self.id = id self.ip_range_whitelist = ip_range_whitelist self.supports_ephemeral = supports_ephemeral + self.msc3202_transaction_extensions = msc3202_transaction_extensions if "|" in self.id: raise Exception("application service ID cannot contain '|' character") @@ -339,12 +349,16 @@ class AppServiceTransaction: events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ): self.service = service self.id = id self.events = events self.ephemeral = ephemeral self.to_device_messages = to_device_messages + self.one_time_key_counts = one_time_key_counts + self.unused_fallback_keys = unused_fallback_keys async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -359,6 +373,8 @@ class AppServiceTransaction: events=self.events, ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, + one_time_key_counts=self.one_time_key_counts, + unused_fallback_keys=self.unused_fallback_keys, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 73be7ff3d4..a0ea958af6 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -19,6 +19,11 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient @@ -26,7 +31,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: - from synapse.appservice import ApplicationService from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -219,6 +223,8 @@ class ApplicationServiceApi(SimpleHttpClient): events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, txn_id: Optional[int] = None, ) -> bool: """ @@ -252,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) # Never send ephemeral events to appservices that do not support it - body: Dict[str, List[JsonDict]] = {"events": serialized_events} + body: JsonDict = {"events": serialized_events} if service.supports_ephemeral: body.update( { @@ -262,6 +268,16 @@ class ApplicationServiceApi(SimpleHttpClient): } ) + if service.msc3202_transaction_extensions: + if one_time_key_counts: + body[ + "org.matrix.msc3202.device_one_time_key_counts" + ] = one_time_key_counts + if unused_fallback_keys: + body[ + "org.matrix.msc3202.device_unused_fallback_keys" + ] = unused_fallback_keys + try: await self.put_json( uri=uri, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index b4e602e880..72417151ba 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -54,12 +54,19 @@ from typing import ( Callable, Collection, Dict, + Iterable, List, Optional, Set, + Tuple, ) -from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.appservice.api import ApplicationServiceApi from synapse.events import EventBase from synapse.logging.context import run_in_background @@ -96,7 +103,7 @@ class ApplicationServiceScheduler: self.as_api = hs.get_application_service_api() self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) - self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) + self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs) async def start(self) -> None: logger.info("Starting appservice scheduler") @@ -153,7 +160,9 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__(self, txn_ctrl: "_TransactionController", clock: Clock): + def __init__( + self, txn_ctrl: "_TransactionController", clock: Clock, hs: "HomeServer" + ): # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [events]} @@ -165,6 +174,10 @@ class _ServiceQueuer: self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock + self._msc3202_transaction_extensions_enabled: bool = ( + hs.config.experimental.msc3202_transaction_extensions + ) + self._store = hs.get_datastores().main def start_background_request(self, service: ApplicationService) -> None: # start a sender for this appservice if we don't already have one @@ -202,15 +215,84 @@ class _ServiceQueuer: if not events and not ephemeral and not to_device_messages_to_send: return + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None + + if ( + self._msc3202_transaction_extensions_enabled + and service.msc3202_transaction_extensions + ): + # Compute the one-time key counts and fallback key usage states + # for the users which are mentioned in this transaction, + # as well as the appservice's sender. + ( + one_time_key_counts, + unused_fallback_keys, + ) = await self._compute_msc3202_otk_counts_and_fallback_keys( + service, events, ephemeral, to_device_messages_to_send + ) + try: await self.txn_ctrl.send( - service, events, ephemeral, to_device_messages_to_send + service, + events, + ephemeral, + to_device_messages_to_send, + one_time_key_counts, + unused_fallback_keys, ) except Exception: logger.exception("AS request failed") finally: self.requests_in_flight.discard(service.id) + async def _compute_msc3202_otk_counts_and_fallback_keys( + self, + service: ApplicationService, + events: Iterable[EventBase], + ephemerals: Iterable[JsonDict], + to_device_messages: Iterable[JsonDict], + ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + """ + Given a list of the events, ephemeral messages and to-device messages, + - first computes a list of application services users that may have + interesting updates to the one-time key counts or fallback key usage. + - then computes one-time key counts and fallback key usages for those users. + Given a list of application service users that are interesting, + compute one-time key counts and fallback key usages for the users. + """ + + # Set of 'interesting' users who may have updates + users: Set[str] = set() + + # The sender is always included + users.add(service.sender) + + # All AS users that would receive the PDUs or EDUs sent to these rooms + # are classed as 'interesting'. + rooms_of_interesting_users: Set[str] = set() + # PDUs + rooms_of_interesting_users.update(event.room_id for event in events) + # EDUs + rooms_of_interesting_users.update( + ephemeral["room_id"] for ephemeral in ephemerals + ) + + # Look up the AS users in those rooms + for room_id in rooms_of_interesting_users: + users.update( + await self._store.get_app_service_users_in_room(room_id, service) + ) + + # Add recipients of to-device messages. + # device_message["user_id"] is the ID of the recipient. + users.update(device_message["user_id"] for device_message in to_device_messages) + + # Compute and return the counts / fallback key usage states + otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users) + unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users) + return otk_counts, unused_fbks + class _TransactionController: """Transaction manager. @@ -238,6 +320,8 @@ class _TransactionController: events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -248,6 +332,10 @@ class _TransactionController: events: The persistent events to include in the transaction. ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -255,6 +343,8 @@ class _TransactionController: events=events, ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], + one_time_key_counts=one_time_key_counts or {}, + unused_fallback_keys=unused_fallback_keys or {}, ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 7fad2e0422..439bfe1526 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -166,6 +166,16 @@ def _load_appservice( supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) + # Opt-in flag for the MSC3202-specific transactional behaviour. + # When enabled, appservice transactions contain the following information: + # - device One-Time Key counts + # - device unused fallback key usage states + msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) + if not isinstance(msc3202_transaction_extensions, bool): + raise ValueError( + "The `org.matrix.msc3202` option should be true or false if specified." + ) + return ApplicationService( token=as_info["as_token"], hostname=hostname, @@ -174,8 +184,9 @@ def _load_appservice( hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], - supports_ephemeral=supports_ephemeral, protocols=protocols, rate_limited=rate_limited, ip_range_whitelist=ip_range_whitelist, + supports_ephemeral=supports_ephemeral, + msc3202_transaction_extensions=msc3202_transaction_extensions, ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 772eb35013..41338b39df 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -47,11 +47,6 @@ class ExperimentalConfig(Config): # MSC3030 (Jump to date API endpoint) self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) - # The portion of MSC3202 which is related to device masquerading. - self.msc3202_device_masquerading_enabled: bool = experimental.get( - "msc3202_device_masquerading", False - ) - # MSC2409 (this setting only relates to optionally sending to-device messages). # Presence, typing and read receipt EDUs are already sent to application services that # have opted in to receive them. If enabled, this adds to-device messages to that list. @@ -59,6 +54,17 @@ class ExperimentalConfig(Config): "msc2409_to_device_messages_enabled", False ) + # The portion of MSC3202 which is related to device masquerading. + self.msc3202_device_masquerading_enabled: bool = experimental.get( + "msc3202_device_masquerading", False + ) + + # Portion of MSC3202 related to transaction extensions: + # sending one-time key counts and fallback key usage to application services. + self.msc3202_transaction_extensions: bool = experimental.get( + "msc3202_transaction_extensions", False + ) + # MSC3706 (server-side support for partial state in /send_join responses) self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 304814af5d..0694446558 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,14 +20,18 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, AppServiceTransaction, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices from synapse.events import EventBase -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.types import JsonDict from synapse.util import json_encoder +from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -56,7 +60,7 @@ def _make_exclusive_regex( return exclusive_user_pattern -class ApplicationServiceWorkerStore(SQLBaseStore): +class ApplicationServiceWorkerStore(RoomMemberWorkerStore): def __init__( self, database: DatabasePool, @@ -124,6 +128,18 @@ class ApplicationServiceWorkerStore(SQLBaseStore): return service return None + @cached(iterable=True, cache_context=True) + async def get_app_service_users_in_room( + self, + room_id: str, + app_service: "ApplicationService", + cache_context: _CacheContext, + ) -> List[str]: + users_in_room = await self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + return list(filter(app_service.is_interested_in_user, users_in_room)) + class ApplicationServiceStore(ApplicationServiceWorkerStore): # This is currently empty due to there not being any AS storage functions @@ -199,6 +215,8 @@ class ApplicationServiceTransactionWorkerStore( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -209,6 +227,10 @@ class ApplicationServiceTransactionWorkerStore( events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. Returns: A new transaction. @@ -244,6 +266,8 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=ephemeral, to_device_messages=to_device_messages, + one_time_key_counts=one_time_key_counts, + unused_fallback_keys=unused_fallback_keys, ) return await self.db_pool.runInteraction( @@ -335,12 +359,17 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) + # TODO: to-device messages, one-time key counts and unused fallback keys + # are not yet populated for catch-up transactions. + # We likely want to populate those for reliability. return AppServiceTransaction( service=service, id=entry["txn_id"], events=events, ephemeral=[], to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1f8447b507..9b293475c8 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -29,6 +29,10 @@ import attr from canonicaljson import encode_canonical_json from synapse.api.constants import DeviceKeyAlgorithms +from synapse.appservice import ( + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -439,6 +443,114 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def count_bulk_e2e_one_time_keys_for_as( + self, user_ids: Collection[str] + ) -> TransactionOneTimeKeyCounts: + """ + Counts, in bulk, the one-time keys for all the users specified. + Intended to be used by application services for populating OTK counts in + transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithm -> count + Empty algorithm -> count dicts are created if needed to represent a + lack of unused one-time keys. + """ + + def _count_bulk_e2e_one_time_keys_txn( + txn: LoggingTransaction, + ) -> TransactionOneTimeKeyCounts: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + sql = f""" + SELECT user_id, device_id, algorithm, COUNT(key_id) + FROM devices + LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id) + WHERE {user_in_where_clause} + GROUP BY user_id, device_id, algorithm + """ + txn.execute(sql, user_parameters) + + result: TransactionOneTimeKeyCounts = {} + + for user_id, device_id, algorithm, count in txn: + # We deliberately construct empty dictionaries for + # users and devices without any unused one-time keys. + # We *could* omit these empty dicts if there have been no + # changes since the last transaction, but we currently don't + # do any change tracking! + device_count_by_algo = result.setdefault(user_id, {}).setdefault( + device_id, {} + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_count_by_algo[algorithm] = count + + return result + + return await self.db_pool.runInteraction( + "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn + ) + + async def get_e2e_bulk_unused_fallback_key_types( + self, user_ids: Collection[str] + ) -> TransactionUnusedFallbackKeys: + """ + Finds, in bulk, the types of unused fallback keys for all the users specified. + Intended to be used by application services for populating unused fallback + keys in transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithms + Empty lists are created for devices if there are no unused fallback + keys. This matches the response structure of MSC3202. + """ + if len(user_ids) == 0: + return {} + + def _get_bulk_e2e_unused_fallback_keys_txn( + txn: LoggingTransaction, + ) -> TransactionUnusedFallbackKeys: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "devices.user_id", user_ids + ) + # We can't use USING here because we require the `.used` condition + # to be part of the JOIN condition so that we generate empty lists + # when all keys are used (as opposed to just when there are no keys at all). + sql = f""" + SELECT devices.user_id, devices.device_id, algorithm + FROM devices + LEFT JOIN e2e_fallback_keys_json AS fallback_keys + ON devices.user_id = fallback_keys.user_id + AND devices.device_id = fallback_keys.device_id + AND NOT fallback_keys.used + WHERE + {user_in_where_clause} + """ + txn.execute(sql, user_parameters) + + result: TransactionUnusedFallbackKeys = {} + + for user_id, device_id, algorithm in txn: + # We deliberately construct empty dictionaries and lists for + # users and devices without any unused fallback keys. + # We *could* omit these empty dicts if there have been no + # changes since the last transaction, but we currently don't + # do any change tracking! + device_unused_keys = result.setdefault(user_id, {}).setdefault( + device_id, [] + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_unused_keys.append(algorithm) + + return result + + return await self.db_pool.runInteraction( + "_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn + ) + async def set_e2e_fallback_keys( self, user_id: str, device_id: str, fallback_keys: JsonDict ) -> None: diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 8fb6687f89..b9dc4dfe1b 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -68,6 +68,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -92,6 +94,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -114,7 +118,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[], to_device_messages=[] + service=service, + events=events, + ephemeral=[], + to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -216,7 +225,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with(service, [event], [], []) + self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -231,11 +240,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], []) + self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], []) + self.txn_ctrl.send.assert_called_with( + service, [event2, event3], [], [], None, None + ) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): @@ -261,15 +272,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], []) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -288,13 +299,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.scheduler.enqueue_for_appservice(service, [event], []) # Expect the first event to be sent immediately. - self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], []) + self.txn_ctrl.send.assert_called_with( + service, [event_list[0]], [], [], None, None + ) srv_1_defer.callback(service) # Then send the next 100 events - self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[1:101], [], [], None, None + ) srv_2_defer.callback(service) # Then the final 99 events - self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[101:], [], [], None, None + ) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): @@ -302,14 +319,18 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4, name="service") event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_multiple_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_single_ephemeral_with_queue(self): d = defer.Deferred() @@ -324,13 +345,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1, []) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [] + service, [], event_list_2 + event_list_3, [], None, None ) self.assertEquals(2, self.txn_ctrl.send.call_count) @@ -343,7 +364,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], first_chunk, [], None, None + ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk, []) + self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) self.assertEquals(2, self.txn_ctrl.send.call_count) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 9918ff6807..6e0ec37963 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -16,17 +16,25 @@ from typing import Dict, Iterable, List, Optional from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.appservice import ApplicationService +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.rest.client import login, receipts, room, sendtodevice +from synapse.rest.client import login, receipts, register, room, sendtodevice +from synapse.server import HomeServer from synapse.types import RoomStreamToken +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config from tests.utils import MockClock @@ -428,7 +436,14 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # # The uninterested application service should not have been notified at all. self.send_mock.assert_called_once() - service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0] + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent self.assertEqual(service, interested_appservice) @@ -540,7 +555,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): service_id_to_message_count: Dict[str, int] = {} for call in self.send_mock.call_args_list: - service, _events, _ephemeral, to_device_messages = call[0] + service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] # Check that this was made to an interested service self.assertIn(service, interested_appservices) @@ -582,3 +597,174 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self._services.append(appservice) return appservice + + +class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): + # Argument indices for pulling out arguments from a `send_mock`. + ARG_OTK_COUNTS = 4 + ARG_FALLBACK_KEYS = 5 + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + room.register_servlets, + sendtodevice.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Mock the ApplicationServiceScheduler's _TransactionController's send method so that + # we can track what's going out + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. + + # Define an application service for the tests + self._service_token = "VERYSECRET" + self._service = ApplicationService( + self._service_token, + "as1.invalid", + "as1", + "@as.sender:test", + namespaces={ + "users": [ + {"regex": "@_as_.*:test", "exclusive": True}, + {"regex": "@as.sender:test", "exclusive": True}, + ] + }, + msc3202_transaction_extensions=True, + ) + self.hs.get_datastores().main.services_cache = [self._service] + + # Register some appservice users + self._sender_user, self._sender_device = self.register_appservice_user( + "as.sender", self._service_token + ) + self._namespaced_user, self._namespaced_device = self.register_appservice_user( + "_as_user1", self._service_token + ) + + # Register a real user as well. + self._real_user = self.register_user("real.user", "meow") + self._real_user_token = self.login("real.user", "meow") + + async def _add_otks_for_device( + self, user_id: str, device_id: str, otk_count: int + ) -> None: + """ + Add some dummy keys. It doesn't matter if they're not a real algorithm; + that should be opaque to the server anyway. + """ + await self.hs.get_datastores().main.add_e2e_one_time_keys( + user_id, + device_id, + self.clock.time_msec(), + [("algo", f"k{i}", "{}") for i in range(otk_count)], + ) + + async def _add_fallback_key_for_device( + self, user_id: str, device_id: str, used: bool + ) -> None: + """ + Adds a fake fallback key to a device, optionally marking it as used + right away. + """ + store = self.hs.get_datastores().main + await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"}) + if used is True: + # Mark the key as used + await store.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": "algo", + "key_id": "fk", + }, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", + ) + + def _set_up_devices_and_a_room(self) -> str: + """ + Helper to set up devices for all the users + and a room for the users to talk in. + """ + + async def preparation(): + await self._add_otks_for_device(self._sender_user, self._sender_device, 42) + await self._add_fallback_key_for_device( + self._sender_user, self._sender_device, used=True + ) + await self._add_otks_for_device( + self._namespaced_user, self._namespaced_device, 36 + ) + await self._add_fallback_key_for_device( + self._namespaced_user, self._namespaced_device, used=False + ) + + # Register a device for the real user, too, so that we can later ensure + # that we don't leak information to the AS about the non-AS user. + await self.hs.get_datastores().main.store_device( + self._real_user, "REALDEV", "UltraMatrix 3000" + ) + await self._add_otks_for_device(self._real_user, "REALDEV", 50) + + self.get_success(preparation()) + + room_id = self.helper.create_room_as( + self._real_user, is_public=True, tok=self._real_user_token + ) + self.helper.join( + room_id, + self._namespaced_user, + tok=self._service_token, + appservice_user_id=self._namespaced_user, + ) + + # Check it was called for sanity. (This was to send the join event to the AS.) + self.send_mock.assert_called() + self.send_mock.reset_mock() + + return room_id + + @override_config( + {"experimental_features": {"msc3202_transaction_extensions": True}} + ) + def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus( + self, + ) -> None: + """ + Tests that: + - the AS receives one-time key counts and unused fallback keys for: + - the specified sender; and + - any user who is in receipt of the PDUs + """ + + room_id = self._set_up_devices_and_a_room() + + # Send a message into the AS's room + self.helper.send(room_id, "woof woof", tok=self._real_user_token) + + # Capture what was sent as an AS transaction. + self.send_mock.assert_called() + last_args, _last_kwargs = self.send_mock.call_args + otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS] + unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[ + self.ARG_FALLBACK_KEYS + ] + + self.assertEqual( + otks, + { + "@as.sender:test": {self._sender_device: {"algo": 42}}, + "@_as_user1:test": {self._namespaced_device: {"algo": 36}}, + }, + ) + self.assertEqual( + unused_fallbacks, + { + "@as.sender:test": {self._sender_device: []}, + "@_as_user1:test": {self._namespaced_device: ["algo"]}, + }, + ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 50703ccaee..d2f654214e 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -267,7 +267,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) ) self.assertEquals(txn.id, 1) @@ -283,7 +283,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) @@ -296,7 +296,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) @@ -320,7 +320,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) From 54e74cc15f30585f5874780437614c0df6f639d9 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 24 Feb 2022 19:56:38 +0100 Subject: [PATCH 13/40] Add type hints to `tests/rest/client` (#12072) --- changelog.d/12072.misc | 1 + tests/rest/client/test_consent.py | 19 +++-- tests/rest/client/test_device_lists.py | 16 ++-- tests/rest/client/test_ephemeral_message.py | 19 +++-- tests/rest/client/test_identity.py | 13 +++- tests/rest/client/test_keys.py | 6 +- tests/rest/client/test_password_policy.py | 39 +++++----- tests/rest/client/test_power_levels.py | 47 +++++++----- tests/rest/client/test_presence.py | 15 ++-- tests/rest/client/test_room_batch.py | 2 +- tests/rest/client/utils.py | 85 +++++++++++++-------- 11 files changed, 160 insertions(+), 102 deletions(-) create mode 100644 changelog.d/12072.misc diff --git a/changelog.d/12072.misc b/changelog.d/12072.misc new file mode 100644 index 0000000000..0360dbd61e --- /dev/null +++ b/changelog.d/12072.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index fcdc565814..b1ca81a911 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -13,11 +13,16 @@ # limitations under the License. import os +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder from synapse.rest.client import login, room from synapse.rest.consent import consent_resource +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -32,7 +37,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["form_secret"] = "123abc" @@ -56,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def test_render_public_consent(self): + def test_render_public_consent(self) -> None: """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) channel = make_request( @@ -66,9 +71,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): "/consent?v=1", shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) - def test_accept_consent(self): + def test_accept_consent(self) -> None: """ A user can use the consent form to accept the terms. """ @@ -92,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the version from the body, and whether we've consented version, consented = channel.result["body"].decode("ascii").split(",") @@ -107,7 +112,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Fetch the consent page, to get the consent version -- it should have # changed @@ -119,7 +124,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the version from the body, and check that it's the version we # agreed to, and that we've consented to it. diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py index 16070cf027..a8af4e2435 100644 --- a/tests/rest/client/test_device_lists.py +++ b/tests/rest/client/test_device_lists.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + from synapse.rest import admin, devices, room, sync from synapse.rest.client import account, login, register @@ -30,7 +32,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): devices.register_servlets, ] - def test_receiving_local_device_list_changes(self): + def test_receiving_local_device_list_changes(self) -> None: """Tests that a local users that share a room receive each other's device list changes. """ @@ -84,7 +86,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): }, access_token=alice_access_token, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that bob's incremental sync contains the updated device list. # If not, the client would only receive the device list update on the @@ -97,7 +99,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): ) self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) - def test_not_receiving_local_device_list_changes(self): + def test_not_receiving_local_device_list_changes(self) -> None: """Tests a local users DO NOT receive device updates from each other if they do not share a room. """ @@ -119,7 +121,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): "/sync", access_token=bob_access_token, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) next_batch_token = channel.json_body["next_batch"] # ...and then an incremental sync. This should block until the sync stream is woken up, @@ -141,11 +143,13 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): }, access_token=alice_access_token, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that bob's incremental sync does not contain the updated device list. bob_sync_channel.await_result() - self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + self.assertEqual( + bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body + ) changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( "changed", [] diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 3d7aa8ec86..9fa1f82dfe 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -11,9 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventContentFields, EventTypes from synapse.rest import admin from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -27,7 +34,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_ephemeral_messages"] = True @@ -35,10 +42,10 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_message_expiry_no_delay(self): + def test_message_expiry_no_delay(self) -> None: """Tests that sending a message sent with a m.self_destruct_after field set to the past results in that event being deleted right away. """ @@ -61,7 +68,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): event_content = self.get_event(self.room_id, event_id)["content"] self.assertFalse(bool(event_content), event_content) - def test_message_expiry_delay(self): + def test_message_expiry_delay(self) -> None: """Tests that sending a message with a m.self_destruct_after field set to the future results in that event not being deleted right away, but advancing the clock to after that expiry timestamp causes the event to be deleted. @@ -89,7 +96,9 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): event_content = self.get_event(self.room_id, event_id)["content"] self.assertFalse(bool(event_content), event_content) - def get_event(self, room_id, event_id, expected_code=200): + def get_event( + self, room_id: str, event_id: str, expected_code: int = HTTPStatus.OK + ) -> JsonDict: url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) channel = self.make_request("GET", url) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index becb4e8dcc..299b9d21e2 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -13,9 +13,14 @@ # limitations under the License. import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -28,7 +33,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_3pid_lookup"] = False @@ -36,14 +41,14 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs - def test_3pid_lookup_disabled(self): + def test_3pid_lookup_disabled(self) -> None: self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) room_id = channel.json_body["room_id"] params = { @@ -56,4 +61,4 @@ class IdentityTestCase(unittest.HomeserverTestCase): channel = self.make_request( b"POST", request_url, request_data, access_token=tok ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py index d7fa635eae..bbc8e74243 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py @@ -28,7 +28,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def test_rejects_device_id_ice_key_outside_of_list(self): + def test_rejects_device_id_ice_key_outside_of_list(self) -> None: self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") bob = self.register_user("bob", "uncle") @@ -49,7 +49,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_rejects_device_key_given_as_map_to_bool(self): + def test_rejects_device_key_given_as_map_to_bool(self) -> None: self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") bob = self.register_user("bob", "uncle") @@ -73,7 +73,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_requires_device_key(self): + def test_requires_device_key(self) -> None: """`device_keys` is required. We should complain if it's missing.""" self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py index 3cf5871899..3a74d2e96c 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py @@ -13,11 +13,16 @@ # limitations under the License. import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import account, login, password_policy, register +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -46,7 +51,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.register_url = "/_matrix/client/r0/register" self.policy = { "enabled": True, @@ -65,12 +70,12 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def test_get_policy(self): + def test_get_policy(self) -> None: """Tests if the /password_policy endpoint returns the configured policy.""" channel = self.make_request("GET", "/_matrix/client/r0/password_policy") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual( channel.json_body, { @@ -83,70 +88,70 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_password_too_short(self): + def test_password_too_short(self) -> None: request_data = json.dumps({"username": "kermit", "password": "shorty"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, ) - def test_password_no_digit(self): + def test_password_no_digit(self) -> None: request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, ) - def test_password_no_symbol(self): + def test_password_no_symbol(self) -> None: request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, ) - def test_password_no_uppercase(self): + def test_password_no_uppercase(self) -> None: request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, ) - def test_password_no_lowercase(self): + def test_password_no_lowercase(self) -> None: request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, ) - def test_password_compliant(self): + def test_password_compliant(self) -> None: request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has # responded with a list of registration flows. - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) - def test_password_change(self): + def test_password_change(self) -> None: """This doesn't test every possible use case, only that hitting /account/password triggers the password validation code. """ @@ -173,5 +178,5 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): access_token=tok, ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index c0de4c93a8..27dcfc83d2 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -30,12 +35,12 @@ class PowerLevelsTestCase(HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register a room admin, moderator and regular user self.admin_user_id = self.register_user("admin", "pass") self.admin_access_token = self.login("admin", "pass") @@ -88,7 +93,7 @@ class PowerLevelsTestCase(HomeserverTestCase): tok=self.admin_access_token, ) - def test_non_admins_cannot_enable_room_encryption(self): + def test_non_admins_cannot_enable_room_encryption(self) -> None: # have the mod try to enable room encryption self.helper.send_state( self.room_id, @@ -104,10 +109,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.encryption", {"algorithm": "m.megolm.v1.aes-sha2"}, tok=self.user_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) - def test_non_admins_cannot_send_server_acl(self): + def test_non_admins_cannot_send_server_acl(self) -> None: # have the mod try to send a server ACL self.helper.send_state( self.room_id, @@ -118,7 +123,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.mod_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) # have the user try to send a server ACL @@ -131,10 +136,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.user_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) - def test_non_admins_cannot_tombstone_room(self): + def test_non_admins_cannot_tombstone_room(self) -> None: # Create another room that will serve as our "upgraded room" self.upgraded_room_id = self.helper.create_room_as( self.admin_user_id, tok=self.admin_access_token @@ -149,7 +154,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "replacement_room": self.upgraded_room_id, }, tok=self.mod_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) # have the user try to send a tombstone event @@ -164,17 +169,17 @@ class PowerLevelsTestCase(HomeserverTestCase): expect_code=403, # expect failure ) - def test_admins_can_enable_room_encryption(self): + def test_admins_can_enable_room_encryption(self) -> None: # have the admin try to enable room encryption self.helper.send_state( self.room_id, "m.room.encryption", {"algorithm": "m.megolm.v1.aes-sha2"}, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_admins_can_send_server_acl(self): + def test_admins_can_send_server_acl(self) -> None: # have the admin try to send a server ACL self.helper.send_state( self.room_id, @@ -185,10 +190,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_admins_can_tombstone_room(self): + def test_admins_can_tombstone_room(self) -> None: # Create another room that will serve as our "upgraded room" self.upgraded_room_id = self.helper.create_room_as( self.admin_user_id, tok=self.admin_access_token @@ -203,10 +208,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "replacement_room": self.upgraded_room_id, }, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_cannot_set_string_power_levels(self): + def test_cannot_set_string_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -221,7 +226,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( @@ -230,7 +235,7 @@ class PowerLevelsTestCase(HomeserverTestCase): body, ) - def test_cannot_set_unsafe_large_power_levels(self): + def test_cannot_set_unsafe_large_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -247,7 +252,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( @@ -256,7 +261,7 @@ class PowerLevelsTestCase(HomeserverTestCase): body, ) - def test_cannot_set_unsafe_small_power_levels(self): + def test_cannot_set_unsafe_small_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -273,7 +278,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 56fe1a3d01..0abe378fe4 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from http import HTTPStatus from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.presence import PresenceHandler from synapse.rest.client import presence +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -31,7 +34,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): user = UserID.from_string(user_id) servlets = [presence.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: presence_handler = Mock(spec=PresenceHandler) presence_handler.set_state.return_value = defer.succeed(None) @@ -45,7 +48,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): return hs - def test_put_presence(self): + def test_put_presence(self) -> None: """ PUT to the status endpoint with use_presence enabled will call set_state on the presence handler. @@ -57,11 +60,11 @@ class PresenceTestCase(unittest.HomeserverTestCase): "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) @unittest.override_config({"use_presence": False}) - def test_put_presence_disabled(self): + def test_put_presence_disabled(self) -> None: """ PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. @@ -72,5 +75,5 @@ class PresenceTestCase(unittest.HomeserverTestCase): "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index e9f8704035..44f333a0ee 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -134,7 +134,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): return room_id, event_id_a, event_id_b, event_id_c @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) - def test_same_state_groups_for_whole_historical_batch(self): + def test_same_state_groups_for_whole_historical_batch(self) -> None: """Make sure that when using the `/batch_send` endpoint to import a bunch of historical messages, it re-uses the same `state_group` across the whole batch. This is an easy optimization to make sure we're getting diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 2b3fdadffa..46cd5f70a8 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -19,6 +19,7 @@ import json import re import time import urllib.parse +from http import HTTPStatus from typing import ( Any, AnyStr, @@ -89,7 +90,7 @@ class RestHelper: is_public: Optional[bool] = None, room_version: Optional[str] = None, tok: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, extra_content: Optional[Dict] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ) -> Optional[str]: @@ -137,12 +138,19 @@ class RestHelper: assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id - if expect_code == 200: + if expect_code == HTTPStatus.OK: return channel.json_body["room_id"] else: return None - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + def invite( + self, + room: Optional[str] = None, + src: Optional[str] = None, + targ: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=src, @@ -156,7 +164,7 @@ class RestHelper: self, room: str, user: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, ) -> None: @@ -170,7 +178,14 @@ class RestHelper: expect_code=expect_code, ) - def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None): + def knock( + self, + room: Optional[str] = None, + user: Optional[str] = None, + reason: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: temp_id = self.auth_user_id self.auth_user_id = user path = "/knock/%s" % room @@ -199,7 +214,13 @@ class RestHelper: self.auth_user_id = temp_id - def leave(self, room=None, user=None, expect_code=200, tok=None): + def leave( + self, + room: Optional[str] = None, + user: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=user, @@ -209,7 +230,7 @@ class RestHelper: expect_code=expect_code, ) - def ban(self, room: str, src: str, targ: str, **kwargs: object): + def ban(self, room: str, src: str, targ: str, **kwargs: object) -> None: """A convenience helper: `change_membership` with `membership` preset to "ban".""" self.change_membership( room=room, @@ -228,7 +249,7 @@ class RestHelper: extra_data: Optional[dict] = None, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, expect_errcode: Optional[str] = None, ) -> None: """ @@ -294,13 +315,13 @@ class RestHelper: def send( self, - room_id, - body=None, - txn_id=None, - tok=None, - expect_code=200, + room_id: str, + body: Optional[str] = None, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + ) -> JsonDict: if body is None: body = "body_text_here" @@ -318,14 +339,14 @@ class RestHelper: def send_event( self, - room_id, - type, + room_id: str, + type: str, content: Optional[dict] = None, - txn_id=None, - tok=None, - expect_code=200, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + ) -> JsonDict: if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -358,10 +379,10 @@ class RestHelper: event_type: str, body: Optional[Dict[str, Any]], tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", method: str = "GET", - ) -> Dict: + ) -> JsonDict: """Read or write some state from a given room Args: @@ -410,9 +431,9 @@ class RestHelper: room_id: str, event_type: str, tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", - ): + ) -> JsonDict: """Gets some state from a room Args: @@ -438,9 +459,9 @@ class RestHelper: event_type: str, body: Dict[str, Any], tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", - ): + ) -> JsonDict: """Set some state in a room Args: @@ -467,8 +488,8 @@ class RestHelper: image_data: bytes, tok: str, filename: str = "test.png", - expect_code: int = 200, - ) -> dict: + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: """Upload a piece of test media to the media repo Args: resource: The resource that will handle the upload request @@ -513,7 +534,7 @@ class RestHelper: channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) # expect a confirmation page - assert channel.code == 200, channel.result + assert channel.code == HTTPStatus.OK, channel.result # fish the matrix login token out of the body of the confirmation page m = re.search( @@ -532,7 +553,7 @@ class RestHelper: "/login", content={"type": "m.login.token", "token": login_token}, ) - assert channel.code == 200 + assert channel.code == HTTPStatus.OK return channel.json_body def auth_via_oidc( @@ -641,7 +662,7 @@ class RestHelper: (expected_uri, resp_obj) = expected_requests.pop(0) assert uri == expected_uri resp = FakeResponse( - code=200, + code=HTTPStatus.OK, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), ) @@ -739,7 +760,7 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint ) # that should serve a confirmation page - assert channel.code == 200, channel.text_body + assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) # parse the confirmation page to fish out the link. From f3fd8558cdb5d91d0e54ca35b55a3dba2610b215 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 25 Feb 2022 10:19:49 +0000 Subject: [PATCH 14/40] Minor typing fixes for `synapse/storage/persist_events.py` (#12069) Signed-off-by: Sean Quah --- changelog.d/12069.misc | 1 + synapse/storage/databases/main/events.py | 23 ++++++++++++---------- synapse/storage/persist_events.py | 25 ++++++++++++------------ 3 files changed, 26 insertions(+), 23 deletions(-) create mode 100644 changelog.d/12069.misc diff --git a/changelog.d/12069.misc b/changelog.d/12069.misc new file mode 100644 index 0000000000..8374a63220 --- /dev/null +++ b/changelog.d/12069.misc @@ -0,0 +1 @@ +Minor typing fixes. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index e53e84054a..23fa089bca 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -130,7 +130,7 @@ class PersistEventsStore: *, current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], - new_forward_extremeties: Dict[str, List[str]], + new_forward_extremities: Dict[str, Set[str]], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, ) -> None: @@ -143,7 +143,7 @@ class PersistEventsStore: the room based on forward extremities state_delta_for_room: Map from room_id to the delta to apply to room state - new_forward_extremities: Map from room_id to list of event IDs + new_forward_extremities: Map from room_id to set of event IDs that are the new forward extremities of the room. use_negative_stream_ordering: Whether to start stream_ordering on the negative side and decrement. This should be set as True @@ -193,7 +193,7 @@ class PersistEventsStore: events_and_contexts=events_and_contexts, inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, ) persist_event_counter.inc(len(events_and_contexts)) @@ -220,7 +220,7 @@ class PersistEventsStore: for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) - for room_id, latest_event_ids in new_forward_extremeties.items(): + for room_id, latest_event_ids in new_forward_extremities.items(): self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) ) @@ -334,8 +334,8 @@ class PersistEventsStore: events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, state_delta_for_room: Optional[Dict[str, DeltaState]] = None, - new_forward_extremeties: Optional[Dict[str, List[str]]] = None, - ): + new_forward_extremities: Optional[Dict[str, Set[str]]] = None, + ) -> None: """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -353,13 +353,13 @@ class PersistEventsStore: from the database. This is useful when retrying due to IntegrityError. state_delta_for_room: The current-state delta for each room. - new_forward_extremetie: The new forward extremities for each room. + new_forward_extremities: The new forward extremities for each room. For each room, a list of the event ids which are the forward extremities. """ state_delta_for_room = state_delta_for_room or {} - new_forward_extremeties = new_forward_extremeties or {} + new_forward_extremities = new_forward_extremities or {} all_events_and_contexts = events_and_contexts @@ -372,7 +372,7 @@ class PersistEventsStore: self._update_forward_extremities_txn( txn, - new_forward_extremities=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, max_stream_order=max_stream_order, ) @@ -1158,7 +1158,10 @@ class PersistEventsStore: ) def _update_forward_extremities_txn( - self, txn, new_forward_extremities, max_stream_order + self, + txn: LoggingTransaction, + new_forward_extremities: Dict[str, Set[str]], + max_stream_order: int, ): for room_id in new_forward_extremities.keys(): self.db_pool.simple_delete_txn( diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 428d66a617..7d543fdbe0 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -427,21 +427,21 @@ class EventsPersistenceStorage: # NB: Assumes that we are only persisting events for one room # at a time. - # map room_id->list[event_ids] giving the new forward + # map room_id->set[event_ids] giving the new forward # extremities in each room - new_forward_extremeties = {} + new_forward_extremities: Dict[str, Set[str]] = {} # map room_id->(type,state_key)->event_id tracking the full # state in each room after adding these events. # This is simply used to prefill the get_current_state_ids # cache - current_state_for_room = {} + current_state_for_room: Dict[str, StateMap[str]] = {} # map room_id->(to_delete, to_insert) where to_delete is a list # of type/state keys to remove from current state, and to_insert # is a map (type,key)->event_id giving the state delta in each # room - state_delta_for_room = {} + state_delta_for_room: Dict[str, DeltaState] = {} # Set of remote users which were in rooms the server has left. We # should check if we still share any rooms and if not we mark their @@ -460,14 +460,13 @@ class EventsPersistenceStorage: ) for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = ( + latest_event_ids = set( await self.main_store.get_latest_event_ids_in_room(room_id) ) new_latest_event_ids = await self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids ) - latest_event_ids = set(latest_event_ids) if new_latest_event_ids == latest_event_ids: # No change in extremities, so no change in state continue @@ -478,7 +477,7 @@ class EventsPersistenceStorage: # extremities, so we'll `continue` above and skip this bit.) assert new_latest_event_ids, "No forward extremities left!" - new_forward_extremeties[room_id] = new_latest_event_ids + new_forward_extremities[room_id] = new_latest_event_ids len_1 = ( len(latest_event_ids) == 1 @@ -533,7 +532,7 @@ class EventsPersistenceStorage: # extremities, so we'll `continue` above and skip this bit.) assert new_latest_event_ids, "No forward extremities left!" - new_forward_extremeties[room_id] = new_latest_event_ids + new_forward_extremities[room_id] = new_latest_event_ids # If either are not None then there has been a change, # and we need to work out the delta (or use that @@ -567,7 +566,7 @@ class EventsPersistenceStorage: ) if not is_still_joined: logger.info("Server no longer in room %s", room_id) - latest_event_ids = [] + latest_event_ids = set() current_state = {} delta.no_longer_in_room = True @@ -582,7 +581,7 @@ class EventsPersistenceStorage: chunk, current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, use_negative_stream_ordering=backfilled, inhibit_local_membership_updates=backfilled, ) @@ -596,7 +595,7 @@ class EventsPersistenceStorage: room_id: str, event_contexts: List[Tuple[EventBase, EventContext]], latest_event_ids: Collection[str], - ): + ) -> Set[str]: """Calculates the new forward extremities for a room given events to persist. @@ -906,9 +905,9 @@ class EventsPersistenceStorage: # Ideally we'd figure out a way of still being able to drop old # dummy events that reference local events, but this is good enough # as a first cut. - events_to_check = [event] + events_to_check: Collection[EventBase] = [event] while events_to_check: - new_events = set() + new_events: Set[str] = set() for event_to_check in events_to_check: if self.is_mine_id(event_to_check.sender): if event_to_check.type != EventTypes.Dummy: From b43c3ef8e2306829074d847bed50575d5e7c7ea3 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 25 Feb 2022 10:20:40 +0000 Subject: [PATCH 15/40] Ensure that `get_datastores().main` is typed (#12070) Signed-off-by: Sean Quah --- changelog.d/12070.misc | 1 + synapse/storage/databases/__init__.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12070.misc diff --git a/changelog.d/12070.misc b/changelog.d/12070.misc new file mode 100644 index 0000000000..d4bedc6b97 --- /dev/null +++ b/changelog.d/12070.misc @@ -0,0 +1 @@ +Remove legacy `HomeServer.get_datastore()`. diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index cfe887b7f7..ce3d1d4e94 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -24,6 +24,7 @@ from synapse.storage.prepare_database import prepare_database if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class Databases(Generic[DataStoreT]): """ databases: List[DatabasePool] - main: DataStoreT + main: "DataStore" # FIXME: #11165: actually an instance of `main_store_class` state: StateGroupDataStore persist_events: Optional[PersistEventsStore] From ab3ef49059e465198754a3d818d1f3b21771f5ef Mon Sep 17 00:00:00 2001 From: lukasdenk <63459921+lukasdenk@users.noreply.github.com> Date: Mon, 28 Feb 2022 12:42:13 +0100 Subject: [PATCH 16/40] synctl: print warning if synctl_cache_factor is set in config (#11865) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/11865.removal | 1 + synctl | 8 ++++++++ 2 files changed, 9 insertions(+) create mode 100644 changelog.d/11865.removal diff --git a/changelog.d/11865.removal b/changelog.d/11865.removal new file mode 100644 index 0000000000..9fcabfc720 --- /dev/null +++ b/changelog.d/11865.removal @@ -0,0 +1 @@ +Deprecate using `synctl` with the config option `synctl_cache_factor` and print a warning if a user still uses this option. diff --git a/synctl b/synctl index 0e54f4847b..1ab36949c7 100755 --- a/synctl +++ b/synctl @@ -37,6 +37,13 @@ YELLOW = "\x1b[1;33m" RED = "\x1b[1;31m" NORMAL = "\x1b[m" +SYNCTL_CACHE_FACTOR_WARNING = """\ +Setting 'synctl_cache_factor' in the config is deprecated. Instead, please do +one of the following: + - Either set the environment variable 'SYNAPSE_CACHE_FACTOR' + - or set 'caches.global_factor' in the homeserver config. +--------------------------------------------------------------------------------""" + def pid_running(pid): try: @@ -228,6 +235,7 @@ def main(): start_stop_synapse = True if cache_factor: + write(SYNCTL_CACHE_FACTOR_WARNING) os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) cache_factors = config.get("synctl_cache_factors", {}) From 02d708568b476f2f7716000b35c0adfa4cbd31b3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Feb 2022 07:12:29 -0500 Subject: [PATCH 17/40] Replace assertEquals and friends with non-deprecated versions. (#12092) --- changelog.d/12092.misc | 1 + tests/api/test_auth.py | 36 +-- tests/api/test_filtering.py | 20 +- tests/api/test_ratelimiting.py | 28 +- tests/appservice/test_scheduler.py | 62 ++--- tests/crypto/test_event_signing.py | 8 +- tests/crypto/test_keyring.py | 6 +- tests/events/test_utils.py | 16 +- tests/federation/test_complexity.py | 4 +- tests/federation/test_federation_server.py | 12 +- tests/federation/transport/test_knocking.py | 16 +- tests/federation/transport/test_server.py | 4 +- tests/handlers/test_appservice.py | 18 +- tests/handlers/test_directory.py | 28 +- tests/handlers/test_presence.py | 78 +++--- tests/handlers/test_profile.py | 20 +- tests/handlers/test_receipts.py | 2 +- tests/handlers/test_register.py | 4 +- tests/handlers/test_sync.py | 6 +- tests/handlers/test_typing.py | 40 +-- tests/handlers/test_user_directory.py | 4 +- tests/http/federation/test_srv_resolver.py | 24 +- .../replication/slave/storage/test_events.py | 2 +- tests/rest/admin/test_room.py | 14 +- tests/rest/client/test_account.py | 22 +- tests/rest/client/test_events.py | 10 +- tests/rest/client/test_filter.py | 10 +- tests/rest/client/test_groups.py | 12 +- tests/rest/client/test_login.py | 78 +++--- tests/rest/client/test_profile.py | 2 +- tests/rest/client/test_register.py | 166 ++++++------ tests/rest/client/test_relations.py | 248 +++++++++--------- tests/rest/client/test_rooms.py | 192 +++++++------- tests/rest/client/test_shadow_banned.py | 20 +- tests/rest/client/test_shared_rooms.py | 24 +- tests/rest/client/test_sync.py | 16 +- tests/rest/client/test_third_party_rules.py | 14 +- tests/rest/client/test_typing.py | 18 +- tests/rest/client/test_upgrade_room.py | 18 +- tests/rest/media/v1/test_media_storage.py | 2 +- .../databases/main/test_events_worker.py | 12 +- tests/storage/test_appservice.py | 86 +++--- tests/storage/test_base.py | 6 +- tests/storage/test_directory.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_main.py | 8 +- tests/storage/test_profile.py | 4 +- tests/storage/test_registration.py | 6 +- tests/storage/test_room.py | 4 +- tests/storage/test_room_search.py | 6 +- tests/storage/test_roommember.py | 2 +- tests/test_distributor.py | 2 +- tests/test_terms_auth.py | 6 +- tests/test_test_utils.py | 2 +- tests/test_types.py | 16 +- tests/unittest.py | 6 +- tests/util/caches/test_deferred_cache.py | 2 +- tests/util/caches/test_descriptors.py | 70 ++--- tests/util/test_expiring_cache.py | 40 +-- tests/util/test_logcontext.py | 2 +- tests/util/test_lrucache.py | 140 +++++----- tests/util/test_treecache.py | 48 ++-- 62 files changed, 888 insertions(+), 889 deletions(-) create mode 100644 changelog.d/12092.misc diff --git a/changelog.d/12092.misc b/changelog.d/12092.misc new file mode 100644 index 0000000000..62653d6f8d --- /dev/null +++ b/changelog.d/12092.misc @@ -0,0 +1 @@ +User `assertEqual` instead of the deprecated `assertEquals` in test code. diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 686d17c0de..3e05789923 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): self.store.get_user_by_access_token = simple_async_mock(None) @@ -109,7 +109,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_good_ip(self): from netaddr import IPSet @@ -128,7 +128,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): from netaddr import IPSet @@ -195,7 +195,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals( + self.assertEqual( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -242,10 +242,10 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals( + self.assertEqual( requester.user.to_string(), masquerading_user_id.decode("utf8") ) - self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8")) + self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8")) @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): @@ -275,8 +275,8 @@ class AuthTestCase(unittest.HomeserverTestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() failure = self.get_failure(self.auth.get_user_by_req(request), AuthError) - self.assertEquals(failure.value.code, 400) - self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE) + self.assertEqual(failure.value.code, 400) + self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self): self.store.get_user_by_access_token = simple_async_mock( @@ -309,7 +309,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(self.store.insert_client_ip.call_count, 2) + self.assertEqual(self.store.insert_client_ip.call_count, 2) def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = simple_async_mock( @@ -369,9 +369,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_monthly_active_count = simple_async_mock(lots_of_users) e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) @@ -473,9 +473,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) def test_hs_disabled_no_server_notices_user(self): """Check that 'hs_disabled_message' works correctly when there is no @@ -488,9 +488,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) def test_server_notices_mxid_special_cased(self): self.auth_blocking._hs_disabled = True diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 973f0f7fa1..2525018e95 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -364,7 +364,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_presence(events=events)) - self.assertEquals(events, results) + self.assertEqual(events, results) def test_filter_presence_no_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} @@ -388,7 +388,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_presence(events=events)) - self.assertEquals([], results) + self.assertEqual([], results) def test_filter_room_state_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -407,7 +407,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_room_state(events=events)) - self.assertEquals(events, results) + self.assertEqual(events, results) def test_filter_room_state_no_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -428,7 +428,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_room_state(events)) - self.assertEquals([], results) + self.assertEqual([], results) def test_filter_rooms(self): definition = { @@ -444,7 +444,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids)) - self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) + self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_filter_relations(self): @@ -486,7 +486,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): Filter(self.hs, definition)._check_event_relations(events) ) ) - self.assertEquals(filtered_events, events[1:]) + self.assertEqual(filtered_events, events[1:]) def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -497,8 +497,8 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals(filter_id, 0) - self.assertEquals( + self.assertEqual(filter_id, 0) + self.assertEqual( user_filter_json, ( self.get_success( @@ -524,6 +524,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals(filter.get_filter_json(), user_filter_json) + self.assertEqual(filter.get_filter_json(), user_filter_json) - self.assertRegexpMatches(repr(filter), r"") + self.assertRegex(repr(filter), r"") diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 4ef754a186..483d5463ad 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -14,19 +14,19 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(None, key="test_id", _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=5) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( @@ -45,19 +45,19 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( @@ -76,19 +76,19 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) def test_allowed_via_ratelimit(self): limiter = Ratelimiter( @@ -246,7 +246,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(None, key="test_id", n_actions=3, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that, after doing these 3 actions, we can't do any more action without # waiting. @@ -254,7 +254,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that after waiting we can do only 1 action. allowed, time_allowed = self.get_success_or_raise( @@ -269,7 +269,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertTrue(allowed) # The time allowed is the current time because we could still repeat the action # once. - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10) @@ -277,7 +277,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertFalse(allowed) # The time allowed doesn't change despite allowed being False because, while we # don't allow 2 actions, we could still do 1. - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that after waiting a bit more we can do 2 actions. allowed, time_allowed = self.get_success_or_raise( @@ -286,4 +286,4 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertTrue(allowed) # The time allowed is the current time because we could still repeat the action # once. - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index b9dc4dfe1b..1cbb059357 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -71,7 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): one_time_key_counts={}, unused_fallback_keys={}, ) - self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made + self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed def test_single_service_down(self): @@ -97,8 +97,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): one_time_key_counts={}, unused_fallback_keys={}, ) - self.assertEquals(0, txn.send.call_count) # txn not sent though - self.assertEquals(0, txn.complete.call_count) # or completed + self.assertEqual(0, txn.send.call_count) # txn not sent though + self.assertEqual(0, txn.complete.call_count) # or completed def test_single_service_up_txn_not_sent(self): # Test: The AS is up and the txn is not sent. A Recoverer is made and @@ -125,10 +125,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): one_time_key_counts={}, unused_fallback_keys={}, ) - self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made - self.assertEquals(1, self.recoverer.recover.call_count) # and invoked - self.assertEquals(1, len(self.txnctrl.recoverers)) # and stored - self.assertEquals(0, txn.complete.call_count) # txn not completed + self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made + self.assertEqual(1, self.recoverer.recover.call_count) # and invoked + self.assertEqual(1, len(self.txnctrl.recoverers)) # and stored + self.assertEqual(0, txn.complete.call_count) # txn not completed self.store.set_appservice_state.assert_called_once_with( service, ApplicationServiceState.DOWN # service marked as down ) @@ -161,17 +161,17 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff - self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) txn.send = simple_async_mock(True) txn.complete = simple_async_mock(None) # wait for exp backoff self.clock.advance_time(2) - self.assertEquals(1, txn.send.call_count) - self.assertEquals(1, txn.complete.call_count) + self.assertEqual(1, txn.send.call_count) + self.assertEqual(1, txn.complete.call_count) # 2 because it needs to get None to know there are no more txns - self.assertEquals(2, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(2, self.store.get_oldest_unsent_txn.call_count) self.callback.assert_called_once_with(self.recoverer) - self.assertEquals(self.recoverer.service, self.service) + self.assertEqual(self.recoverer.service, self.service) def test_recover_retry_txn(self): txn = Mock() @@ -187,26 +187,26 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.recoverer.recover() - self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) txn.send = simple_async_mock(False) txn.complete = simple_async_mock(None) self.clock.advance_time(2) - self.assertEquals(1, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(1, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) self.clock.advance_time(4) - self.assertEquals(2, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(2, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) self.clock.advance_time(8) - self.assertEquals(3, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(3, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) txn.send = simple_async_mock(True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) - self.assertEquals(1, txn.send.call_count) # new mock reset call count - self.assertEquals(1, txn.complete.call_count) + self.assertEqual(1, txn.send.call_count) # new mock reset call count + self.assertEqual(1, txn.complete.call_count) self.callback.assert_called_once_with(self.recoverer) @@ -241,13 +241,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) - self.assertEquals(1, self.txn_ctrl.send.call_count) + self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) self.txn_ctrl.send.assert_called_with( service, [event2, event3], [], [], None, None ) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.assertEqual(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): # Tests that each service has its own queue, and that they don't block @@ -281,7 +281,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # service srv_2_defer.callback(srv2) self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) - self.assertEquals(3, self.txn_ctrl.send.call_count) + self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): srv_1_defer = defer.Deferred() @@ -312,7 +312,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.txn_ctrl.send.assert_called_with( service, event_list[101:], [], [], None, None ) - self.assertEquals(3, self.txn_ctrl.send.call_count) + self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): # Expect the event to be sent immediately. @@ -346,14 +346,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) - self.assertEquals(1, self.txn_ctrl.send.call_count) + self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( service, [], event_list_2 + event_list_3, [], None, None ) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.assertEqual(2, self.txn_ctrl.send.call_count) def test_send_large_txns_ephemeral(self): d = defer.Deferred() @@ -369,4 +369,4 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): ) d.callback(service) self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.assertEqual(2, self.txn_ctrl.send.call_count) diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index a72a0103d3..694020fbef 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -63,14 +63,14 @@ class EventSigningTestCase(unittest.TestCase): self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) - self.assertEquals( + self.assertEqual( event.hashes["sha256"], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" ) self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) - self.assertEquals( + self.assertEqual( event.signatures[HOSTNAME][KEY_NAME], "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+" "aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA", @@ -97,14 +97,14 @@ class EventSigningTestCase(unittest.TestCase): self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) - self.assertEquals( + self.assertEqual( event.hashes["sha256"], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" ) self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) - self.assertEquals( + self.assertEqual( event.signatures[HOSTNAME][KEY_NAME], "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw" "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA", diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 3a4d502719..d00ef24ca8 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -76,7 +76,7 @@ class FakeRequest: @logcontext_clean class KeyringTestCase(unittest.HomeserverTestCase): def check_context(self, val, expected): - self.assertEquals(getattr(current_context(), "request", None), expected) + self.assertEqual(getattr(current_context(), "request", None), expected) return val def test_verify_json_objects_for_server_awaits_previous_requests(self): @@ -96,7 +96,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): async def first_lookup_fetch( server_name: str, key_ids: List[str], minimum_valid_until_ts: int ) -> Dict[str, FetchKeyResult]: - # self.assertEquals(current_context().request.id, "context_11") + # self.assertEqual(current_context().request.id, "context_11") self.assertEqual(server_name, "server10") self.assertEqual(key_ids, [get_key_id(key1)]) self.assertEqual(minimum_valid_until_ts, 0) @@ -137,7 +137,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): async def second_lookup_fetch( server_name: str, key_ids: List[str], minimum_valid_until_ts: int ) -> Dict[str, FetchKeyResult]: - # self.assertEquals(current_context().request.id, "context_12") + # self.assertEqual(current_context().request.id, "context_12") return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)} mock_fetcher.get_keys.reset_mock() diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 1dea09e480..45e3395b33 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -395,7 +395,7 @@ class SerializeEventTestCase(unittest.TestCase): return serialize_event(ev, 1479807801915, only_event_fields=fields) def test_event_fields_works_with_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"] ), @@ -403,7 +403,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_works_with_nested_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -416,7 +416,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_works_with_dot_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -429,7 +429,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_works_with_nested_dot_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -445,7 +445,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_nops_with_unknown_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -458,7 +458,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_nops_with_non_dict_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -471,7 +471,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_nops_with_array_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -484,7 +484,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_all_fields_if_empty(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( type="foo", diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 9336181c96..9f1115dd23 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -50,7 +50,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) complexity = channel.json_body["v1"] self.assertTrue(complexity > 0, complexity) @@ -62,7 +62,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index d084919ef7..30e7e5093a 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -59,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, ) - self.assertEquals(400, channel.code, channel.result) + self.assertEqual(400, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") @@ -125,7 +125,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertEqual( channel.json_body["room_version"], @@ -157,7 +157,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -189,7 +189,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}" f"?ver={DEFAULT_ROOM_VERSION}", ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) return channel.json_body def test_send_join(self): @@ -209,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): f"/_matrix/federation/v2/send_join/{self._room_id}/x", content=join_event_dict, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # we should get complete room state back returned_state = [ @@ -266,7 +266,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", content=join_event_dict, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # expect a reduced room state returned_state = [ diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index adf0535d97..648a01618e 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -169,7 +169,7 @@ class KnockingStrippedStateEventHelperMixin(TestCase): self.assertIn(event_type, expected_room_state) # Check the state content matches - self.assertEquals( + self.assertEqual( expected_room_state[event_type]["content"], event["content"] ) @@ -256,7 +256,7 @@ class FederationKnockingTestCase( RoomVersions.V7.identifier, ), ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Note: We don't expect the knock membership event to be sent over federation as # part of the stripped room state, as the knocking homeserver already has that @@ -266,11 +266,11 @@ class FederationKnockingTestCase( knock_event = channel.json_body["event"] # Check that the event has things we expect in it - self.assertEquals(knock_event["room_id"], room_id) - self.assertEquals(knock_event["sender"], fake_knocking_user_id) - self.assertEquals(knock_event["state_key"], fake_knocking_user_id) - self.assertEquals(knock_event["type"], EventTypes.Member) - self.assertEquals(knock_event["content"]["membership"], Membership.KNOCK) + self.assertEqual(knock_event["room_id"], room_id) + self.assertEqual(knock_event["sender"], fake_knocking_user_id) + self.assertEqual(knock_event["state_key"], fake_knocking_user_id) + self.assertEqual(knock_event["type"], EventTypes.Member) + self.assertEqual(knock_event["content"]["membership"], Membership.KNOCK) # Turn the event json dict into a proper event. # We won't sign it properly, but that's OK as we stub out event auth in `prepare` @@ -294,7 +294,7 @@ class FederationKnockingTestCase( % (room_id, signed_knock_event.event_id), signed_knock_event_json, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Check that we got the stripped room state in return room_state_events = channel.json_body["knock_state_events"] diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index eb62addda8..ce49d094d7 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -26,7 +26,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "GET", "/_matrix/federation/v1/publicRooms", ) - self.assertEquals(403, channel.code) + self.assertEqual(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) def test_open_public_room_list_over_federation(self): @@ -37,4 +37,4 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "GET", "/_matrix/federation/v1/publicRooms", ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 6e0ec37963..072e6bbcdd 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -147,8 +147,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.query_alias.assert_called_once_with( interested_service, room_alias_str ) - self.assertEquals(result.room_id, room_id) - self.assertEquals(result.servers, servers) + self.assertEqual(result.room_id, room_id) + self.assertEqual(result.servers, servers) def test_get_3pe_protocols_no_appservices(self): self.mock_store.get_app_services.return_value = [] @@ -156,7 +156,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) ) self.mock_as_api.get_3pe_protocol.assert_not_called() - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_no_protocols(self): service = self._mkservice(False, []) @@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) self.mock_as_api.get_3pe_protocol.assert_not_called() - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_protocol_no_response(self): service = self._mkservice(False, ["my-protocol"]) @@ -177,7 +177,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_select_one_protocol(self): service = self._mkservice(False, ["my-protocol"]) @@ -191,7 +191,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals( + self.assertEqual( response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) @@ -207,7 +207,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals( + self.assertEqual( response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) @@ -222,7 +222,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) self.mock_as_api.get_3pe_protocol.assert_called() - self.assertEquals( + self.assertEqual( response, { "my-protocol": {"x-protocol-data": 42, "instances": []}, @@ -254,7 +254,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) # It's expected that the second service's data doesn't appear in the response - self.assertEquals( + self.assertEqual( response, { "my-protocol": { diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 65ab107d0e..6e403a87c5 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -63,7 +63,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_association(self.my_room)) - self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) + self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self): self.mock_federation.make_query.return_value = make_awaitable( @@ -72,7 +72,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_association(self.remote_room)) - self.assertEquals( + self.assertEqual( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result ) self.mock_federation.make_query.assert_called_with( @@ -94,7 +94,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.handler.on_directory_query({"room_alias": "#your-room:test"}) ) - self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + self.assertEqual({"room_id": "!8765asdf:test", "servers": ["test"]}, response) class TestCreateAlias(unittest.HomeserverTestCase): @@ -224,7 +224,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): create_requester(self.test_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -243,7 +243,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): create_requester(self.admin_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -269,7 +269,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): create_requester(self.test_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -411,7 +411,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): b"directory/room/%23test%3Atest", {"room_id": room_id}, ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) def test_allowed(self): room_id = self.helper.create_room_as(self.user_id) @@ -421,7 +421,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): b"directory/room/%23unofficial_test%3Atest", {"room_id": room_id}, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def test_denied_during_creation(self): """A room alias that is not allowed should be rejected during creation.""" @@ -443,8 +443,8 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): "GET", b"directory/room/%23unofficial_test%3Atest", ) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(channel.json_body["room_id"], room_id) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(channel.json_body["room_id"], room_id) class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): @@ -572,7 +572,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.room_list_handler = hs.get_room_list_handler() self.directory_handler = hs.get_directory_handler() @@ -585,7 +585,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list is enabled so we should get some results channel = self.make_request("GET", b"publicRooms") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) self.room_list_handler.enable_room_list_search = False @@ -593,7 +593,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list disabled so we should get no results channel = self.make_request("GET", b"publicRooms") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) # Room list disabled so we shouldn't be allowed to publish rooms @@ -601,4 +601,4 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 61d28603ae..6ddec9ecf1 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -61,11 +61,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertTrue(persist_and_notify) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -104,11 +104,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertFalse(persist_and_notify) self.assertTrue(federation_ping) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -149,11 +149,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertFalse(persist_and_notify) self.assertTrue(federation_ping) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -191,11 +191,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertTrue(persist_and_notify) self.assertFalse(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 2) + self.assertEqual(wheel_timer.insert.call_count, 2) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -227,10 +227,10 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertFalse(persist_and_notify) self.assertFalse(federation_ping) self.assertFalse(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) - self.assertEquals(wheel_timer.insert.call_count, 1) + self.assertEqual(wheel_timer.insert.call_count, 1) wheel_timer.insert.assert_has_calls( [ call( @@ -259,10 +259,10 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): ) self.assertTrue(persist_and_notify) - self.assertEquals(new_state.state, state.state) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 0) + self.assertEqual(wheel_timer.insert.call_count, 0) def test_online_to_idle(self): wheel_timer = Mock() @@ -281,12 +281,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): ) self.assertTrue(persist_and_notify) - self.assertEquals(new_state.state, state.state) - self.assertEquals(state.last_federation_update_ts, now) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEqual(new_state.state, state.state) + self.assertEqual(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) - self.assertEquals(wheel_timer.insert.call_count, 1) + self.assertEqual(wheel_timer.insert.call_count, 1) wheel_timer.insert.assert_has_calls( [ call( @@ -357,8 +357,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) + self.assertEqual(new_state.status_msg, status_msg) def test_busy_no_idle(self): """ @@ -380,8 +380,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.BUSY) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.BUSY) + self.assertEqual(new_state.status_msg, status_msg) def test_sync_timeout(self): user_id = "@foo:bar" @@ -399,8 +399,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.OFFLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.OFFLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_sync_online(self): user_id = "@foo:bar" @@ -420,8 +420,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.ONLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.ONLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_federation_ping(self): user_id = "@foo:bar" @@ -440,7 +440,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(state, new_state) + self.assertEqual(state, new_state) def test_no_timeout(self): user_id = "@foo:bar" @@ -477,8 +477,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.OFFLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.OFFLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_last_active(self): user_id = "@foo:bar" @@ -497,7 +497,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(state, new_state) + self.assertEqual(state, new_state) class PresenceHandlerTestCase(unittest.HomeserverTestCase): diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 69e299fc17..972cbac6e4 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -65,7 +65,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): displayname = self.get_success(self.handler.get_displayname(self.frank)) - self.assertEquals("Frank", displayname) + self.assertEqual("Frank", displayname) def test_set_my_name(self): self.get_success( @@ -74,7 +74,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -90,7 +90,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -118,7 +118,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.store.set_profile_displayname(self.frank.localpart, "Frank") ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -150,7 +150,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): displayname = self.get_success(self.handler.get_displayname(self.alice)) - self.assertEquals(displayname, "Alice") + self.assertEqual(displayname, "Alice") self.mock_federation.make_query.assert_called_with( destination="remote", query_type="profile", @@ -172,7 +172,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals({"displayname": "Caroline"}, response) + self.assertEqual({"displayname": "Caroline"}, response) def test_get_my_avatar(self): self.get_success( @@ -182,7 +182,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) avatar_url = self.get_success(self.handler.get_avatar_url(self.frank)) - self.assertEquals("http://my.server/me.png", avatar_url) + self.assertEqual("http://my.server/me.png", avatar_url) def test_set_my_avatar(self): self.get_success( @@ -193,7 +193,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/pic.gif", ) @@ -207,7 +207,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) @@ -235,7 +235,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5de89c873b..5081b97573 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -314,4 +314,4 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ): """Tests that the _filter_out_hidden returns the expected output""" filtered_events = self.event_source.filter_out_hidden(events, "@me:server.org") - self.assertEquals(filtered_events, expected_output) + self.assertEqual(filtered_events, expected_output) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 51ee667ab4..45fd30cf43 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -167,7 +167,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): result_user_id, result_token = self.get_success( self.get_or_create_user(requester, frank.localpart, "Frankie") ) - self.assertEquals(result_user_id, user_id) + self.assertEqual(result_user_id, user_id) self.assertIsInstance(result_token, str) self.assertGreater(len(result_token), 20) @@ -183,7 +183,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): result_user_id, result_token = self.get_success( self.get_or_create_user(requester, local_part, None) ) - self.assertEquals(result_user_id, user_id) + self.assertEqual(result_user_id, user_id) self.assertTrue(result_token is not None) @override_config({"limit_usage_by_mau": False}) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 66b0bd4d1a..3aedc0767b 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -69,7 +69,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user(requester, sync_config), ResourceLimitError, ) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.auth_blocking._hs_disabled = False @@ -80,7 +80,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user(requester, sync_config), ResourceLimitError, ) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_unknown_room_version(self): """ @@ -122,7 +122,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): b"{}", tok, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # The rooms should appear in the sync response. result = self.get_success( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index e461e03599..f91a80b9fa 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.started_typing( @@ -169,13 +169,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -220,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_recv(self): self.room_members = [U_APPLE, U_ONION] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) channel = self.make_request( "PUT", @@ -239,13 +239,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -259,7 +259,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_recv_not_in_room(self): self.room_members = [U_APPLE, U_ONION] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) channel = self.make_request( "PUT", @@ -278,7 +278,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_not_called() - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -288,8 +288,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals(events[0], []) - self.assertEquals(events[1], 0) + self.assertEqual(events[0], []) + self.assertEqual(events[1], 0) @override_config({"send_federation": True}) def test_stopped_typing(self): @@ -302,7 +302,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.handler._member_typing_until[member] = 1002000 self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()} - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.stopped_typing( @@ -332,13 +332,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): try_trailing_slash_on_400=True, ) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) @@ -346,7 +346,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_typing_timeout(self): self.room_members = [U_APPLE, U_BANANA] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.started_typing( @@ -360,7 +360,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -370,7 +370,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -385,7 +385,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 2) + self.assertEqual(self.event_source.get_current_key(), 2) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -395,7 +395,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) @@ -414,7 +414,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() - self.assertEquals(self.event_source.get_current_key(), 3) + self.assertEqual(self.event_source.get_current_key(), 3) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -424,7 +424,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index e159169e22..92012cd6f7 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -1042,7 +1042,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): b'{"search_term":"user2"}', access_token=u1_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) > 0) # Disable user directory and check search returns nothing @@ -1053,5 +1053,5 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): b'{"search_term":"user2"}', access_token=u1_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index c49be33b9f..77ce8432ac 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -65,9 +65,9 @@ class SrvResolverTestCase(unittest.TestCase): servers = self.successResultOf(test_d) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, host_name) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) + self.assertEqual(servers[0].host, host_name) @defer.inlineCallbacks def test_from_cache_expired_and_dns_fail(self): @@ -88,8 +88,8 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock.lookupService.assert_called_once_with(service_name) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks def test_from_cache(self): @@ -114,8 +114,8 @@ class SrvResolverTestCase(unittest.TestCase): self.assertFalse(dns_client_mock.lookupService.called) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks def test_empty_cache(self): @@ -144,8 +144,8 @@ class SrvResolverTestCase(unittest.TestCase): servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) - self.assertEquals(len(servers), 0) - self.assertEquals(len(cache), 0) + self.assertEqual(len(servers), 0) + self.assertEqual(len(cache), 0) def test_disabled_service(self): """ @@ -201,6 +201,6 @@ class SrvResolverTestCase(unittest.TestCase): servers = self.successResultOf(resolve_d) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, b"host") + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) + self.assertEqual(servers[0].host, b"host") diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index eca6a443af..17dc42fd37 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -59,7 +59,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def setUp(self): # Patch up the equality operator for events so that we can check - # whether lists of events match using assertEquals + # whether lists of events match using assertEqual self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] return super().setUp() diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 09c48e85c7..95282f078e 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1909,7 +1909,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self) -> None: @@ -1957,7 +1957,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1980,7 +1980,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self) -> None: @@ -2010,7 +2010,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self) -> None: @@ -2044,7 +2044,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self) -> None: @@ -2074,8 +2074,8 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEquals( + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 008d635b70..6c4462e74a 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -104,7 +104,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -143,7 +143,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email, client_secret, ip) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -193,7 +193,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email_passwort_reset, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -230,7 +230,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) # Attempt to reset password without clicking the link self._reset_password(new_password, session_id, client_secret, expected_code=401) @@ -322,7 +322,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button @@ -337,7 +337,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def _get_link_from_email(self): assert self.email_attempts, "No emails have been sent" @@ -376,7 +376,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): }, }, ) - self.assertEquals(expected_code, channel.code, channel.result) + self.assertEqual(expected_code, channel.code, channel.result) class DeactivateTestCase(unittest.HomeserverTestCase): @@ -676,7 +676,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(self.email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -780,7 +780,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(self.email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) # Attempt to add email without clicking the link channel = self.make_request( @@ -981,7 +981,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def _get_link_from_email(self): assert self.email_attempts, "No emails have been sent" @@ -1010,7 +1010,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(request_email, client_secret) - self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) + self.assertEqual(len(self.email_attempts) - previous_email_attempts, 1) link = self._get_link_from_email() self._validate_token(link) diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index a90294003e..145f247836 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -65,13 +65,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/events?access_token=%s" % ("invalid" + self.token,) ) - self.assertEquals(channel.code, 401, msg=channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) # valid token, expect content channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertTrue("chunk" in channel.json_body) self.assertTrue("start" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -89,10 +89,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # We may get a presence event for ourselves down - self.assertEquals( + self.assertEqual( 0, len( [ @@ -153,4 +153,4 @@ class GetEventsTestCase(unittest.HomeserverTestCase): "/events/" + event_id, access_token=self.token, ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index a573cc3c2e..5c31a54421 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) self.pump() - self.assertEquals(filter.result, self.EXAMPLE_FILTER) + self.assertEqual(filter.result, self.EXAMPLE_FILTER) def test_add_filter_for_other_user(self): channel = self.make_request( @@ -55,7 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine @@ -68,7 +68,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.hs.is_mine = _is_mine self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self): filter_id = defer.ensureDeferred( @@ -83,7 +83,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.result["code"], b"200") - self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) + self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self): channel = self.make_request( @@ -91,7 +91,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.result["code"], b"404") - self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode # in errors.py diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py index ad0425ae65..c99f54cf4f 100644 --- a/tests/rest/client/test_groups.py +++ b/tests/rest/client/test_groups.py @@ -30,8 +30,8 @@ class GroupsTestCase(unittest.HomeserverTestCase): # Alice creates a group channel = self.make_request("POST", "/create_group", {"localpart": "spqr"}) - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals(channel.json_body, {"group_id": group_id}) + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual(channel.json_body, {"group_id": group_id}) # Bob creates a private room room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False) @@ -45,12 +45,12 @@ class GroupsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {} ) - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals(channel.json_body, {}) + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual(channel.json_body, {}) # Alice now tries to retrieve the room list of the space. channel = self.make_request("GET", f"/groups/{group_id}/rooms") - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals( + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual( channel.json_body, {"chunk": [], "total_room_count_estimate": 0} ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index d48defda63..090d2d0a29 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -136,10 +136,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -154,7 +154,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config( { @@ -181,10 +181,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -199,7 +199,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config( { @@ -226,10 +226,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -244,7 +244,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) @override_config({"session_lifetime": "24h"}) def test_soft_logout(self) -> None: @@ -252,8 +252,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we shouldn't be able to make requests without an access token channel = self.make_request(b"GET", TEST_URL) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") # log in as normal params = { @@ -263,22 +263,22 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # # test behaviour after deleting the expired device @@ -290,17 +290,17 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # ... but if we delete that device, it will be a proper logout self._delete_device(access_token_2, "kermit", "monkey", device_id) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], False) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], False) def _delete_device( self, access_token: str, user_id: str, password: str, device_id: str @@ -309,7 +309,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) - self.assertEquals(channel.code, 401, channel.result) + self.assertEqual(channel.code, 401, channel.result) # check it's a UI-Auth fail self.assertEqual( set(channel.json_body.keys()), @@ -332,7 +332,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token=access_token, content={"auth": auth}, ) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: @@ -343,20 +343,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard logout this session channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( @@ -369,20 +369,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") @@ -1129,7 +1129,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" @@ -1143,7 +1143,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" @@ -1157,7 +1157,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" @@ -1171,7 +1171,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice @@ -1185,7 +1185,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) @skip_unless(HAS_OIDC, "requires OIDC") diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index b9647d5bd8..4239e1e610 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -80,7 +80,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): def test_get_displayname_other(self): res = self._get_displayname(self.other) - self.assertEquals(res, "Bob") + self.assertEqual(res, "Bob") def test_set_displayname_other(self): channel = self.make_request( diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 2835d86e5b..4b95b8541c 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -65,7 +65,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) @@ -87,7 +87,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.result["code"], b"400", channel.result) def test_POST_appservice_registration_invalid(self): self.appservice = None # no application service exists @@ -98,21 +98,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self): request_data = json.dumps({"username": "kermit", "password": 666}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid password") + self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self): request_data = json.dumps({"username": 777, "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid username") + self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self): user_id = "@kermit:test" @@ -131,7 +131,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) @@ -141,9 +141,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Registration has been disabled") - self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["error"], "Registration has been disabled") + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") def test_POST_guest_registration(self): self.hs.config.key.macaroon_secret_key = "test" @@ -152,7 +152,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): @@ -160,8 +160,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Guest access is disabled") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting_guest(self): @@ -170,16 +170,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", url, b"{}") if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self): @@ -194,16 +194,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url, request_data) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"registration_requires_token": True}) def test_POST_registration_requires_token(self): @@ -231,7 +231,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Request without auth to get flows and session channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one # flow would be a subset of another flow. @@ -249,7 +249,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -265,7 +265,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) # Check the `completed` counter has been incremented and pending is 0 @@ -276,8 +276,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcols=["pending", "completed"], ) ) - self.assertEquals(res["completed"], 1) - self.assertEquals(res["pending"], 0) + self.assertEqual(res["completed"], 1) + self.assertEqual(res["pending"], 0) @override_config({"registration_requires_token": True}) def test_POST_registration_token_invalid(self): @@ -295,23 +295,23 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session, } channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) + self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) def test_POST_registration_token_limit_uses(self): @@ -354,7 +354,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcol="pending", ) ) - self.assertEquals(pending, 1) + self.assertEqual(pending, 1) # Check auth fails when using token with session2 params2["auth"] = { @@ -363,9 +363,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session2, } channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY @@ -378,14 +378,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcols=["pending", "completed"], ) ) - self.assertEquals(res["pending"], 0) - self.assertEquals(res["completed"], 1) + self.assertEqual(res["pending"], 0) + self.assertEqual(res["completed"], 1) # Check auth still fails when using token with session2 channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) def test_POST_registration_token_expiry(self): @@ -417,9 +417,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session, } channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) # Update token so it expires tomorrow self.get_success( @@ -504,7 +504,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcol="result", ) ) - self.assertEquals(db_to_json(result2), token) + self.assertEqual(db_to_json(result2), token) # Delete both sessions (mimics expiry) self.get_success( @@ -519,7 +519,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcol="pending", ) ) - self.assertEquals(pending, 0) + self.assertEqual(pending, 0) @override_config({"registration_requires_token": True}) def test_POST_registration_token_session_expiry_deleted_token(self): @@ -572,7 +572,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_advertised_flows(self): channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # with the stock config, we only expect the dummy flow @@ -595,7 +595,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) def test_advertised_flows_captcha_and_terms_and_3pids(self): channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] self.assertCountEqual( @@ -627,7 +627,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) def test_advertised_flows_no_msisdn_email_required(self): channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # with the stock config, we expect all four combinations of 3pid @@ -671,7 +671,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIsNotNone(channel.json_body.get("sid")) @@ -694,9 +694,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) + self.assertEqual(400, channel.code, channel.result) # Check error to ensure that we're not erroring due to a bug in the test. - self.assertEquals( + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -707,8 +707,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": "email", "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( + self.assertEqual(400, channel.code, channel.result) + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -720,8 +720,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( + self.assertEqual(400, channel.code, channel.result) + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -745,7 +745,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Check that /available correctly ignores the username provided despite the # username being already registered. channel = self.make_request("GET", "register/available?username=" + username) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Test that when starting a UIA registration flow the request doesn't fail because # of a conflicting username @@ -799,14 +799,14 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -826,12 +826,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): params = {"user_id": user_id} request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) def test_manual_expire(self): user_id = self.register_user("kermit", "monkey") @@ -848,13 +848,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -873,18 +873,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Try to log the user out channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -959,7 +959,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -977,7 +977,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Move 1 day forward. Try to renew with the same token again. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -997,14 +997,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) def test_renewal_invalid_token(self): # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"404", channel.result) + self.assertEqual(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1028,7 +1028,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1103,7 +1103,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1183,8 +1183,8 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertEquals(channel.json_body["valid"], True) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["valid"], True) def test_GET_token_invalid(self): token = "1234" @@ -1192,8 +1192,8 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertEquals(channel.json_body["valid"], False) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["valid"], False) @override_config( {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} @@ -1208,10 +1208,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): ) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -1219,4 +1219,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 5687dea48d..8f7181103b 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -69,7 +69,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] @@ -78,7 +78,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assert_dict( { @@ -103,7 +103,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) # Unless that event is referenced from another event! self.get_success( @@ -123,7 +123,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) def test_deny_invalid_room(self): """Test that we deny relations on non-existant events""" @@ -136,15 +136,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) def test_deny_double_react(self): """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) def test_deny_forked_thread(self): """It is invalid to start a thread off a thread.""" @@ -154,7 +154,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo"}, parent_id=self.parent_id, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) parent_id = channel.json_body["event_id"] channel = self._send_relation( @@ -163,16 +163,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo"}, parent_id=parent_id, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) def test_basic_paginate_relations(self): """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) first_annotation_id = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) second_annotation_id = channel.json_body["event_id"] channel = self.make_request( @@ -180,11 +180,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the latest # full relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( { "event_id": second_annotation_id, @@ -195,7 +195,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( + self.assertEqual( channel.json_body["original_event"]["event_id"], self.parent_id ) @@ -212,11 +212,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the earliest # full relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( { "event_id": first_annotation_id, @@ -245,7 +245,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) prev_token = "" @@ -260,12 +260,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) # Reset and try again, but convert the tokens to the legacy format. prev_token = "" @@ -288,12 +288,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -301,12 +301,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) def test_pagination_from_sync_and_messages(self): """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Send an event after the relation events. self.helper.send(self.room, body="Latest event", tok=self.user_token) @@ -319,7 +319,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] sync_prev_batch = room_timeline["prev_batch"] self.assertIsNotNone(sync_prev_batch) @@ -335,7 +335,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/messages?dir=b&limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) messages_end = channel.json_body["end"] self.assertIsNotNone(messages_end) # Ensure the relation event is not in the chunk returned from /messages. @@ -355,7 +355,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # The relation should be in the returned chunk. self.assertIn( @@ -386,7 +386,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): key=key, access_token=access_tokens[idx], ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) idx += 1 idx %= len(access_tokens) @@ -404,7 +404,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id, from_token), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -419,13 +419,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: break - self.assertEquals(sent_groups, found_groups) + self.assertEqual(sent_groups, found_groups) def test_aggregation_pagination_within_group(self): """Test that we can paginate within an annotation group.""" @@ -449,14 +449,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): key="👍", access_token=access_tokens[idx], ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) idx += 1 # Also send a different type of reaction so that we test we don't see it channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) prev_token = "" found_event_ids: List[str] = [] @@ -473,7 +473,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/m.reaction/{encoded_key}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -481,7 +481,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -489,7 +489,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) # Reset and try again, but convert the tokens to the legacy format. prev_token = "" @@ -506,7 +506,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/m.reaction/{encoded_key}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -514,7 +514,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -522,21 +522,21 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) def test_aggregation(self): """Test that annotations get correctly aggregated.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -544,9 +544,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual( channel.json_body, { "chunk": [ @@ -560,13 +560,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): """Test that annotations get correctly aggregated after a redaction.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Now lets redact one of the 'a' reactions channel = self.make_request( @@ -575,7 +575,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content={}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -583,9 +583,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual( channel.json_body, {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) @@ -599,7 +599,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id, RelationTypes.REPLACE), access_token=self.user_token, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) @unittest.override_config( {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} @@ -615,29 +615,29 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply_1 = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] def assert_bundle(event_json: JsonDict) -> None: @@ -655,7 +655,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) # Check the values of each field. - self.assertEquals( + self.assertEqual( { "chunk": [ {"type": "m.reaction", "key": "a", "count": 2}, @@ -665,12 +665,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): relations_dict[RelationTypes.ANNOTATION], ) - self.assertEquals( + self.assertEqual( {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, relations_dict[RelationTypes.REFERENCE], ) - self.assertEquals( + self.assertEqual( 2, relations_dict[RelationTypes.THREAD].get("count"), ) @@ -701,7 +701,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body) # Request the room messages. @@ -710,7 +710,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/messages?dir=b", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. @@ -719,12 +719,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/context/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) # Request sync. channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) @@ -737,7 +737,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"search_categories": {"room_events": {"search_term": "Hi"}}}, access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) chunk = [ result["result"] for result in channel.json_body["search_categories"]["room_events"][ @@ -751,42 +751,42 @@ class RelationsTestCase(unittest.HomeserverTestCase): when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Annotate the annotation. channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", f"/rooms/{self.room}/event/{annotation_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) def test_aggregation_get_event_for_thread(self): """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] # Annotate the annotation. channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", f"/rooms/{self.room}/event/{thread_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( channel.json_body["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { @@ -801,11 +801,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1) thread_message = channel.json_body["chunk"][0] - self.assertEquals( + self.assertEqual( thread_message["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { @@ -905,7 +905,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) # And when fetching aggregations. @@ -914,7 +914,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) # And for bundled aggregations. @@ -923,7 +923,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{room2}/event/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) @@ -936,7 +936,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -958,8 +958,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["content"], new_body) assert_bundle(channel.json_body) # Request the room messages. @@ -968,7 +968,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/messages?dir=b", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. @@ -977,7 +977,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/context/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) # Request sync, but limit the timeline so it becomes limited (and includes @@ -988,7 +988,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) @@ -1001,7 +1001,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"search_categories": {"room_events": {"search_term": "Hi"}}}, access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) chunk = [ result["result"] for result in channel.json_body["search_categories"]["room_events"][ @@ -1024,7 +1024,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( @@ -1032,7 +1032,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -1045,16 +1045,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(channel.json_body["content"], new_body) relations_dict = channel.json_body["unsigned"].get("m.relations") self.assertIn(RelationTypes.REPLACE, relations_dict) @@ -1076,7 +1076,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A reply!"}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1086,7 +1086,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=reply, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -1095,7 +1095,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, reply), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to see the new body in the dict, as well as the reference # metadata sill intact. @@ -1133,7 +1133,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A threaded reply!"}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) threaded_event_id = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1143,7 +1143,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=threaded_event_id, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Fetch the thread root, to get the bundled aggregation for the thread. channel = self.make_request( @@ -1151,7 +1151,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect that the edit message appears in the thread summary in the # unsigned relations section. @@ -1161,9 +1161,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): thread_summary = relations_dict[RelationTypes.THREAD] self.assertIn("latest_event", thread_summary) latest_event_in_thread = thread_summary["latest_event"] - self.assertEquals( - latest_event_in_thread["content"]["body"], "I've been edited!" - ) + self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") def test_edit_edit(self): """Test that an edit cannot be edited.""" @@ -1177,7 +1175,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": new_body, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] # Edit the edit event. @@ -1191,7 +1189,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, parent_id=edit_event_id, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Request the original event. channel = self.make_request( @@ -1199,9 +1197,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # The edit to the edit should be ignored. - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(channel.json_body["content"], new_body) # The relations information should not include the edit to the edit. relations_dict = channel.json_body["unsigned"].get("m.relations") @@ -1234,7 +1232,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned channel = self.make_request( @@ -1243,10 +1241,10 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) - self.assertEquals(len(channel.json_body["chunk"]), 1) + self.assertEqual(len(channel.json_body["chunk"]), 1) # Redact the original event channel = self.make_request( @@ -1256,7 +1254,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content="{}", ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Try to check for remaining m.replace relations channel = self.make_request( @@ -1265,11 +1263,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check that no relations are returned self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(channel.json_body["chunk"], []) def test_aggregations_redaction_prevents_access_to_aggregations(self): """Test that annotations of an event are redacted when the original event @@ -1283,7 +1281,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Redact the original channel = self.make_request( @@ -1297,7 +1295,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content="{}", ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check that aggregations returns zero channel = self.make_request( @@ -1306,15 +1304,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(channel.json_body["chunk"], []) def test_unknown_relations(self): """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1323,18 +1321,18 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the full # relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"}, channel.json_body["chunk"][0], ) # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( + self.assertEqual( channel.json_body["original_event"]["event_id"], self.parent_id ) @@ -1344,7 +1342,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) # But unknown relations can be directly queried. @@ -1354,8 +1352,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: """ @@ -1422,15 +1420,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_background_update(self): """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_good = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_bad = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_event_id = channel.json_body["event_id"] # Clean-up the table as if the inserts did not happen during event creation. @@ -1450,8 +1448,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( [ev["event_id"] for ev in channel.json_body["chunk"]], [annotation_event_id_good], ) @@ -1475,7 +1473,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertCountEqual( [ev["event_id"] for ev in channel.json_body["chunk"]], [annotation_event_id_good, thread_event_id], diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 1afd96b8f5..e0b11e7264 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -95,7 +95,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # set topic for public room channel = self.make_request( @@ -103,7 +103,7 @@ class RoomPermissionsTestCase(RoomBase): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -125,28 +125,28 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): topic_content = b'{"topic":"My Topic Name"}' @@ -156,28 +156,28 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) @@ -185,25 +185,25 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 channel = self.make_request( @@ -211,7 +211,7 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def _test_get_membership( self, room=None, members: Iterable = frozenset(), expect_code=None @@ -219,7 +219,7 @@ class RoomPermissionsTestCase(RoomBase): for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) channel = self.make_request("GET", path) - self.assertEquals(expect_code, channel.code) + self.assertEqual(expect_code, channel.code) def test_membership_basic_room_perms(self): # === room does not exist === @@ -478,16 +478,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_with_at_token(self): """ @@ -498,7 +498,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check that permission is denied for @sid1:red to get the @@ -507,7 +507,7 @@ class RoomsMemberListTestCase(RoomBase): "GET", f"/rooms/{room_id}/members?at={sync_token}", ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member(self): """ @@ -520,14 +520,14 @@ class RoomsMemberListTestCase(RoomBase): # check that the user can see the member list to start with channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # ban the user self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") # check the user can no longer see the member list channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member_with_at_token(self): """ @@ -541,14 +541,14 @@ class RoomsMemberListTestCase(RoomBase): # sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check that the user can see the member list to start with channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # ban the user (Note: the user is actually allowed to see this event and # state so that they know they're banned!) @@ -560,14 +560,14 @@ class RoomsMemberListTestCase(RoomBase): # now, with the original user, sync again to get a new at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check the user can no longer see the updated member list channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): room_creator = "@some_other_guy:red" @@ -576,17 +576,17 @@ class RoomsMemberListTestCase(RoomBase): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. channel = self.make_request("GET", room_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) class RoomsCreateTestCase(RoomBase): @@ -598,19 +598,19 @@ class RoomsCreateTestCase(RoomBase): # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self): # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self): @@ -618,16 +618,16 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) def test_post_room_invitees_invalid_mxid(self): # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 @@ -635,7 +635,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) def test_post_room_invitees_ratelimit(self): @@ -694,9 +694,9 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) - self.assertEquals(join_mock.call_count, 0) + self.assertEqual(join_mock.call_count, 0) class RoomTopicTestCase(RoomBase): @@ -712,54 +712,54 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self): # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there channel = self.make_request("GET", self.path) - self.assertEquals(404, channel.code, msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -775,22 +775,22 @@ class RoomMemberStateTestCase(RoomBase): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -799,7 +799,7 @@ class RoomMemberStateTestCase(RoomBase): Membership.LEAVE, ) channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( @@ -810,13 +810,13 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} - self.assertEquals(expected_response, channel.json_body) + self.assertEqual(expected_response, channel.json_body) def test_rooms_members_other(self): self.other_id = "@zzsid1:red" @@ -828,11 +828,11 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self): self.other_id = "@zzsid1:red" @@ -847,11 +847,11 @@ class RoomMemberStateTestCase(RoomBase): "Join us!", ) channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(json.loads(content), channel.json_body) class RoomInviteRatelimitTestCase(RoomBase): @@ -937,7 +937,7 @@ class RoomJoinTestCase(RoomBase): False, ), ) - self.assertEquals( + self.assertEqual( callback_mock.call_args, expected_call_args, callback_mock.call_args, @@ -955,7 +955,7 @@ class RoomJoinTestCase(RoomBase): True, ), ) - self.assertEquals( + self.assertEqual( callback_mock.call_args, expected_call_args, callback_mock.call_args, @@ -1013,7 +1013,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # Check that all the rooms have been sent a profile update into. for room_id in room_ids: @@ -1023,10 +1023,10 @@ class RoomJoinRatelimitTestCase(RoomBase): ) channel = self.make_request("GET", path) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) self.assertIn("displayname", channel.json_body) - self.assertEquals(channel.json_body["displayname"], "John Doe") + self.assertEqual(channel.json_body["displayname"], "John Doe") @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} @@ -1047,7 +1047,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # if all of these requests ended up joining the user to a room. for _ in range(4): channel = self.make_request("POST", path % room_id, {}) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) @unittest.override_config( { @@ -1078,40 +1078,40 @@ class RoomMessagesTestCase(RoomBase): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b"text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b"") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' channel = self.make_request("PUT", path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) class RoomInitialSyncTestCase(RoomBase): @@ -1125,10 +1125,10 @@ class RoomInitialSyncTestCase(RoomBase): def test_initial_sync(self): channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.room_id, channel.json_body["room_id"]) - self.assertEquals("join", channel.json_body["membership"]) + self.assertEqual(self.room_id, channel.json_body["room_id"]) + self.assertEqual("join", channel.json_body["membership"]) # Room state is easier to assert on if we unpack it into a dict state = {} @@ -1152,7 +1152,7 @@ class RoomInitialSyncTestCase(RoomBase): e["content"]["user_id"]: e for e in channel.json_body["presence"] } self.assertTrue(self.user_id in presence_by_user) - self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) + self.assertEqual("m.presence", presence_by_user[self.user_id]["type"]) class RoomMessageListTestCase(RoomBase): @@ -1168,9 +1168,9 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) + self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -1179,9 +1179,9 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) + self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -2614,7 +2614,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): }, access_token=self.tok, ) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) # Check that the callback was called with the right params. mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) @@ -2636,7 +2636,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): }, access_token=self.tok, ) - self.assertEquals(channel.code, 403) + self.assertEqual(channel.code, 403) # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 7d0e66b534..2634c98dde 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -96,7 +96,7 @@ class RoomTestCase(_ShadowBannedBase): {"id_server": "test", "medium": "email", "address": "test@test.test"}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # This should have raised an error earlier, but double check this wasn't called. identity_handler.lookup_3pid.assert_not_called() @@ -110,7 +110,7 @@ class RoomTestCase(_ShadowBannedBase): {"visibility": "public", "invite": [self.other_user_id]}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) room_id = channel.json_body["room_id"] # But the user wasn't actually invited. @@ -165,7 +165,7 @@ class RoomTestCase(_ShadowBannedBase): {"new_version": "6"}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # A new room_id should be returned. self.assertIn("replacement_room", channel.json_body) @@ -190,11 +190,11 @@ class RoomTestCase(_ShadowBannedBase): {"typing": True, "timeout": 30000}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # There should be no typing events. event_source = self.hs.get_event_sources().sources.typing - self.assertEquals(event_source.get_current_key(), 0) + self.assertEqual(event_source.get_current_key(), 0) # The other user can join and send typing events. self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) @@ -205,10 +205,10 @@ class RoomTestCase(_ShadowBannedBase): {"typing": True, "timeout": 30000}, access_token=self.other_access_token, ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # These appear in the room. - self.assertEquals(event_source.get_current_key(), 1) + self.assertEqual(event_source.get_current_key(), 1) events = self.get_success( event_source.get_new_events( user=UserID.from_string(self.other_user_id), @@ -218,7 +218,7 @@ class RoomTestCase(_ShadowBannedBase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -257,7 +257,7 @@ class ProfileTestCase(_ShadowBannedBase): {"displayname": new_display_name}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertEqual(channel.json_body, {}) # The user's display name should be updated. @@ -299,7 +299,7 @@ class ProfileTestCase(_ShadowBannedBase): {"membership": "join", "displayname": new_display_name}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("event_id", channel.json_body) # The display name in the room should not be changed. diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py index c42c8aff6c..294f46fb95 100644 --- a/tests/rest/client/test_shared_rooms.py +++ b/tests/rest/client/test_shared_rooms.py @@ -91,9 +91,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): # Check shared rooms from user1's perspective. # We should see the one room in common channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room_id_one) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room_id_one) # Create another room and invite user2 to it room_id_two = self.helper.create_room_as( @@ -104,8 +104,8 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): # Check shared rooms again. We should now see both rooms. channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 2) for room_id_id in channel.json_body["joined"]: self.assertIn(room_id_id, [room_id_one, room_id_two]) @@ -125,18 +125,18 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): # Assert user directory is not empty channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room) self.helper.leave(room, user=u1, tok=u1_token) # Check user1's view of shared rooms with user2 channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) # Check user2's view of shared rooms with user1 channel = self._get_shared_rooms(u2_token, u1) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 69b4ef5378..4351013952 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -237,10 +237,10 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Stop typing. @@ -249,7 +249,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": false}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # Start typing. channel = self.make_request( @@ -257,11 +257,11 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # Should return immediately channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Reset typing serial back to 0, as if the master had. @@ -273,7 +273,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.helper.send(room, body="There!", tok=other_access_token) channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # This should time out! But it does not, because our stream token is @@ -281,7 +281,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): # already seen) is new, since it's got a token above our new, now-reset # stream token. channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Clear the typing information, so that it doesn't think everything is @@ -351,7 +351,7 @@ class SyncKnockTestCase( b"{}", self.knocker_tok, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # We expect to see the knock event in the stripped room state later self.expected_room_state[EventTypes.Member] = { diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index ac6b86ff6b..9cca9edd30 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -139,7 +139,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) callback.assert_called_once() @@ -157,7 +157,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) def test_third_party_rules_workaround_synapse_errors_pass_through(self): """ @@ -193,7 +193,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): access_token=self.tok, ) # Check the error code - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) # Check the JSON body has had the `nasty` key injected self.assertEqual( channel.json_body, @@ -329,10 +329,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.hs.get_module_api().create_and_send_event_into_room(event_dict) ) - self.assertEquals(event.sender, self.user_id) - self.assertEquals(event.room_id, self.room_id) - self.assertEquals(event.type, "m.room.message") - self.assertEquals(event.content, content) + self.assertEqual(event.sender, self.user_id) + self.assertEqual(event.room_id, self.room_id) + self.assertEqual(event.type, "m.room.message") + self.assertEqual(event.content, content) @unittest.override_config( { diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index de312cb63c..8b2da88e8a 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -72,9 +72,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=UserID.from_string(self.user_id), @@ -84,7 +84,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -101,7 +101,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": false}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) def test_typing_timeout(self): channel = self.make_request( @@ -109,19 +109,19 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) self.reactor.advance(36) - self.assertEquals(self.event_source.get_current_key(), 2) + self.assertEqual(self.event_source.get_current_key(), 2) channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 3) + self.assertEqual(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index 7f79336abc..658c21b2a1 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -65,7 +65,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): Upgrading a room should work fine. """ channel = self._upgrade_room() - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) def test_not_in_room(self): @@ -77,7 +77,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): roomless_token = self.login(roomless, "pass") channel = self._upgrade_room(roomless_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) def test_power_levels(self): """ @@ -85,7 +85,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -103,7 +103,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def test_power_levels_user_default(self): """ @@ -111,7 +111,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -129,7 +129,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def test_power_levels_tombstone(self): """ @@ -137,7 +137,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -155,7 +155,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) power_levels = self.helper.get_state( self.room_id, @@ -197,7 +197,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # Upgrade the room! channel = self._upgrade_room(room_id=space_id) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) new_space_id = channel.json_body["replacement_room"] diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 6878ccddbf..cba9be17c4 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -94,7 +94,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): self.assertTrue(os.path.exists(local_path)) # Asserts the file is under the expected local cache directory - self.assertEquals( + self.assertEqual( os.path.commonprefix([self.primary_base_path, local_path]), self.primary_base_path, ) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 59def6e59c..1f6a9eb07b 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -88,18 +88,18 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): res = self.get_success( self.store.have_seen_events("room1", ["event10", "event19"]) ) - self.assertEquals(res, {"event10"}) + self.assertEqual(res, {"event10"}) # that should result in a single db query - self.assertEquals(ctx.get_resource_usage().db_txn_count, 1) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) # a second lookup of the same events should cause no queries with LoggingContext(name="test") as ctx: res = self.get_success( self.store.have_seen_events("room1", ["event10", "event19"]) ) - self.assertEquals(res, {"event10"}) - self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + self.assertEqual(res, {"event10"}) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) def test_query_via_event_cache(self): # fetch an event into the event cache @@ -108,8 +108,8 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # looking it up should now cause no db hits with LoggingContext(name="test") as ctx: res = self.get_success(self.store.have_seen_events("room1", ["event10"])) - self.assertEquals(res, {"event10"}) - self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + self.assertEqual(res, {"event10"}) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) class EventCacheTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index d2f654214e..ee599f4336 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -88,21 +88,21 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): def test_retrieve_unknown_service_token(self) -> None: service = self.store.get_app_service_by_token("invalid_token") - self.assertEquals(service, None) + self.assertEqual(service, None) def test_retrieval_of_service(self) -> None: stored_service = self.store.get_app_service_by_token(self.as_token) assert stored_service is not None - self.assertEquals(stored_service.token, self.as_token) - self.assertEquals(stored_service.id, self.as_id) - self.assertEquals(stored_service.url, self.as_url) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], []) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) + self.assertEqual(stored_service.token, self.as_token) + self.assertEqual(stored_service.id, self.as_id) + self.assertEqual(stored_service.url, self.as_url) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_ALIASES], []) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_ROOMS], []) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_USERS], []) def test_retrieval_of_all_services(self) -> None: services = self.store.get_app_services() - self.assertEquals(len(services), 3) + self.assertEqual(len(services), 3) class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): @@ -182,7 +182,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): ) -> None: service = Mock(id="999") state = self.get_success(self.store.get_appservice_state(service)) - self.assertEquals(None, state) + self.assertEqual(None, state) def test_get_appservice_state_up( self, @@ -194,7 +194,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): state = self.get_success( defer.ensureDeferred(self.store.get_appservice_state(service)) ) - self.assertEquals(ApplicationServiceState.UP, state) + self.assertEqual(ApplicationServiceState.UP, state) def test_get_appservice_state_down( self, @@ -210,7 +210,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): ) service = Mock(id=self.as_list[1]["id"]) state = self.get_success(self.store.get_appservice_state(service)) - self.assertEquals(ApplicationServiceState.DOWN, state) + self.assertEqual(ApplicationServiceState.DOWN, state) def test_get_appservices_by_state_none( self, @@ -218,7 +218,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(0, len(services)) + self.assertEqual(0, len(services)) def test_set_appservices_state_down( self, @@ -235,7 +235,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (ApplicationServiceState.DOWN.value,), ) ) - self.assertEquals(service.id, rows[0][0]) + self.assertEqual(service.id, rows[0][0]) def test_set_appservices_state_multiple_up( self, @@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (ApplicationServiceState.UP.value,), ) ) - self.assertEquals(service.id, rows[0][0]) + self.assertEqual(service.id, rows[0][0]) def test_create_appservice_txn_first( self, @@ -270,9 +270,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.store.create_appservice_txn(service, events, [], [], {}, {}) ) ) - self.assertEquals(txn.id, 1) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 1) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_older_last_txn( self, @@ -285,9 +285,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): txn = self.get_success( self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9646) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9646) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_up_to_date_last_txn( self, @@ -298,9 +298,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): txn = self.get_success( self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9644) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9644) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_up_fuzzing( self, @@ -322,9 +322,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): txn = self.get_success( self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9644) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9644) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_complete_appservice_txn_first_txn( self, @@ -346,8 +346,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (service.id,), ) ) - self.assertEquals(1, len(res)) - self.assertEquals(txn_id, res[0][0]) + self.assertEqual(1, len(res)) + self.assertEqual(txn_id, res[0][0]) res = self.get_success( self.db_pool.runQuery( @@ -357,7 +357,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (txn_id,), ) ) - self.assertEquals(0, len(res)) + self.assertEqual(0, len(res)) def test_complete_appservice_txn_existing_in_state_table( self, @@ -379,9 +379,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (service.id,), ) ) - self.assertEquals(1, len(res)) - self.assertEquals(txn_id, res[0][0]) - self.assertEquals(ApplicationServiceState.UP.value, res[0][1]) + self.assertEqual(1, len(res)) + self.assertEqual(txn_id, res[0][0]) + self.assertEqual(ApplicationServiceState.UP.value, res[0][1]) res = self.get_success( self.db_pool.runQuery( @@ -391,7 +391,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (txn_id,), ) ) - self.assertEquals(0, len(res)) + self.assertEqual(0, len(res)) def test_get_oldest_unsent_txn_none( self, @@ -399,7 +399,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): service = Mock(id=self.as_list[0]["id"]) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) - self.assertEquals(None, txn) + self.assertEqual(None, txn) def test_get_oldest_unsent_txn(self) -> None: service = Mock(id=self.as_list[0]["id"]) @@ -416,9 +416,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 12, other_events)) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) - self.assertEquals(service, txn.service) - self.assertEquals(10, txn.id) - self.assertEquals(events, txn.events) + self.assertEqual(service, txn.service) + self.assertEqual(10, txn.id) + self.assertEqual(events, txn.events) def test_get_appservices_by_state_single( self, @@ -433,8 +433,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(1, len(services)) - self.assertEquals(self.as_list[0]["id"], services[0].id) + self.assertEqual(1, len(services)) + self.assertEqual(self.as_list[0]["id"], services[0].id) def test_get_appservices_by_state_multiple( self, @@ -455,8 +455,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(2, len(services)) - self.assertEquals( + self.assertEqual(2, len(services)) + self.assertEqual( {self.as_list[2]["id"], self.as_list[0]["id"]}, {services[0].id, services[1].id}, ) @@ -476,12 +476,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) - self.assertEquals(value, 0) + self.assertEqual(value, 0) value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "presence") ) - self.assertEquals(value, 0) + self.assertEqual(value, 0) def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 3e4f0579c9..a8ffb52c05 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -103,7 +103,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEquals("Value", value) + self.assertEqual("Value", value) self.mock_txn.execute.assert_called_with( "SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -121,7 +121,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) + self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) self.mock_txn.execute.assert_called_with( "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -154,7 +154,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) + self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.mock_txn.execute.assert_called_with( "SELECT colA FROM tablename WHERE keycol = ?", ["A set"] ) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 7b72a92424..20bf3ca17b 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -31,7 +31,7 @@ class DirectoryStoreTestCase(HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( ["#my-room:test"], (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))), ) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index c9e3b9fa79..0f9add4841 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -57,7 +57,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) ) - self.assertEquals( + self.assertEqual( counts, NotifCounts( notify_count=noitf_count, diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index 4ca212fd11..5806cb0e4b 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -38,12 +38,12 @@ class DataStoreTestCase(unittest.HomeserverTestCase): self.store.get_users_paginate(0, 10, name="bc", guests=False) ) - self.assertEquals(1, total) - self.assertEquals(self.displayname, users.pop()["displayname"]) + self.assertEqual(1, total) + self.assertEqual(self.displayname, users.pop()["displayname"]) users, total = self.get_success( self.store.get_users_paginate(0, 10, name="BC", guests=False) ) - self.assertEquals(1, total) - self.assertEquals(self.displayname, users.pop()["displayname"]) + self.assertEqual(1, total) + self.assertEqual(self.displayname, users.pop()["displayname"]) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index b6f99af2f1..a019d06e09 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): self.store.set_profile_displayname(self.u_frank.localpart, "Frank") ) - self.assertEquals( + self.assertEqual( "Frank", ( self.get_success( @@ -60,7 +60,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( "http://my.site/here", ( self.get_success( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 1fa495f778..a49ac1525e 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -30,7 +30,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): def test_register(self): self.get_success(self.store.register_user(self.user_id, self.pwhash)) - self.assertEquals( + self.assertEqual( { # TODO(paul): Surely this field should be 'user_id', not 'name' "name": self.user_id, @@ -131,7 +131,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): ), ThreepidValidationError, ) - self.assertEquals(e.value.msg, "Unknown session_id", e) + self.assertEqual(e.value.msg, "Unknown session_id", e) # Set the config setting to true. self.store._ignore_unknown_session_error = True @@ -146,4 +146,4 @@ class RegistrationStoreTestCase(HomeserverTestCase): ), ThreepidValidationError, ) - self.assertEquals(e.value.msg, "Validation token not found or has expired", e) + self.assertEqual(e.value.msg, "Validation token not found or has expired", e) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 42bfca2a83..5b011e18cd 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -104,7 +104,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): self.store.get_current_state(room_id=self.room.to_string()) ) - self.assertEquals(1, len(state)) + self.assertEqual(1, len(state)) self.assertObjectHasAttributes( {"type": "m.room.name", "room_id": self.room.to_string(), "name": name}, state[0], @@ -121,7 +121,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): self.store.get_current_state(room_id=self.room.to_string()) ) - self.assertEquals(1, len(state)) + self.assertEqual(1, len(state)) self.assertObjectHasAttributes( {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic}, state[0], diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index d62e01726c..8dfc1e1db9 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -53,7 +53,7 @@ class EventSearchInsertionTest(HomeserverTestCase): result = self.get_success( store.search_msgs([room_id], "hi bob", ["content.body"]) ) - self.assertEquals(result.get("count"), 1) + self.assertEqual(result.get("count"), 1) if isinstance(store.database_engine, PostgresEngine): self.assertIn("hi", result.get("highlights")) self.assertIn("bob", result.get("highlights")) @@ -62,14 +62,14 @@ class EventSearchInsertionTest(HomeserverTestCase): result = self.get_success( store.search_msgs([room_id], "another", ["content.body"]) ) - self.assertEquals(result.get("count"), 1) + self.assertEqual(result.get("count"), 1) if isinstance(store.database_engine, PostgresEngine): self.assertIn("another", result.get("highlights")) # Check that search works for a search term that overlaps with the message # containing a null byte and an unrelated message. result = self.get_success(store.search_msgs([room_id], "hi", ["content.body"])) - self.assertEquals(result.get("count"), 2) + self.assertEqual(result.get("count"), 2) result = self.get_success( store.search_msgs([room_id], "hi alice", ["content.body"]) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 7028f0dfb0..b8f09a8ee0 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -55,7 +55,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals([self.room], [m.room_id for m in rooms_for_user]) + self.assertEqual([self.room], [m.room_id for m in rooms_for_user]) def test_count_known_servers(self): """ diff --git a/tests/test_distributor.py b/tests/test_distributor.py index f8341041ee..31546ea52b 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -48,7 +48,7 @@ class DistributorTestCase(unittest.TestCase): observers[0].assert_called_once_with("Go") observers[1].assert_called_once_with("Go") - self.assertEquals(mock_logger.warning.call_count, 1) + self.assertEqual(mock_logger.warning.call_count, 1) self.assertIsInstance(mock_logger.warning.call_args[0][0], str) def test_signal_prereg(self): diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 67dcf567cd..37fada5c53 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -54,7 +54,7 @@ class TermsTestCase(unittest.HomeserverTestCase): request_data = json.dumps({"username": "kermit", "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) self.assertTrue(channel.json_body is not None) self.assertIsInstance(channel.json_body["session"], str) @@ -99,7 +99,7 @@ class TermsTestCase(unittest.HomeserverTestCase): # We don't bother checking that the response is correct - we'll leave that to # other tests. We just want to make sure we're on the right path. - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) # Finish the UI auth for terms request_data = json.dumps( @@ -117,7 +117,7 @@ class TermsTestCase(unittest.HomeserverTestCase): # We're interested in getting a response that looks like a successful # registration, not so much that the details are exactly what we want. - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertTrue(channel.json_body is not None) self.assertIsInstance(channel.json_body["user_id"], str) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index f2ef1c6051..d04bcae0fa 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -25,7 +25,7 @@ class MockClockTestCase(unittest.TestCase): self.clock.advance_time(20) - self.assertEquals(20, self.clock.time() - start_time) + self.assertEqual(20, self.clock.time() - start_time) def test_later(self): invoked = [0, 0] diff --git a/tests/test_types.py b/tests/test_types.py index 0d0c00d97a..80888a744d 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -22,9 +22,9 @@ class UserIDTestCase(unittest.HomeserverTestCase): def test_parse(self): user = UserID.from_string("@1234abcd:test") - self.assertEquals("1234abcd", user.localpart) - self.assertEquals("test", user.domain) - self.assertEquals(True, self.hs.is_mine(user)) + self.assertEqual("1234abcd", user.localpart) + self.assertEqual("test", user.domain) + self.assertEqual(True, self.hs.is_mine(user)) def test_pase_empty(self): with self.assertRaises(SynapseError): @@ -33,7 +33,7 @@ class UserIDTestCase(unittest.HomeserverTestCase): def test_build(self): user = UserID("5678efgh", "my.domain") - self.assertEquals(user.to_string(), "@5678efgh:my.domain") + self.assertEqual(user.to_string(), "@5678efgh:my.domain") def test_compare(self): userA = UserID.from_string("@userA:my.domain") @@ -48,14 +48,14 @@ class RoomAliasTestCase(unittest.HomeserverTestCase): def test_parse(self): room = RoomAlias.from_string("#channel:test") - self.assertEquals("channel", room.localpart) - self.assertEquals("test", room.domain) - self.assertEquals(True, self.hs.is_mine(room)) + self.assertEqual("channel", room.localpart) + self.assertEqual("test", room.domain) + self.assertEqual(True, self.hs.is_mine(room)) def test_build(self): room = RoomAlias("channel", "my.domain") - self.assertEquals(room.to_string(), "#channel:my.domain") + self.assertEqual(room.to_string(), "#channel:my.domain") def test_validate(self): id_string = "#test:domain,test" diff --git a/tests/unittest.py b/tests/unittest.py index 0caa8e7a45..326895f4c9 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -152,12 +152,12 @@ class TestCase(unittest.TestCase): def assertObjectHasAttributes(self, attrs, obj): """Asserts that the given object has each of the attributes given, and - that the value of each matches according to assertEquals.""" + that the value of each matches according to assertEqual.""" for key in attrs.keys(): if not hasattr(obj, key): raise AssertionError("Expected obj to have a '.%s'" % key) try: - self.assertEquals(attrs[key], getattr(obj, key)) + self.assertEqual(attrs[key], getattr(obj, key)) except AssertionError as e: raise (type(e))(f"Assert error for '.{key}':") from e @@ -169,7 +169,7 @@ class TestCase(unittest.TestCase): actual (dict): The test result. Extra keys will not be checked. """ for key in required: - self.assertEquals( + self.assertEqual( required[key], actual[key], msg="%s mismatch. %s" % (key, actual) ) diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index c613ce3f10..02b99b466a 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -31,7 +31,7 @@ class DeferredCacheTestCase(TestCase): cache = DeferredCache("test") cache.prefill("foo", 123) - self.assertEquals(self.successResultOf(cache.get("foo")), 123) + self.assertEqual(self.successResultOf(cache.get("foo")), 123) def test_hit_deferred(self): cache = DeferredCache("test") diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index ced3efd93f..b92d3f0c1b 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -434,8 +434,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals((yield a.func("bar")), "bar") + self.assertEqual((yield a.func("foo")), "foo") + self.assertEqual((yield a.func("bar")), "bar") @defer.inlineCallbacks def test_hit(self): @@ -450,10 +450,10 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() yield a.func("foo") - self.assertEquals(callcount[0], 1) + self.assertEqual(callcount[0], 1) - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals(callcount[0], 1) + self.assertEqual((yield a.func("foo")), "foo") + self.assertEqual(callcount[0], 1) @defer.inlineCallbacks def test_invalidate(self): @@ -468,13 +468,13 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() yield a.func("foo") - self.assertEquals(callcount[0], 1) + self.assertEqual(callcount[0], 1) a.func.invalidate(("foo",)) yield a.func("foo") - self.assertEquals(callcount[0], 2) + self.assertEqual(callcount[0], 2) def test_invalidate_missing(self): class A: @@ -499,7 +499,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): for k in range(0, 12): yield a.func(k) - self.assertEquals(callcount[0], 12) + self.assertEqual(callcount[0], 12) # There must have been at least 2 evictions, meaning if we calculate # all 12 values again, we must get called at least 2 more times @@ -525,8 +525,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a.func.prefill(("foo",), 456) - self.assertEquals(a.func("foo").result, 456) - self.assertEquals(callcount[0], 0) + self.assertEqual(a.func("foo").result, 456) + self.assertEqual(callcount[0], 0) @defer.inlineCallbacks def test_invalidate_context(self): @@ -547,19 +547,19 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() yield a.func2("foo") - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 1) a.func.invalidate(("foo",)) yield a.func("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 1) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) @defer.inlineCallbacks def test_eviction_context(self): @@ -581,22 +581,22 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): yield a.func2("foo") yield a.func2("foo2") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func("foo3") - self.assertEquals(callcount[0], 3) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 3) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 4) - self.assertEquals(callcount2[0], 3) + self.assertEqual(callcount[0], 4) + self.assertEqual(callcount2[0], 3) @defer.inlineCallbacks def test_double_get(self): @@ -619,30 +619,30 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): yield a.func2("foo") - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 1) a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 1) yield a.func2("foo") a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 2) - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 2) a.func.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 3) yield a.func("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 3) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 3) class CachedListDescriptorTestCase(unittest.TestCase): diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index e6e13ba06c..7f60aae5ba 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -26,8 +26,8 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache = ExpiringCache("test", clock, max_len=1) cache["key"] = "value" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache["key"], "value") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache["key"], "value") def test_eviction(self): clock = MockClock() @@ -35,13 +35,13 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache["key"] = "value" cache["key2"] = "value2" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache.get("key2"), "value2") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache.get("key2"), "value2") cache["key3"] = "value3" - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), "value2") - self.assertEquals(cache.get("key3"), "value3") + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), "value2") + self.assertEqual(cache.get("key3"), "value3") def test_iterable_eviction(self): clock = MockClock() @@ -51,15 +51,15 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache["key2"] = [2, 3] cache["key3"] = [4, 5] - self.assertEquals(cache.get("key"), [1]) - self.assertEquals(cache.get("key2"), [2, 3]) - self.assertEquals(cache.get("key3"), [4, 5]) + self.assertEqual(cache.get("key"), [1]) + self.assertEqual(cache.get("key2"), [2, 3]) + self.assertEqual(cache.get("key3"), [4, 5]) cache["key4"] = [6, 7] - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), None) - self.assertEquals(cache.get("key3"), [4, 5]) - self.assertEquals(cache.get("key4"), [6, 7]) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), None) + self.assertEqual(cache.get("key3"), [4, 5]) + self.assertEqual(cache.get("key4"), [6, 7]) def test_time_eviction(self): clock = MockClock() @@ -69,13 +69,13 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): clock.advance_time(0.5) cache["key2"] = 2 - self.assertEquals(cache.get("key"), 1) - self.assertEquals(cache.get("key2"), 2) + self.assertEqual(cache.get("key"), 1) + self.assertEqual(cache.get("key2"), 2) clock.advance_time(0.9) - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), 2) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), 2) clock.advance_time(1) - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), None) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), None) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 621b0f9fcd..2ad321e184 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -17,7 +17,7 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(current_context().name, value) + self.assertEqual(current_context().name, value) def test_with_context(self): with LoggingContext("test"): diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 291644eb7d..321fc1776f 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -27,37 +27,37 @@ class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache["key"], "value") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache["key"], "value") def test_eviction(self): cache = LruCache(2) cache[1] = 1 cache[2] = 2 - self.assertEquals(cache.get(1), 1) - self.assertEquals(cache.get(2), 2) + self.assertEqual(cache.get(1), 1) + self.assertEqual(cache.get(2), 2) cache[3] = 3 - self.assertEquals(cache.get(1), None) - self.assertEquals(cache.get(2), 2) - self.assertEquals(cache.get(3), 3) + self.assertEqual(cache.get(1), None) + self.assertEqual(cache.get(2), 2) + self.assertEqual(cache.get(3), 3) def test_setdefault(self): cache = LruCache(1) - self.assertEquals(cache.setdefault("key", 1), 1) - self.assertEquals(cache.get("key"), 1) - self.assertEquals(cache.setdefault("key", 2), 1) - self.assertEquals(cache.get("key"), 1) + self.assertEqual(cache.setdefault("key", 1), 1) + self.assertEqual(cache.get("key"), 1) + self.assertEqual(cache.setdefault("key", 2), 1) + self.assertEqual(cache.get("key"), 1) cache["key"] = 2 # Make sure overriding works. - self.assertEquals(cache.get("key"), 2) + self.assertEqual(cache.get("key"), 2) def test_pop(self): cache = LruCache(1) cache["key"] = 1 - self.assertEquals(cache.pop("key"), 1) - self.assertEquals(cache.pop("key"), None) + self.assertEqual(cache.pop("key"), 1) + self.assertEqual(cache.pop("key"), None) def test_del_multi(self): cache = LruCache(4, cache_type=TreeCache) @@ -66,23 +66,23 @@ class LruCacheTestCase(unittest.HomeserverTestCase): cache[("vehicles", "car")] = "vroom" cache[("vehicles", "train")] = "chuff" - self.assertEquals(len(cache), 4) + self.assertEqual(len(cache), 4) - self.assertEquals(cache.get(("animal", "cat")), "mew") - self.assertEquals(cache.get(("vehicles", "car")), "vroom") + self.assertEqual(cache.get(("animal", "cat")), "mew") + self.assertEqual(cache.get(("vehicles", "car")), "vroom") cache.del_multi(("animal",)) - self.assertEquals(len(cache), 2) - self.assertEquals(cache.get(("animal", "cat")), None) - self.assertEquals(cache.get(("animal", "dog")), None) - self.assertEquals(cache.get(("vehicles", "car")), "vroom") - self.assertEquals(cache.get(("vehicles", "train")), "chuff") + self.assertEqual(len(cache), 2) + self.assertEqual(cache.get(("animal", "cat")), None) + self.assertEqual(cache.get(("animal", "dog")), None) + self.assertEqual(cache.get(("vehicles", "car")), "vroom") + self.assertEqual(cache.get(("vehicles", "train")), "chuff") # Man from del_multi say "Yes". def test_clear(self): cache = LruCache(1) cache["key"] = 1 cache.clear() - self.assertEquals(len(cache), 0) + self.assertEqual(len(cache), 0) @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) def test_special_size(self): @@ -105,10 +105,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_multi_get(self): m = Mock() @@ -124,10 +124,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_set(self): m = Mock() @@ -140,10 +140,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_pop(self): m = Mock() @@ -153,13 +153,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.pop("key") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.pop("key") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_del_multi(self): m1 = Mock() @@ -173,17 +173,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set(("b", "1"), "value", callbacks=[m3]) cache.set(("b", "2"), "value", callbacks=[m4]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) - self.assertEquals(m4.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) + self.assertEqual(m4.call_count, 0) cache.del_multi(("a",)) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 1) - self.assertEquals(m3.call_count, 0) - self.assertEquals(m4.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 1) + self.assertEqual(m3.call_count, 0) + self.assertEqual(m4.call_count, 0) def test_clear(self): m1 = Mock() @@ -193,13 +193,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) cache.clear() - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 1) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 1) def test_eviction(self): m1 = Mock(name="m1") @@ -210,33 +210,33 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key3", "value", callbacks=[m3]) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key3", "value") - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.get("key2") - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key1", "value", callbacks=[m1]) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 1) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 1) class LruCacheSizedTestCase(unittest.HomeserverTestCase): @@ -247,20 +247,20 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): cache["key3"] = [3] cache["key4"] = [4] - self.assertEquals(cache["key1"], [0]) - self.assertEquals(cache["key2"], [1, 2]) - self.assertEquals(cache["key3"], [3]) - self.assertEquals(cache["key4"], [4]) - self.assertEquals(len(cache), 5) + self.assertEqual(cache["key1"], [0]) + self.assertEqual(cache["key2"], [1, 2]) + self.assertEqual(cache["key3"], [3]) + self.assertEqual(cache["key4"], [4]) + self.assertEqual(len(cache), 5) cache["key5"] = [5, 6] - self.assertEquals(len(cache), 4) - self.assertEquals(cache.get("key1"), None) - self.assertEquals(cache.get("key2"), None) - self.assertEquals(cache["key3"], [3]) - self.assertEquals(cache["key4"], [4]) - self.assertEquals(cache["key5"], [5, 6]) + self.assertEqual(len(cache), 4) + self.assertEqual(cache.get("key1"), None) + self.assertEqual(cache.get("key2"), None) + self.assertEqual(cache["key3"], [3]) + self.assertEqual(cache["key4"], [4]) + self.assertEqual(cache["key5"], [5, 6]) def test_zero_size_drop_from_cache(self) -> None: """Test that `drop_from_cache` works correctly with 0-sized entries.""" diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 6066372053..567cb18468 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -23,61 +23,61 @@ class TreeCacheTestCase(unittest.TestCase): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" - self.assertEquals(cache.get(("a",)), "A") - self.assertEquals(cache.get(("b",)), "B") - self.assertEquals(len(cache), 2) + self.assertEqual(cache.get(("a",)), "A") + self.assertEqual(cache.get(("b",)), "B") + self.assertEqual(len(cache), 2) def test_pop_onelevel(self): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" - self.assertEquals(cache.pop(("a",)), "A") - self.assertEquals(cache.pop(("a",)), None) - self.assertEquals(cache.get(("b",)), "B") - self.assertEquals(len(cache), 1) + self.assertEqual(cache.pop(("a",)), "A") + self.assertEqual(cache.pop(("a",)), None) + self.assertEqual(cache.get(("b",)), "B") + self.assertEqual(len(cache), 1) def test_get_set_twolevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.get(("a", "a")), "AA") - self.assertEquals(cache.get(("a", "b")), "AB") - self.assertEquals(cache.get(("b", "a")), "BA") - self.assertEquals(len(cache), 3) + self.assertEqual(cache.get(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "b")), "AB") + self.assertEqual(cache.get(("b", "a")), "BA") + self.assertEqual(len(cache), 3) def test_pop_twolevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.pop(("a", "a")), "AA") - self.assertEquals(cache.get(("a", "a")), None) - self.assertEquals(cache.get(("a", "b")), "AB") - self.assertEquals(cache.pop(("b", "a")), "BA") - self.assertEquals(cache.pop(("b", "a")), None) - self.assertEquals(len(cache), 1) + self.assertEqual(cache.pop(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "a")), None) + self.assertEqual(cache.get(("a", "b")), "AB") + self.assertEqual(cache.pop(("b", "a")), "BA") + self.assertEqual(cache.pop(("b", "a")), None) + self.assertEqual(len(cache), 1) def test_pop_mixedlevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.get(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "a")), "AA") popped = cache.pop(("a",)) - self.assertEquals(cache.get(("a", "a")), None) - self.assertEquals(cache.get(("a", "b")), None) - self.assertEquals(cache.get(("b", "a")), "BA") - self.assertEquals(len(cache), 1) + self.assertEqual(cache.get(("a", "a")), None) + self.assertEqual(cache.get(("a", "b")), None) + self.assertEqual(cache.get(("b", "a")), "BA") + self.assertEqual(len(cache), 1) - self.assertEquals({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) + self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) def test_clear(self): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" cache.clear() - self.assertEquals(len(cache), 0) + self.assertEqual(len(cache), 0) def test_contains(self): cache = TreeCache() From 9e83521af860cb33a7459dbe74188ce5ef39f446 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Feb 2022 07:52:44 -0500 Subject: [PATCH 18/40] Properly failover for unknown endpoints from Conduit/Dendrite. (#12077) Before this fix, a legitimate 404 from a federation endpoint (e.g. due to an unknown room) would be treated as an unknown endpoint. This could cause unnecessary federation traffic. --- changelog.d/12077.bugfix | 1 + synapse/federation/federation_client.py | 22 +++++++++++++--------- 2 files changed, 14 insertions(+), 9 deletions(-) create mode 100644 changelog.d/12077.bugfix diff --git a/changelog.d/12077.bugfix b/changelog.d/12077.bugfix new file mode 100644 index 0000000000..1bce82082d --- /dev/null +++ b/changelog.d/12077.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 2121e92e3a..a4bae3c4c8 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -615,11 +615,15 @@ class FederationClient(FederationBase): synapse_error = e.to_synapse_error() # There is no good way to detect an "unknown" endpoint. # - # Dendrite returns a 404 (with no body); synapse returns a 400 + # Dendrite returns a 404 (with a body of "404 page not found"); + # Conduit returns a 404 (with no body); and Synapse returns a 400 # with M_UNRECOGNISED. - return e.code == 404 or ( - e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED - ) + # + # This needs to be rather specific as some endpoints truly do return 404 + # errors. + return ( + e.code == 404 and (not e.response or e.response == b"404 page not found") + ) or (e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED) async def _try_destination_list( self, @@ -1002,7 +1006,7 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the v1 endpoint. Otherwise consider it a legitmate error + # fallback to the v1 endpoint. Otherwise, consider it a legitimate error # and raise. if not self._is_unknown_endpoint(e): raise @@ -1071,7 +1075,7 @@ class FederationClient(FederationBase): except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, # fallback to the v1 endpoint if the room uses old-style event IDs. - # Otherwise consider it a legitmate error and raise. + # Otherwise, consider it a legitimate error and raise. err = e.to_synapse_error() if self._is_unknown_endpoint(e, err): if room_version.event_format != EventFormatVersions.V1: @@ -1132,7 +1136,7 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the v1 endpoint. Otherwise consider it a legitmate error + # fallback to the v1 endpoint. Otherwise, consider it a legitimate error # and raise. if not self._is_unknown_endpoint(e): raise @@ -1458,8 +1462,8 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the unstable endpoint. Otherwise consider it a - # legitmate error and raise. + # fallback to the unstable endpoint. Otherwise, consider it a + # legitimate error and raise. if not self._is_unknown_endpoint(e): raise From 5565f454e1b323b637dd418549f70fadac0f44b4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 28 Feb 2022 14:10:36 +0000 Subject: [PATCH 19/40] Actually fix bad debug logging rejecting device list & signing key transactions (#12098) --- changelog.d/12098.bugfix | 1 + .../federation/transport/server/federation.py | 2 +- tests/federation/transport/test_server.py | 20 ++++++++++++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12098.bugfix diff --git a/changelog.d/12098.bugfix b/changelog.d/12098.bugfix new file mode 100644 index 0000000000..6b696692e3 --- /dev/null +++ b/changelog.d/12098.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.51.0rc1 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. \ No newline at end of file diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 9cc9a7339d..23ce343057 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -110,7 +110,7 @@ class FederationSendServlet(BaseFederationServerServlet): if issue_8631_logger.isEnabledFor(logging.DEBUG): DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"] device_list_updates = [ - edu.content + edu.get("content", {}) for edu in transaction_data.get("edus", []) if edu.get("edu_type") in DEVICE_UPDATE_EDUS ] diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index ce49d094d7..5f001c33b0 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -13,7 +13,7 @@ # limitations under the License. from tests import unittest -from tests.unittest import override_config +from tests.unittest import DEBUG, override_config class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): @@ -38,3 +38,21 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/publicRooms", ) self.assertEqual(200, channel.code) + + @DEBUG + def test_edu_debugging_doesnt_explode(self): + """Sanity check incoming federation succeeds with `synapse.debug_8631` enabled. + + Remove this when we strip out issue_8631_logger. + """ + channel = self.make_signed_federation_request( + "PUT", + "/_matrix/federation/v1/send/txn_id_1234/", + content={ + "edus": [ + {"edu_type": "m.device_list_update", "content": {"foo": "bar"}} + ], + "pdus": [], + }, + ) + self.assertEqual(200, channel.code) From 6c0b44a3d73f73dc5913f081418347645dc84d6f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 28 Feb 2022 17:40:24 +0000 Subject: [PATCH 20/40] Fix `PushRuleEvaluator` and `Filter` to work on frozendicts (#12100) * Fix `PushRuleEvaluator` to work on frozendicts frozendicts do not (necessarily) inherit from dict, so this needs to handle them correctly. * Fix event filtering for frozen events Looks like this one was introduced by #11194. --- changelog.d/12100.bugfix | 1 + synapse/api/filtering.py | 5 +++-- synapse/push/push_rule_evaluator.py | 8 ++++---- tests/api/test_filtering.py | 10 ++++++++++ tests/push/test_push_rule_evaluator.py | 9 +++++++++ 5 files changed, 27 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12100.bugfix diff --git a/changelog.d/12100.bugfix b/changelog.d/12100.bugfix new file mode 100644 index 0000000000..181095ad99 --- /dev/null +++ b/changelog.d/12100.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index fe4cc2e8ee..cb532d7238 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -22,6 +22,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Set, TypeVar, @@ -361,10 +362,10 @@ class Filter: return self._check_fields(field_matchers) else: content = event.get("content") - # Content is assumed to be a dict below, so ensure it is. This should + # Content is assumed to be a mapping below, so ensure it is. This should # always be true for events, but account_data has been allowed to # have non-dict content. - if not isinstance(content, dict): + if not isinstance(content, Mapping): content = {} sender = event.get("sender", None) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 659a53805d..f617c759e6 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -15,12 +15,12 @@ import logging import re -from typing import Any, Dict, List, Optional, Pattern, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union from matrix_common.regex import glob_to_regex, to_word_pattern from synapse.events import EventBase -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -223,7 +223,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: def _flatten_dict( - d: Union[EventBase, JsonDict], + d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: @@ -234,7 +234,7 @@ def _flatten_dict( for key, value in d.items(): if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() - elif isinstance(value, dict): + elif isinstance(value, Mapping): _flatten_dict(value, prefix=(prefix + [key]), result=result) return result diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 2525018e95..8c3354ce3c 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -18,6 +18,7 @@ from unittest.mock import patch import jsonschema +from frozendict import frozendict from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError @@ -327,6 +328,15 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertFalse(Filter(self.hs, definition)._check(event)) + # check it works with frozendicts too + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content=frozendict({EventContentFields.LABELS: ["#fun"]}), + ) + self.assertTrue(Filter(self.hs, definition)._check(event)) + def test_filter_not_labels(self): definition = {"org.matrix.not_labels": ["#fun"]} event = MockEvent( diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index a52e89e407..3849beb9d6 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -14,6 +14,8 @@ from typing import Any, Dict +import frozendict + from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent from synapse.push import push_rule_evaluator @@ -191,6 +193,13 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "pattern should only match at the start/end of the value", ) + # it should work on frozendicts too + self._assert_matches( + condition, + frozendict.frozendict({"value": "FoobaZ"}), + "patterns should match on frozendicts", + ) + # wildcards should match condition = { "kind": "event_match", From 1901cb1d4a8b7d9af64493fbd336e9aa2561c20c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 28 Feb 2022 18:47:37 +0100 Subject: [PATCH 21/40] Add type hints to `tests/rest/client` (#12084) --- changelog.d/12084.misc | 1 + mypy.ini | 3 +- tests/rest/client/test_profile.py | 78 +++++++++++++---------- tests/rest/client/test_push_rule_attrs.py | 26 ++++---- tests/rest/client/test_redactions.py | 26 +++++--- tests/rest/client/test_relations.py | 58 +++++++++-------- tests/rest/client/test_retention.py | 41 ++++++++---- tests/rest/client/test_sendtodevice.py | 8 +-- tests/rest/client/test_shadow_banned.py | 22 ++++--- tests/rest/client/test_shared_rooms.py | 20 +++--- tests/rest/client/test_upgrade_room.py | 17 +++-- tests/rest/client/utils.py | 36 +++++++---- 12 files changed, 198 insertions(+), 138 deletions(-) create mode 100644 changelog.d/12084.misc diff --git a/changelog.d/12084.misc b/changelog.d/12084.misc new file mode 100644 index 0000000000..0360dbd61e --- /dev/null +++ b/changelog.d/12084.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/mypy.ini b/mypy.ini index 610660b9b7..bd75905c8d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -84,7 +84,6 @@ exclude = (?x) |tests/rest/client/test_third_party_rules.py |tests/rest/client/test_transactions.py |tests/rest/client/test_typing.py - |tests/rest/client/utils.py |tests/rest/key/v2/test_remote_key_resource.py |tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_media_storage.py @@ -253,7 +252,7 @@ disallow_untyped_defs = True [mypy-tests.rest.admin.*] disallow_untyped_defs = True -[mypy-tests.rest.client.test_directory] +[mypy-tests.rest.client.*] disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 4239e1e610..77c3ced42e 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -13,12 +13,16 @@ # limitations under the License. """Tests REST events for /profile paths.""" -from typing import Any, Dict +from typing import Any, Dict, Optional + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import login, profile, room +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -32,20 +36,20 @@ class ProfileTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") self.other = self.register_user("other", "pass", displayname="Bob") - def test_get_displayname(self): + def test_get_displayname(self) -> None: res = self._get_displayname() self.assertEqual(res, "owner") - def test_set_displayname(self): + def test_set_displayname(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), @@ -57,7 +61,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_displayname() self.assertEqual(res, "test") - def test_set_displayname_noauth(self): + def test_set_displayname_noauth(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), @@ -65,7 +69,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 401, channel.result) - def test_set_displayname_too_long(self): + def test_set_displayname_too_long(self) -> None: """Attempts to set a stupid displayname should get a 400""" channel = self.make_request( "PUT", @@ -78,11 +82,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_displayname() self.assertEqual(res, "owner") - def test_get_displayname_other(self): + def test_get_displayname_other(self) -> None: res = self._get_displayname(self.other) self.assertEqual(res, "Bob") - def test_set_displayname_other(self): + def test_set_displayname_other(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.other,), @@ -91,11 +95,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def test_get_avatar_url(self): + def test_get_avatar_url(self) -> None: res = self._get_avatar_url() self.assertIsNone(res) - def test_set_avatar_url(self): + def test_set_avatar_url(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.owner,), @@ -107,7 +111,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_avatar_url() self.assertEqual(res, "http://my.server/pic.gif") - def test_set_avatar_url_noauth(self): + def test_set_avatar_url_noauth(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.owner,), @@ -115,7 +119,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 401, channel.result) - def test_set_avatar_url_too_long(self): + def test_set_avatar_url_too_long(self) -> None: """Attempts to set a stupid avatar_url should get a 400""" channel = self.make_request( "PUT", @@ -128,11 +132,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_avatar_url() self.assertIsNone(res) - def test_get_avatar_url_other(self): + def test_get_avatar_url_other(self) -> None: res = self._get_avatar_url(self.other) self.assertIsNone(res) - def test_set_avatar_url_other(self): + def test_set_avatar_url_other(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.other,), @@ -141,14 +145,14 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def _get_displayname(self, name=None): + def _get_displayname(self, name: Optional[str] = None) -> str: channel = self.make_request( "GET", "/profile/%s/displayname" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] - def _get_avatar_url(self, name=None): + def _get_avatar_url(self, name: Optional[str] = None) -> str: channel = self.make_request( "GET", "/profile/%s/avatar_url" % (name or self.owner,) ) @@ -156,7 +160,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): return channel.json_body.get("avatar_url") @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_size_limit_global(self): + def test_avatar_size_limit_global(self) -> None: """Tests that the maximum size limit for avatars is enforced when updating a global profile. """ @@ -187,7 +191,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_size_limit_per_room(self): + def test_avatar_size_limit_per_room(self) -> None: """Tests that the maximum size limit for avatars is enforced when updating a per-room profile. """ @@ -220,7 +224,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_allowed_mime_type_global(self): + def test_avatar_allowed_mime_type_global(self) -> None: """Tests that the MIME type whitelist for avatars is enforced when updating a global profile. """ @@ -251,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_allowed_mime_type_per_room(self): + def test_avatar_allowed_mime_type_per_room(self) -> None: """Tests that the MIME type whitelist for avatars is enforced when updating a per-room profile. """ @@ -283,7 +287,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.result) - def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None: """Stores metadata about files in the database. Args: @@ -316,8 +320,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): - + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_auth_for_profile_requests"] = True config["limit_profile_requests_to_users_who_share_rooms"] = True @@ -325,7 +328,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # User owning the requested profile. self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") @@ -337,22 +340,24 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) - def test_no_auth(self): + def test_no_auth(self) -> None: self.try_fetch_profile(401) - def test_not_in_shared_room(self): + def test_not_in_shared_room(self) -> None: self.ensure_requester_left_room() self.try_fetch_profile(403, access_token=self.requester_tok) - def test_in_shared_room(self): + def test_in_shared_room(self) -> None: self.ensure_requester_left_room() self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) self.try_fetch_profile(200, self.requester_tok) - def try_fetch_profile(self, expected_code, access_token=None): + def try_fetch_profile( + self, expected_code: int, access_token: Optional[str] = None + ) -> None: self.request_profile(expected_code, access_token=access_token) self.request_profile( @@ -363,13 +368,18 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): expected_code, url_suffix="/avatar_url", access_token=access_token ) - def request_profile(self, expected_code, url_suffix="", access_token=None): + def request_profile( + self, + expected_code: int, + url_suffix: str = "", + access_token: Optional[str] = None, + ) -> None: channel = self.make_request( "GET", self.profile_url + url_suffix, access_token=access_token ) self.assertEqual(channel.code, expected_code, channel.result) - def ensure_requester_left_room(self): + def ensure_requester_left_room(self) -> None: try: self.helper.leave( room=self.room_id, user=self.requester, tok=self.requester_tok @@ -389,7 +399,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_auth_for_profile_requests"] = True config["limit_profile_requests_to_users_who_share_rooms"] = True @@ -397,12 +407,12 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # User requesting the profile. self.requester = self.register_user("requester", "pass") self.requester_tok = self.login("requester", "pass") - def test_can_lookup_own_profile(self): + def test_can_lookup_own_profile(self) -> None: """Tests that a user can lookup their own profile without having to be in a room if 'require_auth_for_profile_requests' is set to true in the server's config. """ diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py index d0ce91ccd9..4f875b9289 100644 --- a/tests/rest/client/test_push_rule_attrs.py +++ b/tests/rest/client/test_push_rule_attrs.py @@ -27,7 +27,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): ] hijack_auth = False - def test_enabled_on_creation(self): + def test_enabled_on_creation(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is enabled upon creation, even though a rule with that @@ -56,7 +56,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_on_recreation(self): + def test_enabled_on_recreation(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is enabled upon creation, even if a rule with that @@ -113,7 +113,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_disable(self): + def test_enabled_disable(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is disabled and enabled when we ask for it. @@ -166,7 +166,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_404_when_get_non_existent(self): + def test_enabled_404_when_get_non_existent(self) -> None: """ Tests that `enabled` gives 404 when the rule doesn't exist. """ @@ -212,7 +212,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_get_non_existent_server_rule(self): + def test_enabled_404_when_get_non_existent_server_rule(self) -> None: """ Tests that `enabled` gives 404 when the server-default rule doesn't exist. """ @@ -226,7 +226,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_put_non_existent_rule(self): + def test_enabled_404_when_put_non_existent_rule(self) -> None: """ Tests that `enabled` gives 404 when we put to a rule that doesn't exist. """ @@ -243,7 +243,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_put_non_existent_server_rule(self): + def test_enabled_404_when_put_non_existent_server_rule(self) -> None: """ Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. """ @@ -260,7 +260,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_get(self): + def test_actions_get(self) -> None: """ Tests that `actions` gives you what you expect on a fresh rule. """ @@ -289,7 +289,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] ) - def test_actions_put(self): + def test_actions_put(self) -> None: """ Tests that PUT on actions updates the value you'd get from GET. """ @@ -325,7 +325,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["actions"], ["dont_notify"]) - def test_actions_404_when_get_non_existent(self): + def test_actions_404_when_get_non_existent(self) -> None: """ Tests that `actions` gives 404 when the rule doesn't exist. """ @@ -365,7 +365,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_get_non_existent_server_rule(self): + def test_actions_404_when_get_non_existent_server_rule(self) -> None: """ Tests that `actions` gives 404 when the server-default rule doesn't exist. """ @@ -379,7 +379,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_put_non_existent_rule(self): + def test_actions_404_when_put_non_existent_rule(self) -> None: """ Tests that `actions` gives 404 when putting to a rule that doesn't exist. """ @@ -396,7 +396,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_put_non_existent_server_rule(self): + def test_actions_404_when_put_non_existent_server_rule(self) -> None: """ Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. """ diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 433d715f69..7401b5e0c0 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,9 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + +from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -28,7 +34,7 @@ class RedactionsTestCase(HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["rc_message"] = {"per_second": 0.2, "burst_count": 10} @@ -36,7 +42,7 @@ class RedactionsTestCase(HomeserverTestCase): return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register a couple of users self.mod_user_id = self.register_user("user1", "pass") self.mod_access_token = self.login("user1", "pass") @@ -60,7 +66,9 @@ class RedactionsTestCase(HomeserverTestCase): room=self.room_id, user=self.other_user_id, tok=self.other_access_token ) - def _redact_event(self, access_token, room_id, event_id, expect_code=200): + def _redact_event( + self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + ) -> JsonDict: """Helper function to send a redaction event. Returns the json body. @@ -71,13 +79,13 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(int(channel.result["code"]), expect_code) return channel.json_body - def _sync_room_timeline(self, access_token, room_id): + def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: channel = self.make_request("GET", "sync", access_token=self.mod_access_token) self.assertEqual(channel.result["code"], b"200") room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] - def test_redact_event_as_moderator(self): + def test_redact_event_as_moderator(self) -> None: # as a regular user, send a message to redact b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) msg_id = b["event_id"] @@ -98,7 +106,7 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-2]["content"], {}) - def test_redact_event_as_normal(self): + def test_redact_event_as_normal(self) -> None: # as a regular user, send a message to redact b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) normal_msg_id = b["event_id"] @@ -133,7 +141,7 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(timeline[-3]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-3]["content"], {}) - def test_redact_nonexistent_event(self): + def test_redact_nonexistent_event(self) -> None: # control case: an existing event b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) msg_id = b["event_id"] @@ -158,7 +166,7 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-2]["content"], {}) - def test_redact_create_event(self): + def test_redact_create_event(self) -> None: # control case: an existing event b = self.helper.send(room_id=self.room_id, tok=self.mod_access_token) msg_id = b["event_id"] @@ -178,7 +186,7 @@ class RedactionsTestCase(HomeserverTestCase): self.other_access_token, self.room_id, create_event_id, expect_code=403 ) - def test_redact_event_as_moderator_ratelimit(self): + def test_redact_event_as_moderator_ratelimit(self) -> None: """Tests that the correct ratelimiting is applied to redactions""" message_ids = [] diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 8f7181103b..c8db45719e 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -18,11 +18,15 @@ import urllib.parse from typing import Dict, List, Optional, Tuple from unittest.mock import patch +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync +from synapse.server import HomeServer from synapse.storage.relations import RelationPaginationToken from synapse.types import JsonDict, StreamToken +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -52,7 +56,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): return config - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id, self.user_token = self._create_user("alice") @@ -63,7 +67,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): res = self.helper.send(self.room, body="Hi!", tok=self.user_token) self.parent_id = res["event_id"] - def test_send_relation(self): + def test_send_relation(self) -> None: """Tests that sending a relation using the new /send_relation works creates the right shape of event. """ @@ -95,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body, ) - def test_deny_invalid_event(self): + def test_deny_invalid_event(self) -> None: """Test that we deny relations on non-existant events""" channel = self._send_relation( RelationTypes.ANNOTATION, @@ -125,7 +129,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code, channel.json_body) - def test_deny_invalid_room(self): + def test_deny_invalid_room(self) -> None: """Test that we deny relations on non-existant events""" # Create another room and send a message in it. room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) @@ -138,7 +142,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - def test_deny_double_react(self): + def test_deny_double_react(self) -> None: """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEqual(200, channel.code, channel.json_body) @@ -146,7 +150,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEqual(400, channel.code, channel.json_body) - def test_deny_forked_thread(self): + def test_deny_forked_thread(self) -> None: """It is invalid to start a thread off a thread.""" channel = self._send_relation( RelationTypes.THREAD, @@ -165,7 +169,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - def test_basic_paginate_relations(self): + def test_basic_paginate_relations(self) -> None: """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEqual(200, channel.code, channel.json_body) @@ -235,7 +239,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ).to_string(self.store) ) - def test_repeated_paginate_relations(self): + def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. """ @@ -303,7 +307,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - def test_pagination_from_sync_and_messages(self): + def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") self.assertEqual(200, channel.code, channel.json_body) @@ -362,7 +366,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] ) - def test_aggregation_pagination_groups(self): + def test_aggregation_pagination_groups(self) -> None: """Test that we can paginate annotation groups correctly.""" # We need to create ten separate users to send each reaction. @@ -427,7 +431,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEqual(sent_groups, found_groups) - def test_aggregation_pagination_within_group(self): + def test_aggregation_pagination_within_group(self) -> None: """Test that we can paginate within an annotation group.""" # We need to create ten separate users to send each reaction. @@ -524,7 +528,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - def test_aggregation(self): + def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -556,7 +560,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) - def test_aggregation_redactions(self): + def test_aggregation_redactions(self) -> None: """Test that annotations get correctly aggregated after a redaction.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -590,7 +594,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) - def test_aggregation_must_be_annotation(self): + def test_aggregation_must_be_annotation(self) -> None: """Test that aggregations must be annotations.""" channel = self.make_request( @@ -604,7 +608,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} ) - def test_bundled_aggregations(self): + def test_bundled_aggregations(self) -> None: """ Test that annotations, references, and threads get correctly bundled. @@ -746,7 +750,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_aggregation_get_event_for_annotation(self): + def test_aggregation_get_event_for_annotation(self) -> None: """Test that annotations do not get bundled aggregations included when directly requested. """ @@ -768,7 +772,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - def test_aggregation_get_event_for_thread(self): + def test_aggregation_get_event_for_thread(self) -> None: """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") self.assertEqual(200, channel.code, channel.json_body) @@ -815,7 +819,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_ignore_invalid_room(self): + def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) @@ -927,7 +931,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertNotIn("m.relations", channel.json_body["unsigned"]) @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_edit(self): + def test_edit(self) -> None: """Test that a simple edit works.""" new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1010,7 +1014,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_multi_edit(self): + def test_multi_edit(self) -> None: """Test that multiple edits, including attempts by people who shouldn't be allowed, are correctly handled. """ @@ -1067,7 +1071,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_edit_reply(self): + def test_edit_reply(self) -> None: """Test that editing a reply works.""" # Create a reply to edit. @@ -1124,7 +1128,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_edit_thread(self): + def test_edit_thread(self) -> None: """Test that editing a thread works.""" # Create a thread and edit the last event. @@ -1163,7 +1167,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): latest_event_in_thread = thread_summary["latest_event"] self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") - def test_edit_edit(self): + def test_edit_edit(self) -> None: """Test that an edit cannot be edited.""" new_body = {"msgtype": "m.text", "body": "Initial edit"} channel = self._send_relation( @@ -1213,7 +1217,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_relations_redaction_redacts_edits(self): + def test_relations_redaction_redacts_edits(self) -> None: """Test that edits of an event are redacted when the original event is redacted. """ @@ -1269,7 +1273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertIn("chunk", channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - def test_aggregations_redaction_prevents_access_to_aggregations(self): + def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None: """Test that annotations of an event are redacted when the original event is redacted. """ @@ -1309,7 +1313,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertIn("chunk", channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - def test_unknown_relations(self): + def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") self.assertEqual(200, channel.code, channel.json_body) @@ -1417,7 +1421,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): return user_id, access_token - def test_background_update(self): + def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index c41a1c14a1..f3bf8d0934 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -13,9 +13,14 @@ # limitations under the License. from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from synapse.visibility import filter_events_for_client from tests import unittest @@ -31,7 +36,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["retention"] = { "enabled": True, @@ -47,7 +52,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") @@ -55,7 +60,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.serializer = self.hs.get_event_client_serializer() self.clock = self.hs.get_clock() - def test_retention_event_purged_with_state_event(self): + def test_retention_event_purged_with_state_event(self) -> None: """Tests that expired events are correctly purged when the room's retention policy is defined by a state event. """ @@ -72,7 +77,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self._test_retention_event_purged(room_id, one_day_ms * 1.5) - def test_retention_event_purged_with_state_event_outside_allowed(self): + def test_retention_event_purged_with_state_event_outside_allowed(self) -> None: """Tests that the server configuration can override the policy for a room when running the purge jobs. """ @@ -102,7 +107,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # instead of the one specified in the room's policy. self._test_retention_event_purged(room_id, one_day_ms * 0.5) - def test_retention_event_purged_without_state_event(self): + def test_retention_event_purged_without_state_event(self) -> None: """Tests that expired events are correctly purged when the room's retention policy is defined by the server's configuration's default retention policy. """ @@ -110,7 +115,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self._test_retention_event_purged(room_id, one_day_ms * 2) - def test_visibility(self): + def test_visibility(self) -> None: """Tests that synapse.visibility.filter_events_for_client correctly filters out outdated events """ @@ -152,7 +157,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # That event should be the second, not outdated event. self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) - def _test_retention_event_purged(self, room_id: str, increment: float): + def _test_retention_event_purged(self, room_id: str, increment: float) -> None: """Run the following test scenario to test the message retention policy support: 1. Send event 1 @@ -186,6 +191,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id=room_id, body="1", tok=self.token) expired_event_id = resp.get("event_id") + assert expired_event_id is not None # Check that we can retrieve the event. expired_event = self.get_event(expired_event_id) @@ -201,6 +207,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") + assert valid_event_id is not None # Advance the time again. Now our first event should have expired but our second # one should still be kept. @@ -218,7 +225,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # has been purged. self.get_event(room_id, create_event.event_id) - def get_event(self, event_id, expect_none=False): + def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: event = self.get_success(self.store.get_event(event_id, allow_none=True)) if expect_none: @@ -240,7 +247,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["retention"] = { "enabled": True, @@ -254,11 +261,11 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): ) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - def test_no_default_policy(self): + def test_no_default_policy(self) -> None: """Tests that an event doesn't get expired if there is neither a default retention policy nor a policy specific to the room. """ @@ -266,7 +273,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): self._test_retention(room_id) - def test_state_policy(self): + def test_state_policy(self) -> None: """Tests that an event gets correctly expired if there is no default retention policy but there's a policy specific to the room. """ @@ -283,12 +290,15 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): self._test_retention(room_id, expected_code_for_first_event=404) - def _test_retention(self, room_id, expected_code_for_first_event=200): + def _test_retention( + self, room_id: str, expected_code_for_first_event: int = 200 + ) -> None: # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. resp = self.helper.send(room_id=room_id, body="1", tok=self.token) first_event_id = resp.get("event_id") + assert first_event_id is not None # Check that we can retrieve the event. expired_event = self.get_event(room_id, first_event_id) @@ -304,6 +314,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id=room_id, body="2", tok=self.token) second_event_id = resp.get("event_id") + assert second_event_id is not None # Advance the time by another month. self.reactor.advance(one_day_ms * 30 / 1000) @@ -322,7 +333,9 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): second_event = self.get_event(room_id, second_event_id) self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) - def get_event(self, room_id, event_id, expected_code=200): + def get_event( + self, room_id: str, event_id: str, expected_code: int = 200 + ) -> JsonDict: url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) channel = self.make_request("GET", url, access_token=self.token) diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index e2ed14457f..c3942889e1 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -26,7 +26,7 @@ class SendToDeviceTestCase(HomeserverTestCase): sync.register_servlets, ] - def test_user_to_user(self): + def test_user_to_user(self) -> None: """A to-device message from one user to another should get delivered""" user1 = self.register_user("u1", "pass") @@ -73,7 +73,7 @@ class SendToDeviceTestCase(HomeserverTestCase): self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_local_room_key_request(self): + def test_local_room_key_request(self) -> None: """m.room_key_request has special-casing; test from local user""" user1 = self.register_user("u1", "pass") user1_tok = self.login("u1", "pass", "d1") @@ -128,7 +128,7 @@ class SendToDeviceTestCase(HomeserverTestCase): ) @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_remote_room_key_request(self): + def test_remote_room_key_request(self) -> None: """m.room_key_request has special-casing; test from remote user""" user2 = self.register_user("u2", "pass") user2_tok = self.login("u2", "pass", "d2") @@ -199,7 +199,7 @@ class SendToDeviceTestCase(HomeserverTestCase): }, ) - def test_limited_sync(self): + def test_limited_sync(self) -> None: """If a limited sync for to-devices happens the next /sync should respond immediately.""" self.register_user("u1", "pass") diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 2634c98dde..ae5ada3be7 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -14,6 +14,8 @@ from unittest.mock import Mock, patch +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client import ( @@ -23,13 +25,15 @@ from synapse.rest.client import ( room, room_upgrade_rest_servlet, ) +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest class _ShadowBannedBase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Create two users, one of which is shadow-banned. self.banned_user_id = self.register_user("banned", "test") self.banned_access_token = self.login("banned", "test") @@ -55,7 +59,7 @@ class RoomTestCase(_ShadowBannedBase): room_upgrade_rest_servlet.register_servlets, ] - def test_invite(self): + def test_invite(self) -> None: """Invites from shadow-banned users don't actually get sent.""" # The create works fine. @@ -77,7 +81,7 @@ class RoomTestCase(_ShadowBannedBase): ) self.assertEqual(invited_rooms, []) - def test_invite_3pid(self): + def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() identity_handler.lookup_3pid = Mock( @@ -101,7 +105,7 @@ class RoomTestCase(_ShadowBannedBase): # This should have raised an error earlier, but double check this wasn't called. identity_handler.lookup_3pid.assert_not_called() - def test_create_room(self): + def test_create_room(self) -> None: """Invitations during a room creation should be discarded, but the room still gets created.""" # The room creation is successful. channel = self.make_request( @@ -126,7 +130,7 @@ class RoomTestCase(_ShadowBannedBase): users = self.get_success(self.store.get_users_in_room(room_id)) self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) - def test_message(self): + def test_message(self) -> None: """Messages from shadow-banned users don't actually get sent.""" room_id = self.helper.create_room_as( @@ -151,7 +155,7 @@ class RoomTestCase(_ShadowBannedBase): ) self.assertNotIn(event_id, latest_events) - def test_upgrade(self): + def test_upgrade(self) -> None: """A room upgrade should fail, but look like it succeeded.""" # The create works fine. @@ -177,7 +181,7 @@ class RoomTestCase(_ShadowBannedBase): # The summary should be empty since the room doesn't exist. self.assertEqual(summary, {}) - def test_typing(self): + def test_typing(self) -> None: """Typing notifications should not be propagated into the room.""" # The create works fine. room_id = self.helper.create_room_as( @@ -240,7 +244,7 @@ class ProfileTestCase(_ShadowBannedBase): room.register_servlets, ] - def test_displayname(self): + def test_displayname(self) -> None: """Profile changes should succeed, but don't end up in a room.""" original_display_name = "banned" new_display_name = "new name" @@ -281,7 +285,7 @@ class ProfileTestCase(_ShadowBannedBase): event.content, {"membership": "join", "displayname": original_display_name} ) - def test_room_displayname(self): + def test_room_displayname(self) -> None: """Changes to state events for a room should be processed, but not end up in the room.""" original_display_name = "banned" new_display_name = "new name" diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py index 294f46fb95..3818b7b14b 100644 --- a/tests/rest/client/test_shared_rooms.py +++ b/tests/rest/client/test_shared_rooms.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import login, room, shared_rooms +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -30,16 +34,16 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): shared_rooms.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["update_user_directory"] = True return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() - def _get_shared_rooms(self, token, other_user) -> FakeChannel: + def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel: return self.make_request( "GET", "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" @@ -47,14 +51,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): access_token=token, ) - def test_shared_room_list_public(self): + def test_shared_room_list_public(self) -> None: """ A room should show up in the shared list of rooms between two users if it is public. """ self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) - def test_shared_room_list_private(self): + def test_shared_room_list_private(self) -> None: """ A room should show up in the shared list of rooms between two users if it is private. @@ -63,7 +67,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): room_one_is_public=False, room_two_is_public=False ) - def test_shared_room_list_mixed(self): + def test_shared_room_list_mixed(self) -> None: """ The shared room list between two users should contain both public and private rooms. @@ -72,7 +76,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): def _check_shared_rooms_with( self, room_one_is_public: bool, room_two_is_public: bool - ): + ) -> None: """Checks that shared public or private rooms between two users appear in their shared room lists """ @@ -109,7 +113,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): for room_id_id in channel.json_body["joined"]: self.assertIn(room_id_id, [room_id_one, room_id_two]) - def test_shared_room_list_after_leave(self): + def test_shared_room_list_after_leave(self) -> None: """ A room should no longer be considered shared if the other user has left it. diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index 658c21b2a1..b7d0f42daf 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -13,11 +13,14 @@ # limitations under the License. from typing import Optional +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventContentFields, EventTypes, RoomTypes from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin from synapse.rest.client import login, room, room_upgrade_rest_servlet from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -31,7 +34,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): room_upgrade_rest_servlet.register_servlets, ] - def prepare(self, reactor, clock, hs: "HomeServer"): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.creator = self.register_user("creator", "pass") @@ -60,7 +63,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): access_token=token or self.creator_token, ) - def test_upgrade(self): + def test_upgrade(self) -> None: """ Upgrading a room should work fine. """ @@ -68,7 +71,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) - def test_not_in_room(self): + def test_not_in_room(self) -> None: """ Upgrading a room should work fine. """ @@ -79,7 +82,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): channel = self._upgrade_room(roomless_token) self.assertEqual(403, channel.code, channel.result) - def test_power_levels(self): + def test_power_levels(self) -> None: """ Another user can upgrade the room if their power level is increased. """ @@ -105,7 +108,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): channel = self._upgrade_room(self.other_token) self.assertEqual(200, channel.code, channel.result) - def test_power_levels_user_default(self): + def test_power_levels_user_default(self) -> None: """ Another user can upgrade the room if the default power level for users is increased. """ @@ -131,7 +134,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): channel = self._upgrade_room(self.other_token) self.assertEqual(200, channel.code, channel.result) - def test_power_levels_tombstone(self): + def test_power_levels_tombstone(self) -> None: """ Another user can upgrade the room if they can send the tombstone event. """ @@ -164,7 +167,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): ) self.assertNotIn(self.other, power_levels["users"]) - def test_space(self): + def test_space(self) -> None: """Test upgrading a space.""" # Create a space. diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 46cd5f70a8..28663826fc 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -41,6 +41,7 @@ from twisted.web.resource import Resource from twisted.web.server import Site from synapse.api.constants import Membership +from synapse.server import HomeServer from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request @@ -48,15 +49,15 @@ from tests.test_utils import FakeResponse from tests.test_utils.html_parsers import TestHtmlParser -@attr.s +@attr.s(auto_attribs=True) class RestHelper: """Contains extra helper functions to quickly and clearly perform a given REST action, which isn't the focus of the test. """ - hs = attr.ib() - site = attr.ib(type=Site) - auth_user_id = attr.ib() + hs: HomeServer + site: Site + auth_user_id: Optional[str] @overload def create_room_as( @@ -145,7 +146,7 @@ class RestHelper: def invite( self, - room: Optional[str] = None, + room: str, src: Optional[str] = None, targ: Optional[str] = None, expect_code: int = HTTPStatus.OK, @@ -216,7 +217,7 @@ class RestHelper: def leave( self, - room: Optional[str] = None, + room: str, user: Optional[str] = None, expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, @@ -230,14 +231,22 @@ class RestHelper: expect_code=expect_code, ) - def ban(self, room: str, src: str, targ: str, **kwargs: object) -> None: + def ban( + self, + room: str, + src: str, + targ: str, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: """A convenience helper: `change_membership` with `membership` preset to "ban".""" self.change_membership( room=room, src=src, targ=targ, + tok=tok, membership=Membership.BAN, - **kwargs, + expect_code=expect_code, ) def change_membership( @@ -378,7 +387,7 @@ class RestHelper: room_id: str, event_type: str, body: Optional[Dict[str, Any]], - tok: str, + tok: Optional[str], expect_code: int = HTTPStatus.OK, state_key: str = "", method: str = "GET", @@ -458,7 +467,7 @@ class RestHelper: room_id: str, event_type: str, body: Dict[str, Any], - tok: str, + tok: Optional[str], expect_code: int = HTTPStatus.OK, state_key: str = "", ) -> JsonDict: @@ -658,7 +667,12 @@ class RestHelper: (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] - async def mock_req(method: str, uri: str, data=None, headers=None): + async def mock_req( + method: str, + uri: str, + data: Optional[dict] = None, + headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + ): (expected_uri, resp_obj) = expected_requests.pop(0) assert uri == expected_uri resp = FakeResponse( From 1866fb39d7ffc86d7374a9aed916f70a91ec65fa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Feb 2022 13:29:09 -0500 Subject: [PATCH 22/40] Move experimental support for MSC3440 to /versions. (#12099) Instead of being part of /capabilities, this matches a change to MSC3440 to properly use these endpoints. --- changelog.d/12099.misc | 1 + synapse/rest/client/capabilities.py | 3 --- synapse/rest/client/versions.py | 2 ++ 3 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12099.misc diff --git a/changelog.d/12099.misc b/changelog.d/12099.misc new file mode 100644 index 0000000000..0553825dbc --- /dev/null +++ b/changelog.d/12099.misc @@ -0,0 +1 @@ +Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to /versions. diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index b80fdd3712..4237071c61 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -72,9 +72,6 @@ class CapabilitiesRestServlet(RestServlet): "org.matrix.msc3244.room_capabilities" ] = MSC3244_CAPABILITIES - if self.config.experimental.msc3440_enabled: - response["capabilities"]["io.element.thread"] = {"enabled": True} - if self.config.experimental.msc3720_enabled: response["capabilities"]["org.matrix.msc3720.account_status"] = { "enabled": True, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 00f29344a8..2e5d0e4e22 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -99,6 +99,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc2716": self.config.experimental.msc2716_enabled, # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 "org.matrix.msc3030": self.config.experimental.msc3030_enabled, + # Adds support for thread relations, per MSC3440. + "org.matrix.msc3440": self.config.experimental.msc3440_enabled, }, }, ) From 7754af24ab163a3666bc04c7df409e59ace0d763 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Feb 2022 13:33:00 -0500 Subject: [PATCH 23/40] Remove the unstable `/spaces` endpoint. (#12073) ...and various code supporting it. The /spaces endpoint was from an old version of MSC2946 and included both a Client-Server and Server-Server API. Note that the unstable /hierarchy endpoint (from the final version of MSC2946) is not yet removed. --- changelog.d/12073.removal | 1 + docs/workers.md | 2 - synapse/federation/federation_client.py | 226 ++---------- synapse/federation/transport/client.py | 33 -- .../federation/transport/server/federation.py | 76 ----- synapse/handlers/room_summary.py | 323 +----------------- synapse/rest/client/room.py | 68 ---- tests/handlers/test_room_summary.py | 119 +------ 8 files changed, 46 insertions(+), 802 deletions(-) create mode 100644 changelog.d/12073.removal diff --git a/changelog.d/12073.removal b/changelog.d/12073.removal new file mode 100644 index 0000000000..1f39792712 --- /dev/null +++ b/changelog.d/12073.removal @@ -0,0 +1 @@ +Remove the unstable `/spaces` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/docs/workers.md b/docs/workers.md index b82a6900ac..b0f8599ef0 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -212,7 +212,6 @@ information. ^/_matrix/federation/v1/user/devices/ ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query - ^/_matrix/federation/unstable/org.matrix.msc2946/spaces/ ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/ # Inbound federation transaction request @@ -225,7 +224,6 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ - ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$ ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(r0|v3|unstable)/account/3pid$ diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a4bae3c4c8..64e595e748 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1362,61 +1362,6 @@ class FederationClient(FederationBase): # server doesn't give it to us. return None - async def get_space_summary( - self, - destinations: Iterable[str], - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: List[str], - ) -> "FederationSpaceSummaryResult": - """ - Call other servers to get a summary of the given space - - - Args: - destinations: The remote servers. We will try them in turn, omitting any - that have been blacklisted. - - room_id: ID of the space to be queried - - suggested_only: If true, ask the remote server to only return children - with the "suggested" flag set - - max_rooms_per_space: A limit on the number of children to return for each - space - - exclude_rooms: A list of room IDs to tell the remote server to skip - - Returns: - a parsed FederationSpaceSummaryResult - - Raises: - SynapseError if we were unable to get a valid summary from any of the - remote servers - """ - - async def send_request(destination: str) -> FederationSpaceSummaryResult: - res = await self.transport_layer.get_space_summary( - destination=destination, - room_id=room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_rooms_per_space, - exclude_rooms=exclude_rooms, - ) - - try: - return FederationSpaceSummaryResult.from_json_dict(res) - except ValueError as e: - raise InvalidResponseError(str(e)) - - return await self._try_destination_list( - "fetch space summary", - destinations, - send_request, - failover_on_unknown_endpoint=True, - ) - async def get_room_hierarchy( self, destinations: Iterable[str], @@ -1488,10 +1433,8 @@ class FederationClient(FederationBase): if any(not isinstance(e, dict) for e in children_state): raise InvalidResponseError("Invalid event in 'children_state' list") try: - [ - FederationSpaceSummaryEventResult.from_json_dict(e) - for e in children_state - ] + for child_state in children_state: + _validate_hierarchy_event(child_state) except ValueError as e: raise InvalidResponseError(str(e)) @@ -1513,62 +1456,12 @@ class FederationClient(FederationBase): return room, children_state, children, inaccessible_children - try: - result = await self._try_destination_list( - "fetch room hierarchy", - destinations, - send_request, - failover_on_unknown_endpoint=True, - ) - except SynapseError as e: - # If an unexpected error occurred, re-raise it. - if e.code != 502: - raise - - logger.debug( - "Couldn't fetch room hierarchy, falling back to the spaces API" - ) - - # Fallback to the old federation API and translate the results if - # no servers implement the new API. - # - # The algorithm below is a bit inefficient as it only attempts to - # parse information for the requested room, but the legacy API may - # return additional layers. - legacy_result = await self.get_space_summary( - destinations, - room_id, - suggested_only, - max_rooms_per_space=None, - exclude_rooms=[], - ) - - # Find the requested room in the response (and remove it). - for _i, room in enumerate(legacy_result.rooms): - if room.get("room_id") == room_id: - break - else: - # The requested room was not returned, nothing we can do. - raise - requested_room = legacy_result.rooms.pop(_i) - - # Find any children events of the requested room. - children_events = [] - children_room_ids = set() - for event in legacy_result.events: - if event.room_id == room_id: - children_events.append(event.data) - children_room_ids.add(event.state_key) - - # Find the children rooms. - children = [] - for room in legacy_result.rooms: - if room.get("room_id") in children_room_ids: - children.append(room) - - # It isn't clear from the response whether some of the rooms are - # not accessible. - result = (requested_room, children_events, children, ()) + result = await self._try_destination_list( + "fetch room hierarchy", + destinations, + send_request, + failover_on_unknown_endpoint=True, + ) # Cache the result to avoid fetching data over federation every time. self._get_room_hierarchy_cache[(room_id, suggested_only)] = result @@ -1710,89 +1603,34 @@ class TimestampToEventResponse: return cls(event_id, origin_server_ts, d) -@attr.s(frozen=True, slots=True, auto_attribs=True) -class FederationSpaceSummaryEventResult: - """Represents a single event in the result of a successful get_space_summary call. +def _validate_hierarchy_event(d: JsonDict) -> None: + """Validate an event within the result of a /hierarchy request - It's essentially just a serialised event object, but we do a bit of parsing and - validation in `from_json_dict` and store some of the validated properties in - object attributes. + Args: + d: json object to be parsed + + Raises: + ValueError if d is not a valid event """ - event_type: str - room_id: str - state_key: str - via: Sequence[str] + event_type = d.get("type") + if not isinstance(event_type, str): + raise ValueError("Invalid event: 'event_type' must be a str") - # the raw data, including the above keys - data: JsonDict + room_id = d.get("room_id") + if not isinstance(room_id, str): + raise ValueError("Invalid event: 'room_id' must be a str") - @classmethod - def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": - """Parse an event within the result of a /spaces/ request + state_key = d.get("state_key") + if not isinstance(state_key, str): + raise ValueError("Invalid event: 'state_key' must be a str") - Args: - d: json object to be parsed + content = d.get("content") + if not isinstance(content, dict): + raise ValueError("Invalid event: 'content' must be a dict") - Raises: - ValueError if d is not a valid event - """ - - event_type = d.get("type") - if not isinstance(event_type, str): - raise ValueError("Invalid event: 'event_type' must be a str") - - room_id = d.get("room_id") - if not isinstance(room_id, str): - raise ValueError("Invalid event: 'room_id' must be a str") - - state_key = d.get("state_key") - if not isinstance(state_key, str): - raise ValueError("Invalid event: 'state_key' must be a str") - - content = d.get("content") - if not isinstance(content, dict): - raise ValueError("Invalid event: 'content' must be a dict") - - via = content.get("via") - if not isinstance(via, Sequence): - raise ValueError("Invalid event: 'via' must be a list") - if any(not isinstance(v, str) for v in via): - raise ValueError("Invalid event: 'via' must be a list of strings") - - return cls(event_type, room_id, state_key, via, d) - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class FederationSpaceSummaryResult: - """Represents the data returned by a successful get_space_summary call.""" - - rooms: List[JsonDict] - events: Sequence[FederationSpaceSummaryEventResult] - - @classmethod - def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": - """Parse the result of a /spaces/ request - - Args: - d: json object to be parsed - - Raises: - ValueError if d is not a valid /spaces/ response - """ - rooms = d.get("rooms") - if not isinstance(rooms, List): - raise ValueError("'rooms' must be a list") - if any(not isinstance(r, dict) for r in rooms): - raise ValueError("Invalid room in 'rooms' list") - - events = d.get("events") - if not isinstance(events, Sequence): - raise ValueError("'events' must be a list") - if any(not isinstance(e, dict) for e in events): - raise ValueError("Invalid event in 'events' list") - parsed_events = [ - FederationSpaceSummaryEventResult.from_json_dict(e) for e in events - ] - - return cls(rooms, parsed_events) + via = content.get("via") + if not isinstance(via, Sequence): + raise ValueError("Invalid event: 'via' must be a list") + if any(not isinstance(v, str) for v in via): + raise ValueError("Invalid event: 'via' must be a list of strings") diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 69998de520..de6e5f44fe 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1179,39 +1179,6 @@ class TransportLayerClient: return await self.client.get_json(destination=destination, path=path) - async def get_space_summary( - self, - destination: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: List[str], - ) -> JsonDict: - """ - Args: - destination: The remote server - room_id: The room ID to ask about. - suggested_only: if True, only suggested rooms will be returned - max_rooms_per_space: an optional limit to the number of children to be - returned per space - exclude_rooms: a list of any rooms we can skip - """ - # TODO When switching to the stable endpoint, use GET instead of POST. - path = _create_path( - FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/spaces/%s", room_id - ) - - params = { - "suggested_only": suggested_only, - "exclude_rooms": exclude_rooms, - } - if max_rooms_per_space is not None: - params["max_rooms_per_space"] = max_rooms_per_space - - return await self.client.post_json( - destination=destination, path=path, data=params - ) - async def get_room_hierarchy( self, destination: str, room_id: str, suggested_only: bool ) -> JsonDict: diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 23ce343057..aed3d5069c 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -624,81 +624,6 @@ class FederationVersionServlet(BaseFederationServlet): ) -class FederationSpaceSummaryServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - PATH = "/spaces/(?P[^/]*)" - - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_room_summary_handler() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) - - max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") - if max_rooms_per_space is not None and max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - exclude_rooms = content.get("exclude_rooms", []) - if not isinstance(exclude_rooms, list) or any( - not isinstance(x, str) for x in exclude_rooms - ): - raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None: - if not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON - ) - if max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - class FederationRoomHierarchyServlet(BaseFederationServlet): PATH = "/hierarchy/(?P[^/]*)" @@ -826,7 +751,6 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, - FederationSpaceSummaryServlet, FederationRoomHierarchyServlet, FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 2e61d1cbe9..55c2cbdba8 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -15,7 +15,6 @@ import itertools import logging import re -from collections import deque from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple import attr @@ -107,153 +106,6 @@ class RoomSummaryHandler: "get_room_hierarchy", ) - async def get_space_summary( - self, - requester: str, - room_id: str, - suggested_only: bool = False, - max_rooms_per_space: Optional[int] = None, - ) -> JsonDict: - """ - Implementation of the space summary C-S API - - Args: - requester: user id of the user making this request - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. This does not apply to the root room (ie, room_id), and - is overridden by MAX_ROOMS_PER_SPACE. - - Returns: - summary dict to return - """ - # First of all, check that the room is accessible. - if not await self._is_local_room_accessible(room_id, requester): - raise AuthError( - 403, - "User %s not in room %s, and room previews are disabled" - % (requester, room_id), - ) - - # the queue of rooms to process - room_queue = deque((_RoomQueueEntry(room_id, ()),)) - - # rooms we have already processed - processed_rooms: Set[str] = set() - - # events we have already processed. We don't necessarily have their event ids, - # so instead we key on (room id, state key) - processed_events: Set[Tuple[str, str]] = set() - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE: - max_rooms_per_space = MAX_ROOMS_PER_SPACE - - while room_queue and len(rooms_result) < MAX_ROOMS: - queue_entry = room_queue.popleft() - room_id = queue_entry.room_id - if room_id in processed_rooms: - # already done this room - continue - - logger.debug("Processing room %s", room_id) - - is_in_room = await self._store.is_host_joined(room_id, self._server_name) - - # The client-specified max_rooms_per_space limit doesn't apply to the - # room_id specified in the request, so we ignore it if this is the - # first room we are processing. - max_children = max_rooms_per_space if processed_rooms else MAX_ROOMS - - if is_in_room: - room_entry = await self._summarize_local_room( - requester, None, room_id, suggested_only, max_children - ) - - events: Sequence[JsonDict] = [] - if room_entry: - rooms_result.append(room_entry.room) - events = room_entry.children_state_events - - logger.debug( - "Query of local room %s returned events %s", - room_id, - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - else: - fed_rooms = await self._summarize_remote_room( - queue_entry, - suggested_only, - max_children, - exclude_rooms=processed_rooms, - ) - - # The results over federation might include rooms that the we, - # as the requesting server, are allowed to see, but the requesting - # user is not permitted see. - # - # Filter the returned results to only what is accessible to the user. - events = [] - for room_entry in fed_rooms: - room = room_entry.room - fed_room_id = room_entry.room_id - - # The user can see the room, include it! - if await self._is_remote_room_accessible( - requester, fed_room_id, room - ): - # Before returning to the client, remove the allowed_room_ids - # and allowed_spaces keys. - room.pop("allowed_room_ids", None) - room.pop("allowed_spaces", None) # historical - - rooms_result.append(room) - events.extend(room_entry.children_state_events) - - # All rooms returned don't need visiting again (even if the user - # didn't have access to them). - processed_rooms.add(fed_room_id) - - logger.debug( - "Query of %s returned rooms %s, events %s", - room_id, - [room_entry.room.get("room_id") for room_entry in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - - # the room we queried may or may not have been returned, but don't process - # it again, anyway. - processed_rooms.add(room_id) - - # XXX: is it ok that we blindly iterate through any events returned by - # a remote server, whether or not they actually link to any rooms in our - # tree? - for ev in events: - # remote servers might return events we have already processed - # (eg, Dendrite returns inward pointers as well as outward ones), so - # we need to filter them out, to avoid returning duplicate links to the - # client. - ev_key = (ev["room_id"], ev["state_key"]) - if ev_key in processed_events: - continue - events_result.append(ev) - - # add the child to the queue. we have already validated - # that the vias are a list of server names. - room_queue.append( - _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) - ) - processed_events.add(ev_key) - - return {"rooms": rooms_result, "events": events_result} - async def get_room_hierarchy( self, requester: Requester, @@ -398,8 +250,6 @@ class RoomSummaryHandler: None, room_id, suggested_only, - # Do not limit the maximum children. - max_children=None, ) # Otherwise, attempt to use information for federation. @@ -488,74 +338,6 @@ class RoomSummaryHandler: return result - async def federation_space_summary( - self, - origin: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: Iterable[str], - ) -> JsonDict: - """ - Implementation of the space summary Federation API - - Args: - origin: The server requesting the spaces summary. - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. Unlike the C-S API, this applies to the root room (room_id). - It is clipped to MAX_ROOMS_PER_SPACE. - - exclude_rooms: a list of rooms to skip over (presumably because the - calling server has already seen them). - - Returns: - summary dict to return - """ - # the queue of rooms to process - room_queue = deque((room_id,)) - - # the set of rooms that we should not walk further. Initialise it with the - # excluded-rooms list; we will add other rooms as we process them so that - # we do not loop. - processed_rooms: Set[str] = set(exclude_rooms) - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - # Set a limit on the number of rooms to return. - if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE: - max_rooms_per_space = MAX_ROOMS_PER_SPACE - - while room_queue and len(rooms_result) < MAX_ROOMS: - room_id = room_queue.popleft() - if room_id in processed_rooms: - # already done this room - continue - - room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_rooms_per_space - ) - - processed_rooms.add(room_id) - - if room_entry: - rooms_result.append(room_entry.room) - events_result.extend(room_entry.children_state_events) - - # add any children to the queue - room_queue.extend( - edge_event["state_key"] - for edge_event in room_entry.children_state_events - ) - - return {"rooms": rooms_result, "events": events_result} - async def get_federation_hierarchy( self, origin: str, @@ -579,7 +361,7 @@ class RoomSummaryHandler: The JSON hierarchy dictionary. """ root_room_entry = await self._summarize_local_room( - None, origin, requested_room_id, suggested_only, max_children=None + None, origin, requested_room_id, suggested_only ) if root_room_entry is None: # Room is inaccessible to the requesting server. @@ -600,7 +382,7 @@ class RoomSummaryHandler: continue room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_children=0 + None, origin, room_id, suggested_only, include_children=False ) # If the room is accessible, include it in the results. # @@ -626,7 +408,7 @@ class RoomSummaryHandler: origin: Optional[str], room_id: str, suggested_only: bool, - max_children: Optional[int], + include_children: bool = True, ) -> Optional["_RoomEntry"]: """ Generate a room entry and a list of event entries for a given room. @@ -641,9 +423,8 @@ class RoomSummaryHandler: room_id: The room ID to summarize. suggested_only: True if only suggested children should be returned. Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. A value of None - means no limit. + include_children: + Whether to include the events of any children. Returns: A room entry if the room should be returned. None, otherwise. @@ -653,9 +434,8 @@ class RoomSummaryHandler: room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) - # If the room is not a space or the children don't matter, return just - # the room information. - if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: + # If the room is not a space return just the room information. + if room_entry.get("room_type") != RoomTypes.SPACE or not include_children: return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. @@ -665,14 +445,6 @@ class RoomSummaryHandler: # we only care about suggested children child_events = filter(_is_suggested_child_event, child_events) - # TODO max_children is legacy code for the /spaces endpoint. - if max_children is not None: - child_iter: Iterable[EventBase] = itertools.islice( - child_events, max_children - ) - else: - child_iter = child_events - stripped_events: List[JsonDict] = [ { "type": e.type, @@ -682,80 +454,10 @@ class RoomSummaryHandler: "sender": e.sender, "origin_server_ts": e.origin_server_ts, } - for e in child_iter + for e in child_events ] return _RoomEntry(room_id, room_entry, stripped_events) - async def _summarize_remote_room( - self, - room: "_RoomQueueEntry", - suggested_only: bool, - max_children: Optional[int], - exclude_rooms: Iterable[str], - ) -> Iterable["_RoomEntry"]: - """ - Request room entries and a list of event entries for a given room by querying a remote server. - - Args: - room: The room to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. - exclude_rooms: - Rooms IDs which do not need to be summarized. - - Returns: - An iterable of room entries. - """ - room_id = room.room_id - logger.info("Requesting summary for %s via %s", room_id, room.via) - - # we need to make the exclusion list json-serialisable - exclude_rooms = list(exclude_rooms) - - via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) - try: - res = await self._federation_client.get_space_summary( - via, - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_children, - exclude_rooms=exclude_rooms, - ) - except Exception as e: - logger.warning( - "Unable to get summary of %s via federation: %s", - room_id, - e, - exc_info=logger.isEnabledFor(logging.DEBUG), - ) - return () - - # Group the events by their room. - children_by_room: Dict[str, List[JsonDict]] = {} - for ev in res.events: - if ev.event_type == EventTypes.SpaceChild: - children_by_room.setdefault(ev.room_id, []).append(ev.data) - - # Generate the final results. - results = [] - for fed_room in res.rooms: - fed_room_id = fed_room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue - - results.append( - _RoomEntry( - fed_room_id, - fed_room, - children_by_room.get(fed_room_id, []), - ) - ) - - return results - async def _summarize_remote_room_hierarchy( self, room: "_RoomQueueEntry", suggested_only: bool ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: @@ -958,9 +660,8 @@ class RoomSummaryHandler: ): return True - # Check if the user is a member of any of the allowed spaces - # from the response. - allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") + # Check if the user is a member of any of the allowed rooms from the response. + allowed_rooms = room.get("allowed_room_ids") if allowed_rooms and isinstance(allowed_rooms, list): if await self._event_auth_handler.is_user_in_rooms( allowed_rooms, requester @@ -1028,8 +729,6 @@ class RoomSummaryHandler: ) if allowed_rooms: entry["allowed_room_ids"] = allowed_rooms - # TODO Remove this key once the API is stable. - entry["allowed_spaces"] = allowed_rooms # Filter out Nones – rather omit the field altogether room_entry = {k: v for k, v in entry.items() if v is not None} @@ -1094,7 +793,7 @@ class RoomSummaryHandler: room_id, # Suggested-only doesn't matter since no children are requested. suggested_only=False, - max_children=0, + include_children=False, ) if not room_entry: diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 5ccfe5a92f..8a06ab8c5f 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1141,73 +1141,6 @@ class TimestampLookupRestServlet(RestServlet): } -class RoomSpaceSummaryRestServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" - "/rooms/(?P[^/]*)/spaces$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._auth = hs.get_auth() - self._room_summary_handler = hs.get_room_summary_handler() - - async def on_GET( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - - max_rooms_per_space = parse_integer(request, "max_rooms_per_space") - if max_rooms_per_space is not None and max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=parse_boolean(request, "suggested_only", default=False), - max_rooms_per_space=max_rooms_per_space, - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request) - - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None: - if not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON - ) - if max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_rooms_per_space, - ) - - class RoomHierarchyRestServlet(RestServlet): PATTERNS = ( re.compile( @@ -1301,7 +1234,6 @@ def register_servlets( RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) - RoomSpaceSummaryRestServlet(hs).register(http_server) RoomHierarchyRestServlet(hs).register(http_server) if hs.config.experimental.msc3266_enabled: RoomSummaryRestServlet(hs).register(http_server) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index 51b22d2998..b33ff94a39 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -157,35 +157,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): state_key=room_id, ) - def _assert_rooms( - self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] - ) -> None: - """ - Assert that the expected room IDs and events are in the response. - - Args: - result: The result from the API call. - rooms_and_children: An iterable of tuples where each tuple is: - The expected room ID. - The expected IDs of any children rooms. - """ - room_ids = [] - children_ids = [] - for room_id, children in rooms_and_children: - room_ids.append(room_id) - if children: - children_ids.extend([(room_id, child_id) for child_id in children]) - self.assertCountEqual( - [room.get("room_id") for room in result["rooms"]], room_ids - ) - self.assertCountEqual( - [ - (event.get("room_id"), event.get("state_key")) - for event in result["events"] - ], - children_ids, - ) - def _assert_hierarchy( self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] ) -> None: @@ -251,11 +222,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): def test_simple_space(self): """Test a simple space with a single room.""" - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have the space and the room in it, along with a link # from space -> room. expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -271,12 +240,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._add_child(self.space, room, self.token) rooms.append(room) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The spaces result should have the space and the first 50 rooms in it, - # along with the links from space -> room for those 50 rooms. - expected = [(self.space, rooms[:50])] + [(room, []) for room in rooms[:49]] - self._assert_rooms(result, expected) - # The result should have the space and the rooms in it, along with the links # from space -> room. expected = [(self.space, rooms)] + [(room, []) for room in rooms] @@ -300,10 +263,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): token2 = self.login("user2", "pass") # The user can see the space since it is publicly joinable. - result = self.get_success(self.handler.get_space_summary(user2, self.space)) expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) @@ -316,7 +276,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): body={"join_rule": JoinRules.INVITE}, tok=self.token, ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) self.get_failure( self.handler.get_room_hierarchy(create_requester(user2), self.space), AuthError, @@ -329,9 +288,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): body={"history_visibility": HistoryVisibility.WORLD_READABLE}, tok=self.token, ) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) @@ -344,7 +300,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): body={"history_visibility": HistoryVisibility.JOINED}, tok=self.token, ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) self.get_failure( self.handler.get_room_hierarchy(create_requester(user2), self.space), AuthError, @@ -353,19 +308,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Join the space and results should be returned. self.helper.invite(self.space, targ=user2, tok=self.token) self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) self._assert_hierarchy(result, expected) # Attempting to view an unknown room returns the same error. - self.get_failure( - self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), - AuthError, - ) self.get_failure( self.handler.get_room_hierarchy( create_requester(user2), "#not-a-space:" + self.hs.hostname @@ -496,7 +444,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Join the space. self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) expected = [ ( self.space, @@ -520,7 +467,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (world_readable_room, ()), (joined_room, ()), ] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) @@ -554,8 +500,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._add_child(subspace, self.room, token=self.token) self._add_child(subspace, room2, self.token) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The result should include each room a single time and each link. expected = [ (self.space, [self.room, room2, subspace]), @@ -563,7 +507,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (subspace, [subroom, self.room, room2]), (subroom, ()), ] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -728,10 +671,8 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) ) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have only the space, along with a link from space -> room. expected = [(self.space, [self.room])] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -775,41 +716,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "world_readable": True, } - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [ - requested_room_entry, - _RoomEntry( - subroom, - { - "room_id": subroom, - "world_readable": True, - }, - ), - ] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return requested_room_entry, {subroom: child_room}, set() # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, subspace]), (self.room, ()), (subspace, [subroom]), (subroom, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", @@ -881,7 +799,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "room_id": restricted_room, "world_readable": False, "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], + "allowed_room_ids": [], }, ), ( @@ -890,7 +808,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "room_id": restricted_accessible_room, "world_readable": False, "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], + "allowed_room_ids": [self.room], }, ), ( @@ -929,30 +847,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ], ) - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [subspace_room_entry] + [ - # A copy is made of the room data since the allowed_spaces key - # is removed. - _RoomEntry(child_room[0], dict(child_room[1])) - for child_room in children_rooms - ] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return subspace_room_entry, dict(children_rooms), set() # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, subspace]), (self.room, ()), @@ -976,7 +876,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (world_readable_room, ()), (joined_room, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", @@ -1010,31 +909,17 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): }, ) - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [fed_room_entry] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return fed_room_entry, {}, set() # Add a room to the space which is on another server. self._add_child(self.space, fed_room, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, fed_room]), (self.room, ()), (fed_room, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", From 952efd0bca967bc2fcabe5c3f1f58e14ddc41686 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 28 Feb 2022 19:59:00 +0100 Subject: [PATCH 24/40] Add type hints to `tests/rest/client` (#12094) * Add type hints to `tests/rest/client` * update `mypy.ini` * newsfile * add `test_register.py` --- changelog.d/12094.misc | 1 + mypy.ini | 3 - tests/rest/client/test_events.py | 20 +++--- tests/rest/client/test_groups.py | 2 +- tests/rest/client/test_register.py | 110 +++++++++++++++-------------- 5 files changed, 72 insertions(+), 64 deletions(-) create mode 100644 changelog.d/12094.misc diff --git a/changelog.d/12094.misc b/changelog.d/12094.misc new file mode 100644 index 0000000000..0360dbd61e --- /dev/null +++ b/changelog.d/12094.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/mypy.ini b/mypy.ini index bd75905c8d..38ff787609 100644 --- a/mypy.ini +++ b/mypy.ini @@ -75,10 +75,7 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_account.py - |tests/rest/client/test_events.py |tests/rest/client/test_filter.py - |tests/rest/client/test_groups.py - |tests/rest/client/test_register.py |tests/rest/client/test_report_event.py |tests/rest/client/test_rooms.py |tests/rest/client/test_third_party_rules.py diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index 145f247836..1b1392fa2f 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -16,8 +16,12 @@ from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import events, login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -32,7 +36,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_registration_captcha"] = False @@ -41,11 +45,11 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) - hs.get_federation_handler = Mock() + hs.get_federation_handler = Mock() # type: ignore[assignment] return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register an account self.user_id = self.register_user("sid1", "pass") @@ -55,7 +59,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): self.other_user = self.register_user("other2", "pass") self.other_token = self.login(self.other_user, "pass") - def test_stream_basic_permissions(self): + def test_stream_basic_permissions(self) -> None: # invalid token, expect 401 # note: this is in violation of the original v1 spec, which expected # 403. However, since the v1 spec no longer exists and the v1 @@ -76,7 +80,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): self.assertTrue("start" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_stream_room_permissions(self): + def test_stream_room_permissions(self) -> None: room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) self.helper.send(room_id, tok=self.other_token) @@ -111,7 +115,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): # left to room (expect no content for room) - def TODO_test_stream_items(self): + def TODO_test_stream_items(self) -> None: # new user, no content # join room, expect 1 item (join) @@ -136,7 +140,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, hs, reactor, clock): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register an account self.user_id = self.register_user("sid1", "pass") @@ -144,7 +148,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) - def test_get_event_via_events(self): + def test_get_event_via_events(self) -> None: resp = self.helper.send(self.room_id, tok=self.token) event_id = resp["event_id"] diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py index c99f54cf4f..e067cf825c 100644 --- a/tests/rest/client/test_groups.py +++ b/tests/rest/client/test_groups.py @@ -25,7 +25,7 @@ class GroupsTestCase(unittest.HomeserverTestCase): servlets = [room.register_servlets, groups.register_servlets] @override_config({"enable_group_creation": True}) - def test_rooms_limited_by_visibility(self): + def test_rooms_limited_by_visibility(self) -> None: group_id = "+spqr:test" # Alice creates a group diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 4b95b8541c..9aebf1735a 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -16,15 +16,21 @@ import datetime import json import os +from typing import Any, Dict, List, Tuple import pkg_resources +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync +from synapse.server import HomeServer from synapse.storage._base import db_to_json +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -39,12 +45,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ] url = b"/_matrix/client/r0/register" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["allow_guest_access"] = True return config - def test_POST_appservice_registration_valid(self): + def test_POST_appservice_registration_valid(self) -> None: user_id = "@as_user_kermit:test" as_token = "i_am_an_app_service" @@ -69,7 +75,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) - def test_POST_appservice_registration_no_type(self): + def test_POST_appservice_registration_no_type(self) -> None: as_token = "i_am_an_app_service" appservice = ApplicationService( @@ -89,7 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"400", channel.result) - def test_POST_appservice_registration_invalid(self): + def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists request_data = json.dumps( {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} @@ -100,21 +106,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"401", channel.result) - def test_POST_bad_password(self): + def test_POST_bad_password(self) -> None: request_data = json.dumps({"username": "kermit", "password": 666}) channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.json_body["error"], "Invalid password") - def test_POST_bad_username(self): + def test_POST_bad_username(self) -> None: request_data = json.dumps({"username": 777, "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.json_body["error"], "Invalid username") - def test_POST_user_valid(self): + def test_POST_user_valid(self) -> None: user_id = "@kermit:test" device_id = "frogfone" params = { @@ -135,7 +141,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) - def test_POST_disabled_registration(self): + def test_POST_disabled_registration(self) -> None: request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) @@ -145,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["error"], "Registration has been disabled") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - def test_POST_guest_registration(self): + def test_POST_guest_registration(self) -> None: self.hs.config.key.macaroon_secret_key = "test" self.hs.config.registration.allow_guest_access = True @@ -155,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) - def test_POST_disabled_guest_registration(self): + def test_POST_disabled_guest_registration(self) -> None: self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -164,7 +170,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting_guest(self): + def test_POST_ratelimiting_guest(self) -> None: for i in range(0, 6): url = self.url + b"?kind=guest" channel = self.make_request(b"POST", url, b"{}") @@ -182,7 +188,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting(self): + def test_POST_ratelimiting(self) -> None: for i in range(0, 6): params = { "username": "kermit" + str(i), @@ -206,7 +212,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"registration_requires_token": True}) - def test_POST_registration_requires_token(self): + def test_POST_registration_requires_token(self) -> None: username = "kermit" device_id = "frogfone" token = "abcd" @@ -223,7 +229,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, ) ) - params = { + params: JsonDict = { "username": username, "password": "monkey", "device_id": device_id, @@ -280,8 +286,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(res["pending"], 0) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_invalid(self): - params = { + def test_POST_registration_token_invalid(self) -> None: + params: JsonDict = { "username": "kermit", "password": "monkey", } @@ -314,7 +320,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_limit_uses(self): + def test_POST_registration_token_limit_uses(self) -> None: token = "abcd" store = self.hs.get_datastores().main # Create token that can be used once @@ -330,8 +336,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, ) ) - params1 = {"username": "bert", "password": "monkey"} - params2 = {"username": "ernie", "password": "monkey"} + params1: JsonDict = {"username": "bert", "password": "monkey"} + params2: JsonDict = {"username": "ernie", "password": "monkey"} # Do 2 requests without auth to get two session IDs channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) session1 = channel1.json_body["session"] @@ -388,7 +394,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_expiry(self): + def test_POST_registration_token_expiry(self) -> None: token = "abcd" now = self.hs.get_clock().time_msec() store = self.hs.get_datastores().main @@ -405,7 +411,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, ) ) - params = {"username": "kermit", "password": "monkey"} + params: JsonDict = {"username": "kermit", "password": "monkey"} # Request without auth to get session channel = self.make_request(b"POST", self.url, json.dumps(params)) session = channel.json_body["session"] @@ -436,7 +442,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_session_expiry(self): + def test_POST_registration_token_session_expiry(self) -> None: """Test `pending` is decremented when an uncompleted session expires.""" token = "abcd" store = self.hs.get_datastores().main @@ -454,8 +460,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Do 2 requests without auth to get two session IDs - params1 = {"username": "bert", "password": "monkey"} - params2 = {"username": "ernie", "password": "monkey"} + params1: JsonDict = {"username": "bert", "password": "monkey"} + params2: JsonDict = {"username": "ernie", "password": "monkey"} channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) session1 = channel1.json_body["session"] channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) @@ -522,7 +528,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(pending, 0) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_session_expiry_deleted_token(self): + def test_POST_registration_token_session_expiry_deleted_token(self) -> None: """Test session expiry doesn't break when the token is deleted. 1. Start but don't complete UIA with a registration token @@ -545,7 +551,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Do request without auth to get a session ID - params = {"username": "kermit", "password": "monkey"} + params: JsonDict = {"username": "kermit", "password": "monkey"} channel = self.make_request(b"POST", self.url, json.dumps(params)) session = channel.json_body["session"] @@ -570,7 +576,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) ) - def test_advertised_flows(self): + def test_advertised_flows(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -593,7 +599,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_advertised_flows_captcha_and_terms_and_3pids(self): + def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -625,7 +631,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_advertised_flows_no_msisdn_email_required(self): + def test_advertised_flows_no_msisdn_email_required(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -646,7 +652,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_request_token_existing_email_inhibit_error(self): + def test_request_token_existing_email_inhibit_error(self) -> None: """Test that requesting a token via this endpoint doesn't leak existing associations if configured that way. """ @@ -685,7 +691,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_reject_invalid_email(self): + def test_reject_invalid_email(self) -> None: """Check that bad emails are rejected""" # Test for email with multiple @ @@ -731,7 +737,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "inhibit_user_in_use_error": True, } ) - def test_inhibit_user_in_use_error(self): + def test_inhibit_user_in_use_error(self) -> None: """Tests that the 'inhibit_user_in_use_error' configuration flag behaves correctly. """ @@ -779,7 +785,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): account_validity.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Test for account expiring after a week. config["enable_registration"] = True @@ -791,7 +797,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): return self.hs - def test_validity_period(self): + def test_validity_period(self) -> None: self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -810,7 +816,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) - def test_manual_renewal(self): + def test_manual_renewal(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -833,7 +839,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEqual(channel.result["code"], b"200", channel.result) - def test_manual_expire(self): + def test_manual_expire(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -858,7 +864,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) - def test_logging_out_expired_user(self): + def test_logging_out_expired_user(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -898,7 +904,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Test for account expiring after a week and renewal emails being sent 2 @@ -935,17 +941,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) - async def sendmail(*args, **kwargs): + async def sendmail(*args: Any, **kwargs: Any) -> None: self.email_attempts.append((args, kwargs)) - self.email_attempts = [] + self.email_attempts: List[Tuple[Any, Any]] = [] self.hs.get_send_email_handler()._sendmail = sendmail self.store = self.hs.get_datastores().main return self.hs - def test_renewal_email(self): + def test_renewal_email(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -999,7 +1005,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEqual(channel.result["code"], b"200", channel.result) - def test_renewal_invalid_token(self): + def test_renewal_invalid_token(self) -> None: # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" @@ -1019,7 +1025,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): channel.result["body"], expected_html.encode("utf8"), channel.result ) - def test_manual_email_send(self): + def test_manual_email_send(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -1032,7 +1038,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEqual(len(self.email_attempts), 1) - def test_deactivated_user(self): + def test_deactivated_user(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -1056,7 +1062,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEqual(len(self.email_attempts), 0) - def create_user(self): + def create_user(self) -> Tuple[str, str]: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") # We need to manually add an email address otherwise the handler will do @@ -1073,7 +1079,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): ) return user_id, tok - def test_manual_email_send_expired_account(self): + def test_manual_email_send_expired_account(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -1112,7 +1118,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.validity_period = 10 self.max_delta = self.validity_period * 10.0 / 100.0 @@ -1135,7 +1141,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): return self.hs - def test_background_job(self): + def test_background_job(self) -> None: """ Tests the same thing as test_background_job, except that it sets the startup_job_max_delta parameter and checks that the expiration date is within the @@ -1158,12 +1164,12 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): servlets = [register.register_servlets] url = "/_matrix/client/v1/register/m.login.registration_token/validity" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["registration_requires_token"] = True return config - def test_GET_token_valid(self): + def test_GET_token_valid(self) -> None: token = "abcd" store = self.hs.get_datastores().main self.get_success( @@ -1186,7 +1192,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["valid"], True) - def test_GET_token_invalid(self): + def test_GET_token_invalid(self) -> None: token = "1234" channel = self.make_request( b"GET", @@ -1198,7 +1204,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): @override_config( {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} ) - def test_GET_ratelimiting(self): + def test_GET_ratelimiting(self) -> None: token = "1234" for i in range(0, 6): From 9d11fee8f223787c04c6574b8a30967e2b73cc35 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Mar 2022 09:34:30 +0000 Subject: [PATCH 25/40] Improve exception handling for concurrent execution (#12109) * fix incorrect unwrapFirstError import this was being imported from the wrong place * Refactor `concurrently_execute` to use `yieldable_gather_results` * Improve exception handling in `yieldable_gather_results` Try to avoid swallowing so many stack traces. * mark unwrapFirstError deprecated * changelog --- changelog.d/12109.misc | 1 + synapse/handlers/message.py | 4 +- synapse/util/__init__.py | 4 +- synapse/util/async_helpers.py | 54 +++++++++------ tests/util/test_async_helpers.py | 115 ++++++++++++++++++++++++++++++- 5 files changed, 151 insertions(+), 27 deletions(-) create mode 100644 changelog.d/12109.misc diff --git a/changelog.d/12109.misc b/changelog.d/12109.misc new file mode 100644 index 0000000000..3295e49f43 --- /dev/null +++ b/changelog.d/12109.misc @@ -0,0 +1 @@ +Improve exception handling for concurrent execution. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a9c964cd75..ce1fa3c78e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -55,8 +55,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester -from synapse.util import json_decoder, json_encoder, log_failure -from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError +from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError +from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 511f52534b..58b4220ff3 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -81,7 +81,9 @@ json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) def unwrapFirstError(failure: Failure) -> Failure: - # defer.gatherResults and DeferredLists wrap failures. + # Deprecated: you probably just want to catch defer.FirstError and reraise + # the subFailure's value, which will do a better job of preserving stacktraces. + # (actually, you probably want to use yieldable_gather_results anyway) failure.trap(defer.FirstError) return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 3f7299aff7..a83296a229 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -29,6 +29,7 @@ from typing import ( Hashable, Iterable, Iterator, + List, Optional, Set, Tuple, @@ -51,7 +52,7 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) -from synapse.util import Clock, unwrapFirstError +from synapse.util import Clock logger = logging.getLogger(__name__) @@ -193,9 +194,9 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): T = TypeVar("T") -def concurrently_execute( +async def concurrently_execute( func: Callable[[T], Any], args: Iterable[T], limit: int -) -> defer.Deferred: +) -> None: """Executes the function with each argument concurrently while limiting the number of concurrent executions. @@ -221,20 +222,14 @@ def concurrently_execute( # We use `itertools.islice` to handle the case where the number of args is # less than the limit, avoiding needlessly spawning unnecessary background # tasks. - return make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background(_concurrently_execute_inner, value) - for value in itertools.islice(it, limit) - ], - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) + await yieldable_gather_results( + _concurrently_execute_inner, (value for value in itertools.islice(it, limit)) + ) -def yieldable_gather_results( - func: Callable, iter: Iterable, *args: Any, **kwargs: Any -) -> defer.Deferred: +async def yieldable_gather_results( + func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any +) -> List[T]: """Executes the function with each argument concurrently. Args: @@ -245,15 +240,30 @@ def yieldable_gather_results( **kwargs: Keyword arguments to be passed to each call to func Returns - Deferred[list]: Resolved when all functions have been invoked, or errors if - one of the function calls fails. + A list containing the results of the function """ - return make_deferred_yieldable( - defer.gatherResults( - [run_in_background(func, item, *args, **kwargs) for item in iter], - consumeErrors=True, + try: + return await make_deferred_yieldable( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], + consumeErrors=True, + ) ) - ).addErrback(unwrapFirstError) + except defer.FirstError as dfe: + # unwrap the error from defer.gatherResults. + + # The raised exception's traceback only includes func() etc if + # the 'await' happens before the exception is thrown - ie if the failure + # happens *asynchronously* - otherwise Twisted throws away the traceback as it + # could be large. + # + # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe + # we could throw Twisted into the fires of Mordor. + + # suppress exception chaining, because the FirstError doesn't tell us anything + # very interesting. + assert isinstance(dfe.subFailure.value, BaseException) + raise dfe.subFailure.value from None T1 = TypeVar("T1") diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index ab89cab812..cce8d595fc 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import traceback + from twisted.internet import defer -from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.internet.task import Clock +from twisted.python.failure import Failure from synapse.logging.context import ( SENTINEL_CONTEXT, @@ -21,7 +24,11 @@ from synapse.logging.context import ( PreserveLoggingContext, current_context, ) -from synapse.util.async_helpers import ObservableDeferred, timeout_deferred +from synapse.util.async_helpers import ( + ObservableDeferred, + concurrently_execute, + timeout_deferred, +) from tests.unittest import TestCase @@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase): ) self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(current_context(), context_one) + + +class _TestException(Exception): + pass + + +class ConcurrentlyExecuteTest(TestCase): + def test_limits_runners(self): + """If we have more tasks than runners, we should get the limit of runners""" + started = 0 + waiters = [] + processed = [] + + async def callback(v): + # when we first enter, bump the start count + nonlocal started + started += 1 + + # record the fact we got an item + processed.append(v) + + # wait for the goahead before returning + d2 = Deferred() + waiters.append(d2) + await d2 + + # set it going + d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3)) + + # check we got exactly 3 processes + self.assertEqual(started, 3) + self.assertEqual(len(waiters), 3) + + # let one finish + waiters.pop().callback(0) + + # ... which should start another + self.assertEqual(started, 4) + self.assertEqual(len(waiters), 3) + + # we still shouldn't be done + self.assertNoResult(d2) + + # finish the job + while waiters: + waiters.pop().callback(0) + + # check everything got done + self.assertEqual(started, 5) + self.assertCountEqual(processed, [1, 2, 3, 4, 5]) + self.successResultOf(d2) + + def test_preserves_stacktraces(self): + """Test that the stacktrace from an exception thrown in the callback is preserved""" + d1 = Deferred() + + async def callback(v): + # alas, this doesn't work at all without an await here + await d1 + raise _TestException("bah") + + async def caller(): + try: + await concurrently_execute(callback, [1], 2) + except _TestException as e: + tb = traceback.extract_tb(e.__traceback__) + # we expect to see "caller", "concurrently_execute" and "callback". + self.assertEqual(tb[0].name, "caller") + self.assertEqual(tb[1].name, "concurrently_execute") + self.assertEqual(tb[-1].name, "callback") + else: + self.fail("No exception thrown") + + d2 = ensureDeferred(caller()) + d1.callback(0) + self.successResultOf(d2) + + def test_preserves_stacktraces_on_preformed_failure(self): + """Test that the stacktrace on a Failure returned by the callback is preserved""" + d1 = Deferred() + f = Failure(_TestException("bah")) + + async def callback(v): + # alas, this doesn't work at all without an await here + await d1 + await defer.fail(f) + + async def caller(): + try: + await concurrently_execute(callback, [1], 2) + except _TestException as e: + tb = traceback.extract_tb(e.__traceback__) + # we expect to see "caller", "concurrently_execute", "callback", + # and some magic from inside ensureDeferred that happens when .fail + # is called. + self.assertEqual(tb[0].name, "caller") + self.assertEqual(tb[1].name, "concurrently_execute") + self.assertEqual(tb[-2].name, "callback") + else: + self.fail("No exception thrown") + + d2 = ensureDeferred(caller()) + d1.callback(0) + self.successResultOf(d2) From 5458eb8551be676fea7ff21e2b0d3c3762c871a7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Mar 2022 09:51:38 +0000 Subject: [PATCH 26/40] Fix 'Unhandled error in Deferred' (#12089) * Fix 'Unhandled error in Deferred' Fixes a CRITICAL "Unhandled error in Deferred" log message which happened when a function wrapped with `@cachedList` failed * Minor optimisation to cachedListDescriptor we can avoid re-using `missing`, which saves looking up entries in `deferreds_map`, and means we don't need to copy it. * Improve type annotation on CachedListDescriptor --- changelog.d/12089.bugfix | 1 + synapse/util/caches/descriptors.py | 62 +++++++++++++-------------- tests/util/caches/test_descriptors.py | 10 ++--- 3 files changed, 37 insertions(+), 36 deletions(-) create mode 100644 changelog.d/12089.bugfix diff --git a/changelog.d/12089.bugfix b/changelog.d/12089.bugfix new file mode 100644 index 0000000000..27172c4828 --- /dev/null +++ b/changelog.d/12089.bugfix @@ -0,0 +1 @@ +Fix occasional 'Unhandled error in Deferred' error message. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index df4fb156c2..1cdead02f1 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,6 +18,7 @@ import inspect import logging from typing import ( Any, + Awaitable, Callable, Dict, Generic, @@ -346,15 +347,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given an iterable of keys it looks in the cache to find any hits, then passes - the tuple of missing keys to the wrapped function. + the set of missing keys to the wrapped function. - Once wrapped, the function returns a Deferred which resolves to the list - of results. + Once wrapped, the function returns a Deferred which resolves to a Dict mapping from + input key to output value. """ def __init__( self, - orig: Callable[..., Any], + orig: Callable[..., Awaitable[Dict]], cached_method_name: str, list_name: str, num_args: Optional[int] = None, @@ -385,13 +386,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): def __get__( self, obj: Optional[Any], objtype: Optional[Type] = None - ) -> Callable[..., Any]: + ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]: cached_method = getattr(obj, self.cached_method_name) cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) - def wrapped(*args: Any, **kwargs: Any) -> Any: + def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": # If we're passed a cache_context then we'll want to call its # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -444,39 +445,38 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): deferred: "defer.Deferred[Any]" = defer.Deferred() deferreds_map[arg] = deferred key = arg_to_cache_key(arg) - cache.set(key, deferred, callback=invalidate_callback) + cached_defers.append( + cache.set(key, deferred, callback=invalidate_callback) + ) def complete_all(res: Dict[Hashable, Any]) -> None: - # the wrapped function has completed. It returns a - # a dict. We can now resolve the observable deferreds in - # the cache and update our own result map. - for e in missing: + # the wrapped function has completed. It returns a dict. + # We can now update our own result map, and then resolve the + # observable deferreds in the cache. + for e, d1 in deferreds_map.items(): val = res.get(e, None) - deferreds_map[e].callback(val) + # make sure we update the results map before running the + # deferreds, because as soon as we run the last deferred, the + # gatherResults() below will complete and return the result + # dict to our caller. results[e] = val + d1.callback(val) - def errback(f: Failure) -> Failure: - # the wrapped function has failed. Invalidate any cache - # entries we're supposed to be populating, and fail - # their deferreds. - for e in missing: - key = arg_to_cache_key(e) - cache.invalidate(key) - deferreds_map[e].errback(f) - - # return the failure, to propagate to our caller. - return f + def errback_all(f: Failure) -> None: + # the wrapped function has failed. Propagate the failure into + # the cache, which will invalidate the entry, and cause the + # relevant cached_deferreds to fail, which will propagate the + # failure to our caller. + for d1 in deferreds_map.values(): + d1.errback(f) args_to_call = dict(arg_dict) - # copy the missing set before sending it to the callee, to guard against - # modification. - args_to_call[self.list_name] = tuple(missing) + args_to_call[self.list_name] = missing - cached_defers.append( - defer.maybeDeferred( - preserve_fn(self.orig), **args_to_call - ).addCallbacks(complete_all, errback) - ) + # dispatch the call, and attach the two handlers + defer.maybeDeferred( + preserve_fn(self.orig), **args_to_call + ).addCallbacks(complete_all, errback_all) if cached_defers: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index b92d3f0c1b..19741ffcda 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -673,14 +673,14 @@ class CachedListDescriptorTestCase(unittest.TestCase): self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 self.assertEqual(current_context(), c1) - obj.mock.assert_called_once_with((10, 20), 2) + obj.mock.assert_called_once_with({10, 20}, 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = {30: "peas"} r = yield obj.list_fn([20, 30], 2) - obj.mock.assert_called_once_with((30,), 2) + obj.mock.assert_called_once_with({30}, 2) self.assertEqual(r, {20: "chips", 30: "peas"}) obj.mock.reset_mock() @@ -701,7 +701,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.mock.return_value = {40: "gravy"} iterable = (x for x in [10, 40, 40]) r = yield obj.list_fn(iterable, 2) - obj.mock.assert_called_once_with((40,), 2) + obj.mock.assert_called_once_with({40}, 2) self.assertEqual(r, {10: "fish", 40: "gravy"}) def test_concurrent_lookups(self): @@ -729,7 +729,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): d3 = obj.list_fn([10]) # the mock should have been called exactly once - obj.mock.assert_called_once_with((10,)) + obj.mock.assert_called_once_with({10}) obj.mock.reset_mock() # ... and none of the calls should yet be complete @@ -771,7 +771,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): # cache miss obj.mock.return_value = {10: "fish", 20: "chips"} r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) - obj.mock.assert_called_once_with((10, 20), 2) + obj.mock.assert_called_once_with({10, 20}, 2) self.assertEqual(r1, {10: "fish", 20: "chips"}) obj.mock.reset_mock() From 4ccc2d09aae71da0be725ac177a9d4aced9a53c9 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 1 Mar 2022 12:35:32 +0000 Subject: [PATCH 27/40] Advertise Python 3.10 support in setup.py (#12111) --- changelog.d/12111.misc | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/12111.misc diff --git a/changelog.d/12111.misc b/changelog.d/12111.misc new file mode 100644 index 0000000000..be84789c9d --- /dev/null +++ b/changelog.d/12111.misc @@ -0,0 +1 @@ +Advertise support for Python 3.10 in packaging files. \ No newline at end of file diff --git a/setup.py b/setup.py index c80cb6f207..26f4650348 100755 --- a/setup.py +++ b/setup.py @@ -165,6 +165,7 @@ setup( "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], scripts=["synctl"] + glob.glob("scripts/*"), cmdclass={"test": TestCommand}, From e2e1d90a5e4030616a3de242cde26c0cfff4a6b5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Mar 2022 12:49:54 +0000 Subject: [PATCH 28/40] Faster joins: persist to database (#12012) When we get a partial_state response from send_join, store information in the database about it: * store a record about the room as a whole having partial state, and stash the list of member servers too. * flag the join event itself as having partial state * also, for any new events whose prev-events are partial-stated, note that they will *also* be partial-stated. We don't yet make any attempt to interpret this data, so API calls (and a bunch of other things) are just going to get incorrect data. --- changelog.d/12012.misc | 1 + synapse/events/snapshot.py | 9 +++ synapse/handlers/federation.py | 11 ++- synapse/handlers/federation_event.py | 13 +++- synapse/handlers/message.py | 2 + synapse/state/__init__.py | 31 +++++++- synapse/storage/databases/main/events.py | 25 +++++++ .../storage/databases/main/events_worker.py | 28 ++++++++ synapse/storage/databases/main/room.py | 37 ++++++++++ .../main/delta/68/04partial_state_rooms.sql | 41 +++++++++++ .../68/05partial_state_rooms_triggers.py | 72 +++++++++++++++++++ tests/test_state.py | 59 ++++++++------- 12 files changed, 297 insertions(+), 32 deletions(-) create mode 100644 changelog.d/12012.misc create mode 100644 synapse/storage/schema/main/delta/68/04partial_state_rooms.sql create mode 100644 synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py diff --git a/changelog.d/12012.misc b/changelog.d/12012.misc new file mode 100644 index 0000000000..a473f41e78 --- /dev/null +++ b/changelog.d/12012.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 5833fee25f..46042b2bf7 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -101,6 +101,9 @@ class EventContext: As with _current_state_ids, this is a private attribute. It should be accessed via get_prev_state_ids. + + partial_state: if True, we may be storing this event with a temporary, + incomplete state. """ rejected: Union[bool, str] = False @@ -113,12 +116,15 @@ class EventContext: _current_state_ids: Optional[StateMap[str]] = None _prev_state_ids: Optional[StateMap[str]] = None + partial_state: bool = False + @staticmethod def with_state( state_group: Optional[int], state_group_before_event: Optional[int], current_state_ids: Optional[StateMap[str]], prev_state_ids: Optional[StateMap[str]], + partial_state: bool, prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ) -> "EventContext": @@ -129,6 +135,7 @@ class EventContext: state_group_before_event=state_group_before_event, prev_group=prev_group, delta_ids=delta_ids, + partial_state=partial_state, ) @staticmethod @@ -170,6 +177,7 @@ class EventContext: "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), "app_service_id": self.app_service.id if self.app_service else None, + "partial_state": self.partial_state, } @staticmethod @@ -196,6 +204,7 @@ class EventContext: prev_group=input["prev_group"], delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], + partial_state=input.get("partial_state", False), ) app_service_id = input["app_service_id"] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c055c26eca..eb03a5accb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -519,8 +519,17 @@ class FederationHandler: state_events=state, ) + if ret.partial_state: + await self.store.store_partial_state_room(room_id, ret.servers_in_room) + max_stream_id = await self._federation_event_handler.process_remote_join( - origin, room_id, auth_chain, state, event, room_version_obj + origin, + room_id, + auth_chain, + state, + event, + room_version_obj, + partial_state=ret.partial_state, ) # We wait here until this instance has seen the events come down diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 09d0de1ead..4bd87709f3 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -397,6 +397,7 @@ class FederationEventHandler: state: List[EventBase], event: EventBase, room_version: RoomVersion, + partial_state: bool, ) -> int: """Persists the events returned by a send_join @@ -412,6 +413,7 @@ class FederationEventHandler: event room_version: The room version we expect this room to have, and will raise if it doesn't match the version in the create event. + partial_state: True if the state omits non-critical membership events Returns: The stream ID after which all events have been persisted. @@ -453,10 +455,14 @@ class FederationEventHandler: ) # and now persist the join event itself. - logger.info("Peristing join-via-remote %s", event) + logger.info( + "Peristing join-via-remote %s (partial_state: %s)", event, partial_state + ) with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( - event, old_state=state + event, + old_state=state, + partial_state=partial_state, ) context = await self._check_event_auth(origin, event, context) @@ -698,6 +704,8 @@ class FederationEventHandler: try: state = await self._resolve_state_at_missing_prevs(origin, event) + # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does + # not return partial state await self._process_received_pdu( origin, event, state=state, backfilled=backfilled ) @@ -1791,6 +1799,7 @@ class FederationEventHandler: prev_state_ids=prev_state_ids, prev_group=prev_group, delta_ids=state_updates, + partial_state=context.partial_state, ) async def _run_push_actions_and_persist_event( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ce1fa3c78e..61cb133ef2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -992,6 +992,8 @@ class EventCreationHandler: and full_state_ids_at_event and builder.internal_metadata.is_historical() ): + # TODO(faster_joins): figure out how this works, and make sure that the + # old state is complete. old_state = await self.store.get_events_as_list(full_state_ids_at_event) context = await self.state.compute_event_context(event, old_state=old_state) else: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index fcc24ad129..6babd5963c 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -258,7 +258,10 @@ class StateHandler: return await self.store.get_joined_hosts(room_id, entry) async def compute_event_context( - self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None + self, + event: EventBase, + old_state: Optional[Iterable[EventBase]] = None, + partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,6 +276,8 @@ class StateHandler: calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. + partial_state: True if `old_state` is partial and omits non-critical + membership events Returns: The event context. """ @@ -295,8 +300,28 @@ class StateHandler: else: # otherwise, we'll need to resolve the state across the prev_events. - logger.debug("calling resolve_state_groups from compute_event_context") + # partial_state should not be set explicitly in this case: + # we work it out dynamically + assert not partial_state + + # if any of the prev-events have partial state, so do we. + # (This is slightly racy - the prev-events might get fixed up before we use + # their states - but I don't think that really matters; it just means we + # might redundantly recalculate the state for this event later.) + prev_event_ids = event.prev_event_ids() + incomplete_prev_events = await self.store.get_partial_state_events( + prev_event_ids + ) + if any(incomplete_prev_events.values()): + logger.debug( + "New/incoming event %s refers to prev_events %s with partial state", + event.event_id, + [k for (k, v) in incomplete_prev_events.items() if v], + ) + partial_state = True + + logger.debug("calling resolve_state_groups from compute_event_context") entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -342,6 +367,7 @@ class StateHandler: prev_state_ids=state_ids_before_event, prev_group=state_group_before_event_prev_group, delta_ids=deltas_to_state_group_before_event, + partial_state=partial_state, ) # @@ -373,6 +399,7 @@ class StateHandler: prev_state_ids=state_ids_before_event, prev_group=state_group_before_event, delta_ids=delta_ids, + partial_state=partial_state, ) @measure_func() diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 23fa089bca..ca2a9ba9d1 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2145,6 +2145,14 @@ class PersistEventsStore: state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): + # double-check that we don't have any events that claim to be outliers + # *and* have partial state (which is meaningless: we should have no + # state at all for an outlier) + if context.partial_state: + raise ValueError( + "Outlier event %s claims to have partial state", event.event_id + ) + continue # if the event was rejected, just give it the same state as its @@ -2155,6 +2163,23 @@ class PersistEventsStore: state_groups[event.event_id] = context.state_group + # if we have partial state for these events, record the fact. (This happens + # here rather than in _store_event_txn because it also needs to happen when + # we de-outlier an event.) + self.db_pool.simple_insert_many_txn( + txn, + table="partial_state_events", + keys=("room_id", "event_id"), + values=[ + ( + event.room_id, + event.event_id, + ) + for event, ctx in events_and_contexts + if ctx.partial_state + ], + ) + self.db_pool.simple_upsert_many_txn( txn, table="event_to_state_groups", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 2a255d1031..26784f755e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1953,3 +1953,31 @@ class EventsWorkerStore(SQLBaseStore): "get_event_id_for_timestamp_txn", get_event_id_for_timestamp_txn, ) + + @cachedList("is_partial_state_event", list_name="event_ids") + async def get_partial_state_events( + self, event_ids: Collection[str] + ) -> Dict[str, bool]: + """Checks which of the given events have partial state""" + result = await self.db_pool.simple_select_many_batch( + table="partial_state_events", + column="event_id", + iterable=event_ids, + retcols=["event_id"], + desc="get_partial_state_events", + ) + # convert the result to a dict, to make @cachedList work + partial = {r["event_id"] for r in result} + return {e_id: e_id in partial for e_id in event_ids} + + @cached() + async def is_partial_state_event(self, event_id: str) -> bool: + """Checks if the given event has partial state""" + result = await self.db_pool.simple_select_one_onecol( + table="partial_state_events", + keyvalues={"event_id": event_id}, + retcol="1", + allow_none=True, + desc="is_partial_state_event", + ) + return result is not None diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 0416df64ce..94068940b9 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -20,6 +20,7 @@ from typing import ( TYPE_CHECKING, Any, Awaitable, + Collection, Dict, List, Optional, @@ -1543,6 +1544,42 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): lock=False, ) + async def store_partial_state_room( + self, + room_id: str, + servers: Collection[str], + ) -> None: + """Mark the given room as containing events with partial state + + Args: + room_id: the ID of the room + servers: other servers known to be in the room + """ + await self.db_pool.runInteraction( + "store_partial_state_room", + self._store_partial_state_room_txn, + room_id, + servers, + ) + + @staticmethod + def _store_partial_state_room_txn( + txn: LoggingTransaction, room_id: str, servers: Collection[str] + ) -> None: + DatabasePool.simple_insert_txn( + txn, + table="partial_state_rooms", + values={ + "room_id": room_id, + }, + ) + DatabasePool.simple_insert_many_txn( + txn, + table="partial_state_rooms_servers", + keys=("room_id", "server_name"), + values=((room_id, s) for s in servers), + ) + async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion ) -> None: diff --git a/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql new file mode 100644 index 0000000000..815c0cc390 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql @@ -0,0 +1,41 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- rooms which we have done a partial-state-style join to +CREATE TABLE IF NOT EXISTS partial_state_rooms ( + room_id TEXT PRIMARY KEY, + FOREIGN KEY(room_id) REFERENCES rooms(room_id) +); + +-- a list of remote servers we believe are in the room +CREATE TABLE IF NOT EXISTS partial_state_rooms_servers ( + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + server_name TEXT NOT NULL, + UNIQUE(room_id, server_name) +); + +-- a list of events with partial state. We can't store this in the `events` table +-- itself, because `events` is meant to be append-only. +CREATE TABLE IF NOT EXISTS partial_state_events ( + -- the room_id is denormalised for efficient indexing (the canonical source is `events`) + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + event_id TEXT NOT NULL REFERENCES events(event_id), + UNIQUE(event_id) +); + +CREATE INDEX IF NOT EXISTS partial_state_events_room_id_idx + ON partial_state_events (room_id); + + diff --git a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py new file mode 100644 index 0000000000..a2ec4fc26e --- /dev/null +++ b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py @@ -0,0 +1,72 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This migration adds triggers to the partial_state_events tables to enforce uniqueness + +Triggers cannot be expressed in .sql files, so we have to use a separate file. +""" +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + # complain if the room_id in partial_state_events doesn't match + # that in `events`. We already have a fk constraint which ensures that the event + # exists in `events`, so all we have to do is raise if there is a row with a + # matching stream_ordering but not a matching room_id. + if isinstance(database_engine, Sqlite3Engine): + cur.execute( + """ + CREATE TRIGGER IF NOT EXISTS partial_state_events_bad_room_id + BEFORE INSERT ON partial_state_events + FOR EACH ROW + BEGIN + SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events') + WHERE EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ); + END; + """ + ) + elif isinstance(database_engine, PostgresEngine): + cur.execute( + """ + CREATE OR REPLACE FUNCTION check_partial_state_events() RETURNS trigger AS $BODY$ + BEGIN + IF EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ) THEN + RAISE EXCEPTION 'Incorrect room_id in partial_state_events'; + END IF; + RETURN NEW; + END; + $BODY$ LANGUAGE plpgsql; + """ + ) + + cur.execute( + """ + CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events + FOR EACH ROW + EXECUTE PROCEDURE check_partial_state_events() + """ + ) + else: + raise NotImplementedError("Unknown database engine") diff --git a/tests/test_state.py b/tests/test_state.py index 90800421fb..e4baa69137 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Collection, Dict, List, Optional from unittest.mock import Mock from twisted.internet import defer @@ -70,7 +70,7 @@ def create_event( return event -class StateGroupStore: +class _DummyStore: def __init__(self): self._event_to_state_group = {} self._group_to_state = {} @@ -105,6 +105,11 @@ class StateGroupStore: if e_id in self._event_id_to_event } + async def get_partial_state_events( + self, event_ids: Collection[str] + ) -> Dict[str, bool]: + return {e: False for e in event_ids} + async def get_state_group_delta(self, name): return None, None @@ -157,8 +162,8 @@ class Graph: class StateTestCase(unittest.TestCase): def setUp(self): - self.store = StateGroupStore() - storage = Mock(main=self.store, state=self.store) + self.dummy_store = _DummyStore() + storage = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", @@ -173,7 +178,7 @@ class StateTestCase(unittest.TestCase): ] ) hs.config = default_config("tesths", True) - hs.get_datastores.return_value = Mock(main=self.store) + hs.get_datastores.return_value = Mock(main=self.dummy_store) hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) @@ -198,7 +203,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store: dict[str, EventContext] = {} @@ -206,7 +211,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context ctx_c = context_store["C"] @@ -242,7 +247,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -250,7 +255,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between B and C @@ -300,7 +305,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -308,7 +313,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between C and D because bans win over other @@ -375,7 +380,7 @@ class StateTestCase(unittest.TestCase): self._add_depths(nodes, edges) graph = Graph(nodes, edges) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -383,7 +388,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # B ends up winning the resolution between B and C because power levels @@ -476,7 +481,7 @@ class StateTestCase(unittest.TestCase): ] group_name = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id, event.room_id, None, @@ -484,7 +489,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state}, ) ) - self.store.register_event_id_state_group(prev_event_id, group_name) + self.dummy_store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred(self.state.compute_event_context(event)) @@ -510,7 +515,7 @@ class StateTestCase(unittest.TestCase): ] group_name = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id, event.room_id, None, @@ -518,7 +523,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state}, ) ) - self.store.register_event_id_state_group(prev_event_id, group_name) + self.dummy_store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred(self.state.compute_event_context(event)) @@ -554,8 +559,8 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] - self.store.register_events(old_state_1) - self.store.register_events(old_state_2) + self.dummy_store.register_events(old_state_1) + self.dummy_store.register_events(old_state_2) context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -594,10 +599,10 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] - store = StateGroupStore() + store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.store.get_events = store.get_events + self.dummy_store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -649,10 +654,10 @@ class StateTestCase(unittest.TestCase): create_event(type="test1", state_key="1", depth=2), ] - store = StateGroupStore() + store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.store.get_events = store.get_events + self.dummy_store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -695,7 +700,7 @@ class StateTestCase(unittest.TestCase): self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): sg1 = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id_1, event.room_id, None, @@ -703,10 +708,10 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state_1}, ) ) - self.store.register_event_id_state_group(prev_event_id_1, sg1) + self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1) sg2 = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id_2, event.room_id, None, @@ -714,7 +719,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state_2}, ) ) - self.store.register_event_id_state_group(prev_event_id_2, sg2) + self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2) result = yield defer.ensureDeferred(self.state.compute_event_context(event)) return result From c893632319f9bcd76d105573008e8cb0ec2fe7ce Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 1 Mar 2022 13:41:57 +0000 Subject: [PATCH 29/40] Order in-flight state group queries in biggest-first order (#11610) Co-authored-by: Patrick Cloke --- changelog.d/11610.misc | 1 + synapse/storage/databases/state/store.py | 30 +++++- tests/storage/databases/test_state_store.py | 104 +++++++++++++++++++- 3 files changed, 131 insertions(+), 4 deletions(-) create mode 100644 changelog.d/11610.misc diff --git a/changelog.d/11610.misc b/changelog.d/11610.misc new file mode 100644 index 0000000000..3af049b969 --- /dev/null +++ b/changelog.d/11610.misc @@ -0,0 +1 @@ +Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index b8016f679a..dadf3d1e3a 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -25,6 +25,7 @@ from typing import ( ) import attr +from sortedcontainers import SortedDict from twisted.internet import defer @@ -72,6 +73,24 @@ class _GetStateGroupDelta: return len(self.delta_ids) if self.delta_ids else 0 +def state_filter_rough_priority_comparator( + state_filter: StateFilter, +) -> Tuple[int, int]: + """ + Returns a comparable value that roughly indicates the relative size of this + state filter compared to others. + 'Larger' state filters should sort first when using ascending order, so + this is essentially the opposite of 'size'. + It should be treated as a rough guide only and should not be interpreted to + have any particular meaning. The representation may also change + + The current implementation returns a tuple of the form: + * -1 for include_others, 0 otherwise + * -(number of entries in state_filter.types) + """ + return -int(state_filter.include_others), -len(state_filter.types) + + class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """A data store for fetching/storing state groups.""" @@ -127,7 +146,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # Current ongoing get_state_for_groups in-flight requests # {group ID -> {StateFilter -> ObservableDeferred}} self._state_group_inflight_requests: Dict[ - int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]] + int, SortedDict[StateFilter, AbstractObservableDeferred[StateMap[str]]] ] = {} def get_max_state_group_txn(txn: Cursor) -> int: @@ -279,7 +298,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # The list of ongoing requests which will help narrow the current request. reusable_requests = [] - for (request_state_filter, request_deferred) in inflight_requests.items(): + + # Iterate over existing requests in roughly biggest-first order. + for request_state_filter in inflight_requests: + request_deferred = inflight_requests[request_state_filter] new_state_filter_left_over = state_filter_left_over.approx_difference( request_state_filter ) @@ -358,7 +380,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True) # Insert the ObservableDeferred into the cache - group_request_dict = self._state_group_inflight_requests.setdefault(group, {}) + group_request_dict = self._state_group_inflight_requests.setdefault( + group, SortedDict(state_filter_rough_priority_comparator) + ) group_request_dict[db_state_filter] = observable_deferred return await make_deferred_yieldable(observable_deferred.observe()) diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py index 076b660809..2b484c95a9 100644 --- a/tests/storage/databases/test_state_store.py +++ b/tests/storage/databases/test_state_store.py @@ -15,11 +15,16 @@ import typing from typing import Dict, List, Sequence, Tuple from unittest.mock import patch +from parameterized import parameterized + from twisted.internet.defer import Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes -from synapse.storage.databases.state.store import MAX_INFLIGHT_REQUESTS_PER_GROUP +from synapse.storage.databases.state.store import ( + MAX_INFLIGHT_REQUESTS_PER_GROUP, + state_filter_rough_priority_comparator, +) from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util import Clock @@ -350,3 +355,100 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase): self._complete_request_fake(groups, sf, d) self.assertTrue(reqs[CAP_COUNT].called) self.assertTrue(reqs[CAP_COUNT + 1].called) + + @parameterized.expand([(False,), (True,)]) + def test_ordering_of_request_reuse(self, reverse: bool) -> None: + """ + Tests that 'larger' in-flight requests are ordered first. + + This is mostly a design decision in order to prevent a request from + hanging on to multiple queries when it would have been sufficient to + hang on to only one bigger query. + + The 'size' of a state filter is a rough heuristic. + + - requests two pieces of state, one 'larger' than the other, but each + spawning a query + - requests a third piece of state + - completes the larger of the first two queries + - checks that the third request gets completed (and doesn't needlessly + wait for the other query) + + Parameters: + reverse: whether to reverse the order of the initial requests, to ensure + that the effect doesn't depend on the order of request submission. + """ + + # We add in an extra state type to make sure that both requests spawn + # queries which are not optimised out. + state_filters = [ + StateFilter.freeze( + {"state.type": {"A"}, "other.state.type": {"a"}}, include_others=False + ), + StateFilter.freeze( + { + "state.type": None, + "other.state.type": {"b"}, + # The current rough size comparator uses the number of state types + # as an indicator of size. + # To influence it to make this state filter bigger than the previous one, + # we add another dummy state type. + "extra.state.type": {"c"}, + }, + include_others=False, + ), + ] + + if reverse: + # For fairness, we perform one test run with the list reversed. + state_filters.reverse() + smallest_state_filter_idx = 1 + biggest_state_filter_idx = 0 + else: + smallest_state_filter_idx = 0 + biggest_state_filter_idx = 1 + + # This assertion is for our own sanity more than anything else. + self.assertLess( + state_filter_rough_priority_comparator( + state_filters[biggest_state_filter_idx] + ), + state_filter_rough_priority_comparator( + state_filters[smallest_state_filter_idx] + ), + "Test invalid: bigger state filter is not actually bigger.", + ) + + # Spawn the initial two requests + for state_filter in state_filters: + ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, + state_filter, + ) + ) + + # Spawn a third request + req = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, + StateFilter.freeze( + { + "state.type": {"A"}, + }, + include_others=False, + ), + ) + ) + self.pump(by=0.1) + + self.assertFalse(req.called) + + # Complete the largest request's query to make sure that the final request + # only waits for that one (and doesn't needlessly wait for both queries) + self._complete_request_fake( + *self.get_state_group_calls[biggest_state_filter_idx] + ) + + # That should have been sufficient to complete the third request + self.assertTrue(req.called) From 91bc15c772d22fbe814170ab2e0fdbfa50f9c372 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 1 Mar 2022 13:51:03 +0000 Subject: [PATCH 30/40] Add `stop_cancellation` utility function (#12106) --- changelog.d/12106.misc | 1 + synapse/util/async_helpers.py | 19 ++++++++++++++ tests/util/test_async_helpers.py | 45 ++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) create mode 100644 changelog.d/12106.misc diff --git a/changelog.d/12106.misc b/changelog.d/12106.misc new file mode 100644 index 0000000000..d918e9e3b1 --- /dev/null +++ b/changelog.d/12106.misc @@ -0,0 +1 @@ +Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a83296a229..81320b8972 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -665,3 +665,22 @@ def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: return value return DoneAwaitable(value) + + +def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`. + + Args: + deferred: The `Deferred` to protect against cancellation. Must not follow the + Synapse logcontext rules. + + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`, + but will not propagate cancellation through to the original. When cancelled, + the new `Deferred` will fail with a `CancelledError` and will not follow the + Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap + the new `Deferred`. + """ + new_deferred: defer.Deferred[T] = defer.Deferred() + deferred.chainDeferred(new_deferred) + return new_deferred diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index cce8d595fc..362014f4cb 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -27,6 +27,7 @@ from synapse.logging.context import ( from synapse.util.async_helpers import ( ObservableDeferred, concurrently_execute, + stop_cancellation, timeout_deferred, ) @@ -282,3 +283,47 @@ class ConcurrentlyExecuteTest(TestCase): d2 = ensureDeferred(caller()) d1.callback(0) self.successResultOf(d2) + + +class StopCancellationTests(TestCase): + """Tests for the `stop_cancellation` function.""" + + def test_succeed(self): + """Test that the new `Deferred` receives the result.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Success should propagate through. + deferred.callback("success") + self.assertTrue(wrapper_deferred.called) + self.assertEqual("success", self.successResultOf(wrapper_deferred)) + + def test_failure(self): + """Test that the new `Deferred` receives the `Failure`.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Failure should propagate through. + deferred.errback(ValueError("abc")) + self.assertTrue(wrapper_deferred.called) + self.failureResultOf(wrapper_deferred, ValueError) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + def test_cancellation(self): + """Test that cancellation of the new `Deferred` leaves the original running.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertTrue(wrapper_deferred.called) + self.failureResultOf(wrapper_deferred, CancelledError) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled." + ) + + # Now make the inner `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") From f26e390a40288be2801b3b9b3a99269b3f3ff81f Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 1 Mar 2022 13:55:18 +0000 Subject: [PATCH 31/40] Use Python 3.9 in Synapse dockerfiles by default (#12112) --- changelog.d/12112.docker | 1 + docker/Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12112.docker diff --git a/changelog.d/12112.docker b/changelog.d/12112.docker new file mode 100644 index 0000000000..b9e630653d --- /dev/null +++ b/changelog.d/12112.docker @@ -0,0 +1 @@ +Use Python 3.9 in Docker images by default. \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index e4c1c19b86..a8bb9b0e7f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,10 +11,10 @@ # There is an optional PYTHON_VERSION build argument which sets the # version of python to build against: for example: # -# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.9 . +# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.10 . # -ARG PYTHON_VERSION=3.8 +ARG PYTHON_VERSION=3.9 ### ### Stage 0: builder From 300ed0b8a6050b5187a2a524a82cf87baad3ca73 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 1 Mar 2022 15:00:03 +0000 Subject: [PATCH 32/40] Add module callbacks called for reacting to deactivation status change and profile update (#12062) --- changelog.d/12062.feature | 1 + docs/modules/third_party_rules_callbacks.md | 56 +++++ synapse/events/third_party_rules.py | 56 ++++- synapse/handlers/deactivate_account.py | 20 +- synapse/handlers/profile.py | 14 ++ synapse/module_api/__init__.py | 1 + tests/rest/client/test_third_party_rules.py | 219 +++++++++++++++++++- 7 files changed, 360 insertions(+), 7 deletions(-) create mode 100644 changelog.d/12062.feature diff --git a/changelog.d/12062.feature b/changelog.d/12062.feature new file mode 100644 index 0000000000..46a606709d --- /dev/null +++ b/changelog.d/12062.feature @@ -0,0 +1 @@ +Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index a3a17096a8..09ac838107 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -148,6 +148,62 @@ deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#c If multiple modules implement this callback, Synapse runs them all in order. +### `on_profile_update` + +_First introduced in Synapse v1.54.0_ + +```python +async def on_profile_update( + user_id: str, + new_profile: "synapse.module_api.ProfileInfo", + by_admin: bool, + deactivation: bool, +) -> None: +``` + +Called after updating a local user's profile. The update can be triggered either by the +user themselves or a server admin. The update can also be triggered by a user being +deactivated (in which case their display name is set to an empty string (`""`) and the +avatar URL is set to `None`). The module is passed the Matrix ID of the user whose profile +has been updated, their new profile, as well as a `by_admin` boolean that is `True` if the +update was triggered by a server admin (and `False` otherwise), and a `deactivated` +boolean that is `True` if the update is a result of the user being deactivated. + +Note that the `by_admin` boolean is also `True` if the profile change happens as a result +of the user logging in through Single Sign-On, or if a server admin updates their own +profile. + +Per-room profile changes do not trigger this callback to be called. Synapse administrators +wishing this callback to be called on every profile change are encouraged to disable +per-room profiles globally using the `allow_per_room_profiles` configuration setting in +Synapse's configuration file. +This callback is not called when registering a user, even when setting it through the +[`get_displayname_for_registration`](https://matrix-org.github.io/synapse/latest/modules/password_auth_provider_callbacks.html#get_displayname_for_registration) +module callback. + +If multiple modules implement this callback, Synapse runs them all in order. + +### `on_user_deactivation_status_changed` + +_First introduced in Synapse v1.54.0_ + +```python +async def on_user_deactivation_status_changed( + user_id: str, deactivated: bool, by_admin: bool +) -> None: +``` + +Called after deactivating a local user, or reactivating them through the admin API. The +deactivation can be triggered either by the user themselves or a server admin. The module +is passed the Matrix ID of the user whose status is changed, as well as a `deactivated` +boolean that is `True` if the user is being deactivated and `False` if they're being +reactivated, and a `by_admin` boolean that is `True` if the deactivation was triggered by +a server admin (and `False` otherwise). This latter `by_admin` boolean is always `True` +if the user is being reactivated, as this operation can only be performed through the +admin API. + +If multiple modules implement this callback, Synapse runs them all in order. + ## Example The example below is a module that implements the third-party rules callback diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 71ec100a7f..dd3104faf3 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tupl from synapse.api.errors import ModuleFailedException, SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.storage.roommember import ProfileInfo from synapse.types import Requester, StateMap from synapse.util.async_helpers import maybe_awaitable @@ -37,6 +38,8 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ [str, StateMap[EventBase], str], Awaitable[bool] ] ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] +ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] +ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -154,6 +157,10 @@ class ThirdPartyEventRules: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = [] self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = [] + self._on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = [] + self._on_user_deactivation_status_changed_callbacks: List[ + ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK + ] = [] def register_third_party_rules_callbacks( self, @@ -166,6 +173,8 @@ class ThirdPartyEventRules: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, + on_deactivation: Optional[ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -187,6 +196,12 @@ class ThirdPartyEventRules: if on_new_event is not None: self._on_new_event_callbacks.append(on_new_event) + if on_profile_update is not None: + self._on_profile_update_callbacks.append(on_profile_update) + + if on_deactivation is not None: + self._on_user_deactivation_status_changed_callbacks.append(on_deactivation) + async def check_event_allowed( self, event: EventBase, context: EventContext ) -> Tuple[bool, Optional[dict]]: @@ -334,9 +349,6 @@ class ThirdPartyEventRules: Args: event_id: The ID of the event. - - Raises: - ModuleFailureError if a callback raised any exception. """ # Bail out early without hitting the store if we don't have any callbacks if len(self._on_new_event_callbacks) == 0: @@ -370,3 +382,41 @@ class ThirdPartyEventRules: state_events[key] = room_state_events[event_id] return state_events + + async def on_profile_update( + self, user_id: str, new_profile: ProfileInfo, by_admin: bool, deactivation: bool + ) -> None: + """Called after the global profile of a user has been updated. Does not include + per-room profile changes. + + Args: + user_id: The user whose profile was changed. + new_profile: The updated profile for the user. + by_admin: Whether the profile update was performed by a server admin. + deactivation: Whether this change was made while deactivating the user. + """ + for callback in self._on_profile_update_callbacks: + try: + await callback(user_id, new_profile, by_admin, deactivation) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + + async def on_user_deactivation_status_changed( + self, user_id: str, deactivated: bool, by_admin: bool + ) -> None: + """Called after a user has been deactivated or reactivated. + + Args: + user_id: The deactivated user. + deactivated: Whether the user is now deactivated. + by_admin: Whether the deactivation was performed by a server admin. + """ + for callback in self._on_user_deactivation_status_changed_callbacks: + try: + await callback(user_id, deactivated, by_admin) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index e4eae03056..76ae768e6e 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -38,6 +38,7 @@ class DeactivateAccountHandler: self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self._server_name = hs.hostname + self._third_party_rules = hs.get_third_party_event_rules() # Flag that indicates whether the process to part users from rooms is running self._user_parter_running = False @@ -135,9 +136,13 @@ class DeactivateAccountHandler: if erase_data: user = UserID.from_string(user_id) # Remove avatar URL from this user - await self._profile_handler.set_avatar_url(user, requester, "", by_admin) + await self._profile_handler.set_avatar_url( + user, requester, "", by_admin, deactivation=True + ) # Remove displayname from this user - await self._profile_handler.set_displayname(user, requester, "", by_admin) + await self._profile_handler.set_displayname( + user, requester, "", by_admin, deactivation=True + ) logger.info("Marking %s as erased", user_id) await self.store.mark_user_erased(user_id) @@ -160,6 +165,13 @@ class DeactivateAccountHandler: # Remove account data (including ignored users and push rules). await self.store.purge_account_data_for_user(user_id) + # Let modules know the user has been deactivated. + await self._third_party_rules.on_user_deactivation_status_changed( + user_id, + True, + by_admin, + ) + return identity_server_supports_unbinding async def _reject_pending_invites_for_user(self, user_id: str) -> None: @@ -264,6 +276,10 @@ class DeactivateAccountHandler: # Mark the user as active. await self.store.set_user_deactivated_status(user_id, False) + await self._third_party_rules.on_user_deactivation_status_changed( + user_id, False, True + ) + # Add the user to the directory, if necessary. Note that # this must be done after the user is re-activated, because # deactivated users are excluded from the user directory. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index dd27f0accc..6554c0d3c2 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -71,6 +71,8 @@ class ProfileHandler: self.server_name = hs.config.server.server_name + self._third_party_rules = hs.get_third_party_event_rules() + if hs.config.worker.run_background_tasks: self.clock.looping_call( self._update_remote_profile_cache, self.PROFILE_UPDATE_MS @@ -171,6 +173,7 @@ class ProfileHandler: requester: Requester, new_displayname: str, by_admin: bool = False, + deactivation: bool = False, ) -> None: """Set the displayname of a user @@ -179,6 +182,7 @@ class ProfileHandler: requester: The user attempting to make this change. new_displayname: The displayname to give this user. by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -227,6 +231,10 @@ class ProfileHandler: target_user.to_string(), profile ) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + await self._update_join_states(requester, target_user) async def get_avatar_url(self, target_user: UserID) -> Optional[str]: @@ -261,6 +269,7 @@ class ProfileHandler: requester: Requester, new_avatar_url: str, by_admin: bool = False, + deactivation: bool = False, ) -> None: """Set a new avatar URL for a user. @@ -269,6 +278,7 @@ class ProfileHandler: requester: The user attempting to make this change. new_avatar_url: The avatar URL to give this user. by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -315,6 +325,10 @@ class ProfileHandler: target_user.to_string(), profile ) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + await self._update_join_states(requester, target_user) @cached() diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 902916d800..7e46931869 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -145,6 +145,7 @@ __all__ = [ "JsonDict", "EventBase", "StateMap", + "ProfileInfo", ] logger = logging.getLogger(__name__) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 9cca9edd30..bfc04785b7 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -15,12 +15,12 @@ import threading from typing import TYPE_CHECKING, Dict, Optional, Tuple from unittest.mock import Mock -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin -from synapse.rest.client import login, room +from synapse.rest.client import account, login, profile, room from synapse.types import JsonDict, Requester, StateMap from synapse.util.frozenutils import unfreeze @@ -80,6 +80,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): admin.register_servlets, login.register_servlets, room.register_servlets, + profile.register_servlets, + account.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -530,3 +532,216 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, tok=self.tok, ) + + def test_on_profile_update(self): + """Tests that the on_profile_update module callback is correctly called on + profile updates. + """ + displayname = "Foo" + avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" + + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m) + + # Change the display name. + channel = self.make_request( + "PUT", + "/_matrix/client/v3/profile/%s/displayname" % self.user_id, + {"displayname": displayname}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called once for our user. + m.assert_called_once() + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Test that by_admin is False. + self.assertFalse(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertIsNone(profile_info.avatar_url) + + # Change the avatar. + channel = self.make_request( + "PUT", + "/_matrix/client/v3/profile/%s/avatar_url" % self.user_id, + {"avatar_url": avatar_url}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called once for our user. + self.assertEqual(m.call_count, 2) + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Test that by_admin is False. + self.assertFalse(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertEqual(profile_info.avatar_url, avatar_url) + + def test_on_profile_update_admin(self): + """Tests that the on_profile_update module callback is correctly called on + profile updates triggered by a server admin. + """ + displayname = "Foo" + avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" + + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Change a user's profile. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % self.user_id, + {"displayname": displayname, "avatar_url": avatar_url}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called twice (since we update the display name + # and avatar separately). + self.assertEqual(m.call_count, 2) + + # Get the arguments for the last call and check it's about the right user. + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Check that by_admin is True. + self.assertTrue(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertEqual(profile_info.avatar_url, avatar_url) + + def test_on_user_deactivation_status_changed(self): + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_user_deactivation_status_changed_callbacks.append( + deactivation_mock, + ) + # Also register a mocked callback for profile updates, to check that the + # deactivation code calls it in a way that let modules know the user is being + # deactivated. + profile_mock = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append( + profile_mock, + ) + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + tok = self.login("altan", "password") + + # Deactivate that user. + channel = self.make_request( + "POST", + "/_matrix/client/v3/account/deactivate", + { + "auth": { + "type": LoginType.PASSWORD, + "password": "password", + "identifier": { + "type": "m.id.user", + "user": user_id, + }, + }, + "erase": True, + }, + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID, and with a True + # deactivated flag and a False by_admin flag. + self.assertEqual(args[0], user_id) + self.assertTrue(args[1]) + self.assertFalse(args[2]) + + # Check that the profile update callback was called twice (once for the display + # name and once for the avatar URL), and that the "deactivation" boolean is true. + self.assertEqual(profile_mock.call_count, 2) + args = profile_mock.call_args[0] + self.assertTrue(args[3]) + + def test_on_user_deactivation_status_changed_admin(self): + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation triggered by a server admin as + well as a reactivation. + """ + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_user_deactivation_status_changed_callbacks.append(m) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": True}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + m.assert_called_once() + args = m.call_args[0] + + # Check that the mock was called with the right user ID, and with True deactivated + # and by_admin flags. + self.assertEqual(args[0], user_id) + self.assertTrue(args[1]) + self.assertTrue(args[2]) + + # Reactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": False, "password": "hackme"}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + self.assertEqual(m.call_count, 2) + args = m.call_args[0] + + # Check that the mock was called with the right user ID, and with a False + # deactivated flag and a True by_admin flag. + self.assertEqual(args[0], user_id) + self.assertFalse(args[1]) + self.assertTrue(args[2]) From 4d6b6c17c860a6ef258e513d841dbda6ea151cbd Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 1 Mar 2022 15:27:15 +0000 Subject: [PATCH 33/40] Fix rare error in `ReadWriteLock` when writers complete immediately (#12105) Signed-off-by: Sean Quah --- changelog.d/12105.bugfix | 1 + synapse/util/async_helpers.py | 5 ++++- tests/util/test_rwlock.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12105.bugfix diff --git a/changelog.d/12105.bugfix b/changelog.d/12105.bugfix new file mode 100644 index 0000000000..f42e63e01f --- /dev/null +++ b/changelog.d/12105.bugfix @@ -0,0 +1 @@ +Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 81320b8972..60c03a66fd 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -555,7 +555,10 @@ class ReadWriteLock: finally: with PreserveLoggingContext(): new_defer.callback(None) - if self.key_to_current_writer[key] == new_defer: + # `self.key_to_current_writer[key]` may be missing if there was another + # writer waiting for us and it completed entirely within the + # `new_defer.callback()` call above. + if self.key_to_current_writer.get(key) == new_defer: self.key_to_current_writer.pop(key) return _ctx_manager() diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index a10071c70f..0774625b85 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -13,6 +13,7 @@ # limitations under the License. from twisted.internet import defer +from twisted.internet.defer import Deferred from synapse.util.async_helpers import ReadWriteLock @@ -83,3 +84,32 @@ class ReadWriteLockTestCase(unittest.TestCase): self.assertTrue(d.called) with d.result: pass + + def test_lock_handoff_to_nonblocking_writer(self): + """Test a writer handing the lock to another writer that completes instantly.""" + rwlock = ReadWriteLock() + key = "key" + + unblock: "Deferred[None]" = Deferred() + + async def blocking_write(): + with await rwlock.write(key): + await unblock + + async def nonblocking_write(): + with await rwlock.write(key): + pass + + d1 = defer.ensureDeferred(blocking_write()) + d2 = defer.ensureDeferred(nonblocking_write()) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + # Unblock the first writer. The second writer will complete without blocking. + unblock.callback(None) + self.assertTrue(d1.called) + self.assertTrue(d2.called) + + # The `ReadWriteLock` should operate as normal. + d3 = defer.ensureDeferred(nonblocking_write()) + self.assertTrue(d3.called) From 313581e4e9bc2ec3d59ccff86e3a0c02661f71c4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Mar 2022 17:44:41 +0000 Subject: [PATCH 34/40] Use importlib.metadata to read requirements (#12088) * Pull runtime dep checks into their own module * Reimplement `check_requirements` using `importlib` I've tried to make this clearer. We start by working out which of Synapse's requirements we need to be installed here and now. I was surprised that there wasn't an easier way to see which packages were installed by a given extra. I've pulled out the error messages into functions that deal with "is this for an extra or not". And I've rearranged the loop over two different sets of requirements into one loop with a "must be instaled" flag. I hope you agree that this is clearer. * Test cases --- changelog.d/12088.misc | 1 + synapse/app/__init__.py | 6 +- synapse/app/homeserver.py | 2 +- synapse/config/cache.py | 2 +- synapse/config/metrics.py | 2 +- synapse/config/oidc.py | 2 +- synapse/config/redis.py | 2 +- synapse/config/repository.py | 2 +- synapse/config/saml2.py | 2 +- synapse/config/tracer.py | 2 +- synapse/python_dependencies.py | 107 +--------------------- synapse/util/check_dependencies.py | 127 ++++++++++++++++++++++++++ tests/util/test_check_dependencies.py | 95 +++++++++++++++++++ 13 files changed, 237 insertions(+), 115 deletions(-) create mode 100644 changelog.d/12088.misc create mode 100644 synapse/util/check_dependencies.py create mode 100644 tests/util/test_check_dependencies.py diff --git a/changelog.d/12088.misc b/changelog.d/12088.misc new file mode 100644 index 0000000000..ce4213650c --- /dev/null +++ b/changelog.d/12088.misc @@ -0,0 +1 @@ +Inspect application dependencies using `importlib.metadata` or its backport. \ No newline at end of file diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index ee51480a9e..334c3d2c17 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -15,13 +15,13 @@ import logging import sys from typing import Container -from synapse import python_dependencies # noqa: E402 +from synapse.util import check_dependencies logger = logging.getLogger(__name__) try: - python_dependencies.check_requirements() -except python_dependencies.DependencyException as e: + check_dependencies.check_requirements() +except check_dependencies.DependencyException as e: sys.stderr.writelines( e.message # noqa: B306, DependencyException.message is a property ) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b9931001c2..a6789a840e 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -59,7 +59,6 @@ from synapse.http.server import ( from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource @@ -70,6 +69,7 @@ from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.util.check_dependencies import check_requirements from synapse.util.httpresourcetree import create_resource_tree from synapse.util.module_loader import load_module diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 387ac6d115..9a68da9c33 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -20,7 +20,7 @@ from typing import Callable, Dict, Optional import attr -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 1cc26e7578..f62292ecf6 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -15,7 +15,7 @@ import attr -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index e783b11315..f7e4f9ef22 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -20,11 +20,11 @@ import attr from synapse.config._util import validate_config from synapse.config.sso import SsoAttributeRequirement -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_mxc_uri +from ..util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError, read_file DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc.JinjaOidcMappingProvider" diff --git a/synapse/config/redis.py b/synapse/config/redis.py index 33104af734..bdb1aac3a2 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.config._base import Config -from synapse.python_dependencies import check_requirements +from synapse.util.check_dependencies import check_requirements class RedisConfig(Config): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 1980351e77..0a0d901bfb 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -20,8 +20,8 @@ from urllib.request import getproxies_environment # type: ignore import attr from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict +from synapse.util.check_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module from ._base import Config, ConfigError diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index ec9d9f65e7..43c456d5c6 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -17,8 +17,8 @@ import logging from typing import Any, List, Set from synapse.config.sso import SsoAttributeRequirement -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict +from synapse.util.check_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 21b9a88353..7aff618ea6 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -14,7 +14,7 @@ from typing import Set -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index f43fbb5842..8f48a33936 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -17,14 +17,7 @@ import itertools import logging -from typing import List, Set - -from pkg_resources import ( - DistributionNotFound, - Requirement, - VersionConflict, - get_provider, -) +from typing import Set logger = logging.getLogger(__name__) @@ -90,6 +83,8 @@ REQUIREMENTS = [ # ijson 3.1.4 fixes a bug with "." in property names "ijson>=3.1.4", "matrix-common~=1.1.0", + # For runtime introspection of our dependencies + "packaging~=21.3", ] CONDITIONAL_REQUIREMENTS = { @@ -144,102 +139,6 @@ def list_requirements(): return list(set(REQUIREMENTS) | ALL_OPTIONAL_REQUIREMENTS) -class DependencyException(Exception): - @property - def message(self): - return "\n".join( - [ - "Missing Requirements: %s" % (", ".join(self.dependencies),), - "To install run:", - " pip install --upgrade --force %s" % (" ".join(self.dependencies),), - "", - ] - ) - - @property - def dependencies(self): - for i in self.args[0]: - yield '"' + i + '"' - - -def check_requirements(for_feature=None): - deps_needed = [] - errors = [] - - if for_feature: - reqs = CONDITIONAL_REQUIREMENTS[for_feature] - else: - reqs = REQUIREMENTS - - for dependency in reqs: - try: - _check_requirement(dependency) - except VersionConflict as e: - deps_needed.append(dependency) - errors.append( - "Needed %s, got %s==%s" - % ( - dependency, - e.dist.project_name, # type: ignore[attr-defined] # noqa - e.dist.version, # type: ignore[attr-defined] # noqa - ) - ) - except DistributionNotFound: - deps_needed.append(dependency) - if for_feature: - errors.append( - "Needed %s for the '%s' feature but it was not installed" - % (dependency, for_feature) - ) - else: - errors.append("Needed %s but it was not installed" % (dependency,)) - - if not for_feature: - # Check the optional dependencies are up to date. We allow them to not be - # installed. - OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), []) - - for dependency in OPTS: - try: - _check_requirement(dependency) - except VersionConflict as e: - deps_needed.append(dependency) - errors.append( - "Needed optional %s, got %s==%s" - % ( - dependency, - e.dist.project_name, # type: ignore[attr-defined] # noqa - e.dist.version, # type: ignore[attr-defined] # noqa - ) - ) - except DistributionNotFound: - # If it's not found, we don't care - pass - - if deps_needed: - for err in errors: - logging.error(err) - - raise DependencyException(deps_needed) - - -def _check_requirement(dependency_string): - """Parses a dependency string, and checks if the specified requirement is installed - - Raises: - VersionConflict if the requirement is installed, but with the the wrong version - DistributionNotFound if nothing is found to provide the requirement - """ - req = Requirement.parse(dependency_string) - - # first check if the markers specify that this requirement needs installing - if req.marker is not None and not req.marker.evaluate(): - # not required for this environment - return - - get_provider(req) - - if __name__ == "__main__": import sys diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py new file mode 100644 index 0000000000..3a1f6b3c75 --- /dev/null +++ b/synapse/util/check_dependencies.py @@ -0,0 +1,127 @@ +import logging +from typing import Iterable, NamedTuple, Optional + +from packaging.requirements import Requirement + +DISTRIBUTION_NAME = "matrix-synapse" + +try: + from importlib import metadata +except ImportError: + import importlib_metadata as metadata # type: ignore[no-redef] + + +class DependencyException(Exception): + @property + def message(self) -> str: + return "\n".join( + [ + "Missing Requirements: %s" % (", ".join(self.dependencies),), + "To install run:", + " pip install --upgrade --force %s" % (" ".join(self.dependencies),), + "", + ] + ) + + @property + def dependencies(self) -> Iterable[str]: + for i in self.args[0]: + yield '"' + i + '"' + + +EXTRAS = set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) + + +class Dependency(NamedTuple): + requirement: Requirement + must_be_installed: bool + + +def _generic_dependencies() -> Iterable[Dependency]: + """Yield pairs (requirement, must_be_installed).""" + requirements = metadata.requires(DISTRIBUTION_NAME) + assert requirements is not None + for raw_requirement in requirements: + req = Requirement(raw_requirement) + # https://packaging.pypa.io/en/latest/markers.html#usage notes that + # > Evaluating an extra marker with no environment is an error + # so we pass in a dummy empty extra value here. + must_be_installed = req.marker is None or req.marker.evaluate({"extra": ""}) + yield Dependency(req, must_be_installed) + + +def _dependencies_for_extra(extra: str) -> Iterable[Dependency]: + """Yield additional dependencies needed for a given `extra`.""" + requirements = metadata.requires(DISTRIBUTION_NAME) + assert requirements is not None + for raw_requirement in requirements: + req = Requirement(raw_requirement) + # Exclude mandatory deps by only selecting deps needed with this extra. + if ( + req.marker is not None + and req.marker.evaluate({"extra": extra}) + and not req.marker.evaluate({"extra": ""}) + ): + yield Dependency(req, True) + + +def _not_installed(requirement: Requirement, extra: Optional[str] = None) -> str: + if extra: + return f"Need {requirement.name} for {extra}, but it is not installed" + else: + return f"Need {requirement.name}, but it is not installed" + + +def _incorrect_version( + requirement: Requirement, got: str, extra: Optional[str] = None +) -> str: + if extra: + return f"Need {requirement} for {extra}, but got {requirement.name}=={got}" + else: + return f"Need {requirement}, but got {requirement.name}=={got}" + + +def check_requirements(extra: Optional[str] = None) -> None: + """Check Synapse's dependencies are present and correctly versioned. + + If provided, `extra` must be the name of an pacakging extra (e.g. "saml2" in + `pip install matrix-synapse[saml2]`). + + If `extra` is None, this function checks that + - all mandatory dependencies are installed and correctly versioned, and + - each optional dependency that's installed is correctly versioned. + + If `extra` is not None, this function checks that + - the dependencies needed for that extra are installed and correctly versioned. + + :raises DependencyException: if a dependency is missing or incorrectly versioned. + :raises ValueError: if this extra does not exist. + """ + # First work out which dependencies are required, and which are optional. + if extra is None: + dependencies = _generic_dependencies() + elif extra in EXTRAS: + dependencies = _dependencies_for_extra(extra) + else: + raise ValueError(f"Synapse does not provide the feature '{extra}'") + + deps_unfulfilled = [] + errors = [] + + for (requirement, must_be_installed) in dependencies: + try: + dist: metadata.Distribution = metadata.distribution(requirement.name) + except metadata.PackageNotFoundError: + if must_be_installed: + deps_unfulfilled.append(requirement.name) + errors.append(_not_installed(requirement, extra)) + else: + if not requirement.specifier.contains(dist.version): + deps_unfulfilled.append(requirement.name) + errors.append(_incorrect_version(requirement, dist.version, extra)) + + if deps_unfulfilled: + for err in errors: + logging.error(err) + + raise DependencyException(deps_unfulfilled) diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py new file mode 100644 index 0000000000..3c07252252 --- /dev/null +++ b/tests/util/test_check_dependencies.py @@ -0,0 +1,95 @@ +from contextlib import contextmanager +from typing import Generator, Optional +from unittest.mock import patch + +from synapse.util.check_dependencies import ( + DependencyException, + check_requirements, + metadata, +) + +from tests.unittest import TestCase + + +class DummyDistribution(metadata.Distribution): + def __init__(self, version: str): + self._version = version + + @property + def version(self): + return self._version + + def locate_file(self, path): + raise NotImplementedError() + + def read_text(self, filename): + raise NotImplementedError() + + +old = DummyDistribution("0.1.2") +new = DummyDistribution("1.2.3") + +# could probably use stdlib TestCase --- no need for twisted here + + +class TestDependencyChecker(TestCase): + @contextmanager + def mock_installed_package( + self, distribution: Optional[DummyDistribution] + ) -> Generator[None, None, None]: + """Pretend that looking up any distribution yields the given `distribution`.""" + + def mock_distribution(name: str): + if distribution is None: + raise metadata.PackageNotFoundError + else: + return distribution + + with patch( + "synapse.util.check_dependencies.metadata.distribution", + mock_distribution, + ): + yield + + def test_mandatory_dependency(self) -> None: + """Complain if a required package is missing or old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(None): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(new): + # should not raise + check_requirements() + + def test_generic_check_of_optional_dependency(self) -> None: + """Complain if an optional package is old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'cool-extra'"], + ): + with self.mock_installed_package(None): + # should not raise + check_requirements() + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(new): + # should not raise + check_requirements() + + def test_check_for_extra_dependencies(self) -> None: + """Complain if a package required for an extra is missing or old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'cool-extra'"], + ), patch("synapse.util.check_dependencies.EXTRAS", {"cool-extra"}): + with self.mock_installed_package(None): + self.assertRaises(DependencyException, check_requirements, "cool-extra") + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements, "cool-extra") + with self.mock_installed_package(new): + # should not raise + check_requirements() From 5f62a094de10b4c4382908231128dace833a1195 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Mar 2022 19:47:02 +0000 Subject: [PATCH 35/40] Detox, part 1 of N (#12119) * Don't use `tox` for `check-sampleconfig` * Don't use `tox` for check-newsfragment --- .github/workflows/tests.yml | 13 ++++++++++--- changelog.d/12119.misc | 1 + scripts-dev/check-newsfragment | 2 +- tox.ini | 10 ---------- 4 files changed, 12 insertions(+), 14 deletions(-) create mode 100644 changelog.d/12119.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bbf1033bdd..e9e4277322 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,12 +10,19 @@ concurrency: cancel-in-progress: true jobs: + check-sampleconfig: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - run: pip install -e . + - run: scripts-dev/generate_sample_config --check + lint: runs-on: ubuntu-latest strategy: matrix: toxenv: - - "check-sampleconfig" - "check_codestyle" - "check_isort" - "mypy" @@ -43,7 +50,7 @@ jobs: ref: ${{ github.event.pull_request.head.sha }} fetch-depth: 0 - uses: actions/setup-python@v2 - - run: pip install tox + - run: "pip install 'towncrier>=18.6.0rc1'" - run: scripts-dev/check-newsfragment env: PULL_REQUEST_NUMBER: ${{ github.event.number }} @@ -51,7 +58,7 @@ jobs: # Dummy step to gate other tests on without repeating the whole list linting-done: if: ${{ !cancelled() }} # Run this even if prior jobs were skipped - needs: [lint, lint-crlf, lint-newsfile] + needs: [lint, lint-crlf, lint-newsfile, check-sampleconfig] runs-on: ubuntu-latest steps: - run: "true" diff --git a/changelog.d/12119.misc b/changelog.d/12119.misc new file mode 100644 index 0000000000..f02d140f38 --- /dev/null +++ b/changelog.d/12119.misc @@ -0,0 +1 @@ +Move CI checks out of tox, to facilitate a move to using poetry. \ No newline at end of file diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index c764011d6a..493558ad65 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -35,7 +35,7 @@ CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing y https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog" # If check-newsfragment returns a non-zero exit code, print the contributing guide and exit -tox -qe check-newsfragment || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) +python -m towncrier.check --compare-with=origin/develop || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) echo echo "--------------------------" diff --git a/tox.ini b/tox.ini index 436ecf7552..04b972e2c5 100644 --- a/tox.ini +++ b/tox.ini @@ -168,16 +168,6 @@ commands = extras = lint commands = isort -c --df {[base]lint_targets} -[testenv:check-newsfragment] -skip_install = true -usedevelop = false -deps = towncrier>=18.6.0rc1 -commands = - python -m towncrier.check --compare-with=origin/develop - -[testenv:check-sampleconfig] -commands = {toxinidir}/scripts-dev/generate_sample_config --check - [testenv:combine] skip_install = true usedevelop = false From 8e56a1b73c9819ea4bddbe6a4734966e70b3b92c Mon Sep 17 00:00:00 2001 From: lukasdenk <63459921+lukasdenk@users.noreply.github.com> Date: Wed, 2 Mar 2022 11:35:34 +0100 Subject: [PATCH 36/40] Make get_room_version use cached get_room_version_id. (#11808) --- changelog.d/11808.misc | 1 + synapse/storage/databases/main/state.py | 27 ++++++++++++------------- tests/handlers/test_room_summary.py | 5 ++++- 3 files changed, 18 insertions(+), 15 deletions(-) create mode 100644 changelog.d/11808.misc diff --git a/changelog.d/11808.misc b/changelog.d/11808.misc new file mode 100644 index 0000000000..cdc5fc75b7 --- /dev/null +++ b/changelog.d/11808.misc @@ -0,0 +1 @@ +Make method `get_room_version` use cached `get_room_version_id`. diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 2fb3e65192..417aef1dbc 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -42,6 +42,16 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: + v = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not v: + raise UnsupportedRoomVersionError( + "Room %s uses a room version %s which is no longer supported" + % (room_id, room_version_id) + ) + return v + + # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" @@ -62,11 +72,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Typically this happens if support for the room's version has been removed from Synapse. """ - return await self.db_pool.runInteraction( - "get_room_version_txn", - self.get_room_version_txn, - room_id, - ) + room_version_id = await self.get_room_version_id(room_id) + return _retrieve_and_check_room_version(room_id, room_version_id) def get_room_version_txn( self, txn: LoggingTransaction, room_id: str @@ -82,15 +89,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): removed from Synapse. """ room_version_id = self.get_room_version_id_txn(txn, room_id) - v = KNOWN_ROOM_VERSIONS.get(room_version_id) - - if not v: - raise UnsupportedRoomVersionError( - "Room %s uses a room version %s which is no longer supported" - % (room_id, room_version_id) - ) - - return v + return _retrieve_and_check_room_version(room_id, room_version_id) @cached(max_entries=10000) async def get_room_version_id(self, room_id: str) -> str: diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index b33ff94a39..cff07a8973 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -658,7 +658,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): def test_unknown_room_version(self): """ - If an room with an unknown room version is encountered it should not cause + If a room with an unknown room version is encountered it should not cause the entire summary to skip. """ # Poke the database and update the room version to an unknown one. @@ -670,6 +670,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): desc="updated-room-version", ) ) + # Invalidate method so that it returns the currently updated version + # instead of the cached version. + self.hs.get_datastores().main.get_room_version_id.invalidate((self.room,)) # The result should have only the space, along with a link from space -> room. expected = [(self.space, [self.room])] From c7b2f1ccdc412c4f5f07f4fe630d2c2040caf93d Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 2 Mar 2022 10:37:04 +0000 Subject: [PATCH 37/40] Back out in-flight state caching changes. (#12126) --- changelog.d/10870.misc | 1 - changelog.d/11608.misc | 1 - changelog.d/11610.misc | 1 - changelog.d/12033.misc | 1 - changelog.d/12126.removal | 1 + synapse/storage/databases/state/store.py | 243 ++--------- tests/storage/databases/test_state_store.py | 454 -------------------- 7 files changed, 26 insertions(+), 676 deletions(-) delete mode 100644 changelog.d/10870.misc delete mode 100644 changelog.d/11608.misc delete mode 100644 changelog.d/11610.misc delete mode 100644 changelog.d/12033.misc create mode 100644 changelog.d/12126.removal delete mode 100644 tests/storage/databases/test_state_store.py diff --git a/changelog.d/10870.misc b/changelog.d/10870.misc deleted file mode 100644 index 3af049b969..0000000000 --- a/changelog.d/10870.misc +++ /dev/null @@ -1 +0,0 @@ -Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/changelog.d/11608.misc b/changelog.d/11608.misc deleted file mode 100644 index 3af049b969..0000000000 --- a/changelog.d/11608.misc +++ /dev/null @@ -1 +0,0 @@ -Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/changelog.d/11610.misc b/changelog.d/11610.misc deleted file mode 100644 index 3af049b969..0000000000 --- a/changelog.d/11610.misc +++ /dev/null @@ -1 +0,0 @@ -Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/changelog.d/12033.misc b/changelog.d/12033.misc deleted file mode 100644 index 3af049b969..0000000000 --- a/changelog.d/12033.misc +++ /dev/null @@ -1 +0,0 @@ -Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/changelog.d/12126.removal b/changelog.d/12126.removal new file mode 100644 index 0000000000..8c8bf6ee7e --- /dev/null +++ b/changelog.d/12126.removal @@ -0,0 +1 @@ +Back out in-flight state caching changes. \ No newline at end of file diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index dadf3d1e3a..7614d76ac6 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -13,24 +13,11 @@ # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - Iterable, - Optional, - Sequence, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple import attr -from sortedcontainers import SortedDict - -from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -42,12 +29,6 @@ from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap -from synapse.util import unwrapFirstError -from synapse.util.async_helpers import ( - AbstractObservableDeferred, - ObservableDeferred, - yieldable_gather_results, -) from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -56,8 +37,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) + MAX_STATE_DELTA_HOPS = 100 -MAX_INFLIGHT_REQUESTS_PER_GROUP = 5 @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -73,24 +54,6 @@ class _GetStateGroupDelta: return len(self.delta_ids) if self.delta_ids else 0 -def state_filter_rough_priority_comparator( - state_filter: StateFilter, -) -> Tuple[int, int]: - """ - Returns a comparable value that roughly indicates the relative size of this - state filter compared to others. - 'Larger' state filters should sort first when using ascending order, so - this is essentially the opposite of 'size'. - It should be treated as a rough guide only and should not be interpreted to - have any particular meaning. The representation may also change - - The current implementation returns a tuple of the form: - * -1 for include_others, 0 otherwise - * -(number of entries in state_filter.types) - """ - return -int(state_filter.include_others), -len(state_filter.types) - - class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """A data store for fetching/storing state groups.""" @@ -143,12 +106,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): 500000, ) - # Current ongoing get_state_for_groups in-flight requests - # {group ID -> {StateFilter -> ObservableDeferred}} - self._state_group_inflight_requests: Dict[ - int, SortedDict[StateFilter, AbstractObservableDeferred[StateMap[str]]] - ] = {} - def get_max_state_group_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") return txn.fetchone()[0] # type: ignore @@ -200,7 +157,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) async def _get_state_groups_from_groups( - self, groups: Sequence[int], state_filter: StateFilter + self, groups: List[int], state_filter: StateFilter ) -> Dict[int, StateMap[str]]: """Returns the state groups for a given set of groups from the database, filtering on types of state events. @@ -271,170 +228,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types - def _get_state_for_group_gather_inflight_requests( - self, group: int, state_filter_left_over: StateFilter - ) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]: - """ - Attempts to gather in-flight requests and re-use them to retrieve state - for the given state group, filtered with the given state filter. - - If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests, - and there *still* isn't enough information to complete the request by solely - reusing others, a full state filter will be requested to ensure that subsequent - requests can reuse this request. - - Used as part of _get_state_for_group_using_inflight_cache. - - Returns: - Tuple of two values: - A sequence of ObservableDeferreds to observe - A StateFilter representing what else needs to be requested to fulfill the request - """ - - inflight_requests = self._state_group_inflight_requests.get(group) - if inflight_requests is None: - # no requests for this group, need to retrieve it all ourselves - return (), state_filter_left_over - - # The list of ongoing requests which will help narrow the current request. - reusable_requests = [] - - # Iterate over existing requests in roughly biggest-first order. - for request_state_filter in inflight_requests: - request_deferred = inflight_requests[request_state_filter] - new_state_filter_left_over = state_filter_left_over.approx_difference( - request_state_filter - ) - if new_state_filter_left_over == state_filter_left_over: - # Reusing this request would not gain us anything, so don't bother. - continue - - reusable_requests.append(request_deferred) - state_filter_left_over = new_state_filter_left_over - if state_filter_left_over == StateFilter.none(): - # we have managed to collect enough of the in-flight requests - # to cover our StateFilter and give us the state we need. - break - - if ( - state_filter_left_over != StateFilter.none() - and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP - ): - # There are too many requests for this group. - # To prevent even more from building up, we request the whole - # state filter to guarantee that we can be reused by any subsequent - # requests for this state group. - return (), StateFilter.all() - - return reusable_requests, state_filter_left_over - - async def _get_state_for_group_fire_request( - self, group: int, state_filter: StateFilter - ) -> StateMap[str]: - """ - Fires off a request to get the state at a state group, - potentially filtering by type and/or state key. - - This request will be tracked in the in-flight request cache and automatically - removed when it is finished. - - Used as part of _get_state_for_group_using_inflight_cache. - - Args: - group: ID of the state group for which we want to get state - state_filter: the state filter used to fetch state from the database - """ - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - async def _the_request() -> StateMap[str]: - group_to_state_dict = await self._get_state_groups_from_groups( - (group,), state_filter=db_state_filter - ) - - # Now let's update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # Remove ourselves from the in-flight cache - group_request_dict = self._state_group_inflight_requests[group] - del group_request_dict[db_state_filter] - if not group_request_dict: - # If there are no more requests in-flight for this group, - # clean up the cache by removing the empty dictionary - del self._state_group_inflight_requests[group] - - return group_to_state_dict[group] - - # We don't immediately await the result, so must use run_in_background - # But we DO await the result before the current log context (request) - # finishes, so don't need to run it as a background process. - request_deferred = run_in_background(_the_request) - observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True) - - # Insert the ObservableDeferred into the cache - group_request_dict = self._state_group_inflight_requests.setdefault( - group, SortedDict(state_filter_rough_priority_comparator) - ) - group_request_dict[db_state_filter] = observable_deferred - - return await make_deferred_yieldable(observable_deferred.observe()) - - async def _get_state_for_group_using_inflight_cache( - self, group: int, state_filter: StateFilter - ) -> MutableStateMap[str]: - """ - Gets the state at a state group, potentially filtering by type and/or - state key. - - 1. Calls _get_state_for_group_gather_inflight_requests to gather any - ongoing requests which might overlap with the current request. - 2. Fires a new request, using _get_state_for_group_fire_request, - for any state which cannot be gathered from ongoing requests. - - Args: - group: ID of the state group for which we want to get state - state_filter: the state filter used to fetch state from the database - Returns: - state map - """ - - # first, figure out whether we can re-use any in-flight requests - # (and if so, what would be left over) - ( - reusable_requests, - state_filter_left_over, - ) = self._get_state_for_group_gather_inflight_requests(group, state_filter) - - if state_filter_left_over != StateFilter.none(): - # Fetch remaining state - remaining = await self._get_state_for_group_fire_request( - group, state_filter_left_over - ) - assembled_state: MutableStateMap[str] = dict(remaining) - else: - assembled_state = {} - - gathered = await make_deferred_yieldable( - defer.gatherResults( - (r.observe() for r in reusable_requests), consumeErrors=True - ) - ).addErrback(unwrapFirstError) - - # assemble our result. - for result_piece in gathered: - assembled_state.update(result_piece) - - # Filter out any state that may be more than what we asked for. - return state_filter.filter_state(assembled_state) - async def _get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: @@ -476,17 +269,31 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): if not incomplete_groups: return state - async def get_from_cache(group: int, state_filter: StateFilter) -> None: - state[group] = await self._get_state_for_group_using_inflight_cache( - group, state_filter - ) + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence - await yieldable_gather_results( - get_from_cache, - incomplete_groups, - state_filter, + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = await self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter ) + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in group_to_state_dict.items(): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + return state def _get_state_for_groups_using_cache( diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py deleted file mode 100644 index 2b484c95a9..0000000000 --- a/tests/storage/databases/test_state_store.py +++ /dev/null @@ -1,454 +0,0 @@ -# Copyright 2022 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import typing -from typing import Dict, List, Sequence, Tuple -from unittest.mock import patch - -from parameterized import parameterized - -from twisted.internet.defer import Deferred, ensureDeferred -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.constants import EventTypes -from synapse.storage.databases.state.store import ( - MAX_INFLIGHT_REQUESTS_PER_GROUP, - state_filter_rough_priority_comparator, -) -from synapse.storage.state import StateFilter -from synapse.types import StateMap -from synapse.util import Clock - -from tests.unittest import HomeserverTestCase - -if typing.TYPE_CHECKING: - from synapse.server import HomeServer - -# StateFilter for ALL non-m.room.member state events -ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze( - types={EventTypes.Member: set()}, - include_others=True, -) - -FAKE_STATE = { - (EventTypes.Member, "@alice:test"): "join", - (EventTypes.Member, "@bob:test"): "leave", - (EventTypes.Member, "@charlie:test"): "invite", - ("test.type", "a"): "AAA", - ("test.type", "b"): "BBB", - ("other.event.type", "state.key"): "123", -} - - -class StateGroupInflightCachingTestCase(HomeserverTestCase): - def prepare( - self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer" - ) -> None: - self.state_storage = homeserver.get_storage().state - self.state_datastore = homeserver.get_datastores().state - # Patch out the `_get_state_groups_from_groups`. - # This is useful because it lets us pretend we have a slow database. - get_state_groups_patch = patch.object( - self.state_datastore, - "_get_state_groups_from_groups", - self._fake_get_state_groups_from_groups, - ) - get_state_groups_patch.start() - - self.addCleanup(get_state_groups_patch.stop) - self.get_state_group_calls: List[ - Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]] - ] = [] - - def _fake_get_state_groups_from_groups( - self, groups: Sequence[int], state_filter: StateFilter - ) -> "Deferred[Dict[int, StateMap[str]]]": - d: Deferred[Dict[int, StateMap[str]]] = Deferred() - self.get_state_group_calls.append((tuple(groups), state_filter, d)) - return d - - def _complete_request_fake( - self, - groups: Tuple[int, ...], - state_filter: StateFilter, - d: "Deferred[Dict[int, StateMap[str]]]", - ) -> None: - """ - Assemble a fake database response and complete the database request. - """ - - # Return a filtered copy of the fake state - d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups}) - - def test_duplicate_requests_deduplicated(self) -> None: - """ - Tests that duplicate requests for state are deduplicated. - - This test: - - requests some state (state group 42, 'all' state filter) - - requests it again, before the first request finishes - - checks to see that only one database query was made - - completes the database query - - checks that both requests see the same retrieved state - """ - req1 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.all() - ) - ) - self.pump(by=0.1) - - # This should have gone to the database - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - - req2 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.all() - ) - ) - self.pump(by=0.1) - - # No more calls should have gone to the database - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - self.assertFalse(req2.called) - - groups, sf, d = self.get_state_group_calls[0] - self.assertEqual(groups, (42,)) - self.assertEqual(sf, StateFilter.all()) - - # Now we can complete the request - self._complete_request_fake(groups, sf, d) - - self.assertEqual(self.get_success(req1), FAKE_STATE) - self.assertEqual(self.get_success(req2), FAKE_STATE) - - def test_smaller_request_deduplicated(self) -> None: - """ - Tests that duplicate requests for state are deduplicated. - - This test: - - requests some state (state group 42, 'all' state filter) - - requests a subset of that state, before the first request finishes - - checks to see that only one database query was made - - completes the database query - - checks that both requests see the correct retrieved state - """ - req1 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.from_types((("test.type", None),)) - ) - ) - self.pump(by=0.1) - - # This should have gone to the database - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - - req2 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.from_types((("test.type", "b"),)) - ) - ) - self.pump(by=0.1) - - # No more calls should have gone to the database, because the second - # request was already in the in-flight cache! - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - self.assertFalse(req2.called) - - groups, sf, d = self.get_state_group_calls[0] - self.assertEqual(groups, (42,)) - # The state filter is expanded internally for increased cache hit rate, - # so we the database sees a wider state filter than requested. - self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) - - # Now we can complete the request - self._complete_request_fake(groups, sf, d) - - self.assertEqual( - self.get_success(req1), - {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, - ) - self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"}) - - def test_partially_overlapping_request_deduplicated(self) -> None: - """ - Tests that partially-overlapping requests are partially deduplicated. - - This test: - - requests a single type of wildcard state - (This is internally expanded to be all non-member state) - - requests the entire state in parallel - - checks to see that two database queries were made, but that the second - one is only for member state. - - completes the database queries - - checks that both requests have the correct result. - """ - - req1 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.from_types((("test.type", None),)) - ) - ) - self.pump(by=0.1) - - # This should have gone to the database - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - - req2 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.all() - ) - ) - self.pump(by=0.1) - - # Because it only partially overlaps, this also went to the database - self.assertEqual(len(self.get_state_group_calls), 2) - self.assertFalse(req1.called) - self.assertFalse(req2.called) - - # First request: - groups, sf, d = self.get_state_group_calls[0] - self.assertEqual(groups, (42,)) - # The state filter is expanded internally for increased cache hit rate, - # so we the database sees a wider state filter than requested. - self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) - self._complete_request_fake(groups, sf, d) - - # Second request: - groups, sf, d = self.get_state_group_calls[1] - self.assertEqual(groups, (42,)) - # The state filter is narrowed to only request membership state, because - # the remainder of the state is already being queried in the first request! - self.assertEqual( - sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False) - ) - self._complete_request_fake(groups, sf, d) - - # Check the results are correct - self.assertEqual( - self.get_success(req1), - {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, - ) - self.assertEqual(self.get_success(req2), FAKE_STATE) - - def test_in_flight_requests_stop_being_in_flight(self) -> None: - """ - Tests that in-flight request deduplication doesn't somehow 'hold on' - to completed requests: once they're done, they're taken out of the - in-flight cache. - """ - req1 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.all() - ) - ) - self.pump(by=0.1) - - # This should have gone to the database - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - - # Complete the request right away. - self._complete_request_fake(*self.get_state_group_calls[0]) - self.assertTrue(req1.called) - - # Send off another request - req2 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.all() - ) - ) - self.pump(by=0.1) - - # It should have gone to the database again, because the previous request - # isn't in-flight and therefore isn't available for deduplication. - self.assertEqual(len(self.get_state_group_calls), 2) - self.assertFalse(req2.called) - - # Complete the request right away. - self._complete_request_fake(*self.get_state_group_calls[1]) - self.assertTrue(req2.called) - groups, sf, d = self.get_state_group_calls[0] - - self.assertEqual(self.get_success(req1), FAKE_STATE) - self.assertEqual(self.get_success(req2), FAKE_STATE) - - def test_inflight_requests_capped(self) -> None: - """ - Tests that the number of in-flight requests is capped to 5. - - - requests several pieces of state separately - (5 to hit the limit, 1 to 'shunt out', another that comes after the - group has been 'shunted out') - - checks to see that the torrent of requests is shunted out by - rewriting one of the filters as the 'all' state filter - - requests after that one do not cause any additional queries - """ - # 5 at the time of writing. - CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP - - reqs = [] - - # Request 7 different keys (1 to 7) of the `some.state` type. - for req_id in range(CAP_COUNT + 2): - reqs.append( - ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, - StateFilter.freeze( - {"some.state": {str(req_id + 1)}}, include_others=False - ), - ) - ) - ) - self.pump(by=0.1) - - # There should only be 6 calls to the database, not 7. - self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1) - - # Assert that the first 5 are exact requests for the individual pieces - # wanted - for req_id in range(CAP_COUNT): - groups, sf, d = self.get_state_group_calls[req_id] - self.assertEqual( - sf, - StateFilter.freeze( - {"some.state": {str(req_id + 1)}}, include_others=False - ), - ) - - # The 6th request should be the 'all' state filter - groups, sf, d = self.get_state_group_calls[CAP_COUNT] - self.assertEqual(sf, StateFilter.all()) - - # Complete the queries and check which requests complete as a result - for req_id in range(CAP_COUNT): - # This request should not have been completed yet - self.assertFalse(reqs[req_id].called) - - groups, sf, d = self.get_state_group_calls[req_id] - self._complete_request_fake(groups, sf, d) - - # This should have only completed this one request - self.assertTrue(reqs[req_id].called) - - # Now complete the final query; the last 2 requests should complete - # as a result - self.assertFalse(reqs[CAP_COUNT].called) - self.assertFalse(reqs[CAP_COUNT + 1].called) - groups, sf, d = self.get_state_group_calls[CAP_COUNT] - self._complete_request_fake(groups, sf, d) - self.assertTrue(reqs[CAP_COUNT].called) - self.assertTrue(reqs[CAP_COUNT + 1].called) - - @parameterized.expand([(False,), (True,)]) - def test_ordering_of_request_reuse(self, reverse: bool) -> None: - """ - Tests that 'larger' in-flight requests are ordered first. - - This is mostly a design decision in order to prevent a request from - hanging on to multiple queries when it would have been sufficient to - hang on to only one bigger query. - - The 'size' of a state filter is a rough heuristic. - - - requests two pieces of state, one 'larger' than the other, but each - spawning a query - - requests a third piece of state - - completes the larger of the first two queries - - checks that the third request gets completed (and doesn't needlessly - wait for the other query) - - Parameters: - reverse: whether to reverse the order of the initial requests, to ensure - that the effect doesn't depend on the order of request submission. - """ - - # We add in an extra state type to make sure that both requests spawn - # queries which are not optimised out. - state_filters = [ - StateFilter.freeze( - {"state.type": {"A"}, "other.state.type": {"a"}}, include_others=False - ), - StateFilter.freeze( - { - "state.type": None, - "other.state.type": {"b"}, - # The current rough size comparator uses the number of state types - # as an indicator of size. - # To influence it to make this state filter bigger than the previous one, - # we add another dummy state type. - "extra.state.type": {"c"}, - }, - include_others=False, - ), - ] - - if reverse: - # For fairness, we perform one test run with the list reversed. - state_filters.reverse() - smallest_state_filter_idx = 1 - biggest_state_filter_idx = 0 - else: - smallest_state_filter_idx = 0 - biggest_state_filter_idx = 1 - - # This assertion is for our own sanity more than anything else. - self.assertLess( - state_filter_rough_priority_comparator( - state_filters[biggest_state_filter_idx] - ), - state_filter_rough_priority_comparator( - state_filters[smallest_state_filter_idx] - ), - "Test invalid: bigger state filter is not actually bigger.", - ) - - # Spawn the initial two requests - for state_filter in state_filters: - ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, - state_filter, - ) - ) - - # Spawn a third request - req = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, - StateFilter.freeze( - { - "state.type": {"A"}, - }, - include_others=False, - ), - ) - ) - self.pump(by=0.1) - - self.assertFalse(req.called) - - # Complete the largest request's query to make sure that the final request - # only waits for that one (and doesn't needlessly wait for both queries) - self._complete_request_fake( - *self.get_state_group_calls[biggest_state_filter_idx] - ) - - # That should have been sufficient to complete the third request - self.assertTrue(req.called) From a43a5ea5bfefcb25d90209958fb014a3b5e0ead0 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 2 Mar 2022 10:38:10 +0000 Subject: [PATCH 38/40] Remove misleading newsfile from #12126 which backs out an unreleased change. --- changelog.d/12126.removal | 1 - 1 file changed, 1 deletion(-) delete mode 100644 changelog.d/12126.removal diff --git a/changelog.d/12126.removal b/changelog.d/12126.removal deleted file mode 100644 index 8c8bf6ee7e..0000000000 --- a/changelog.d/12126.removal +++ /dev/null @@ -1 +0,0 @@ -Back out in-flight state caching changes. \ No newline at end of file From 879e4a7bd7a90cda4c8ea908aede53d8e038ca7c Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 2 Mar 2022 10:45:16 +0000 Subject: [PATCH 39/40] 1.54.0rc1 --- CHANGES.md | 96 +++++++++++++++++++++++++++++++++++++++ changelog.d/11599.doc | 1 - changelog.d/11617.feature | 1 - changelog.d/11808.misc | 1 - changelog.d/11835.feature | 1 - changelog.d/11865.removal | 1 - changelog.d/11900.misc | 1 - changelog.d/11972.misc | 1 - changelog.d/11974.misc | 1 - changelog.d/11984.misc | 1 - changelog.d/11985.feature | 1 - changelog.d/11991.misc | 1 - changelog.d/11992.bugfix | 1 - changelog.d/11994.misc | 1 - changelog.d/11996.misc | 1 - changelog.d/11997.docker | 1 - changelog.d/11999.bugfix | 1 - changelog.d/12000.feature | 1 - changelog.d/12001.feature | 1 - changelog.d/12003.doc | 1 - changelog.d/12004.doc | 1 - changelog.d/12005.misc | 1 - changelog.d/12008.removal | 1 - changelog.d/12009.feature | 1 - changelog.d/12011.misc | 1 - changelog.d/12012.misc | 1 - changelog.d/12013.misc | 1 - changelog.d/12015.misc | 1 - changelog.d/12016.misc | 1 - changelog.d/12018.removal | 1 - changelog.d/12019.misc | 1 - changelog.d/12020.feature | 1 - changelog.d/12021.feature | 1 - changelog.d/12022.feature | 1 - changelog.d/12024.bugfix | 1 - changelog.d/12025.misc | 1 - changelog.d/12030.misc | 1 - changelog.d/12031.misc | 1 - changelog.d/12034.misc | 1 - changelog.d/12037.bugfix | 1 - changelog.d/12039.misc | 1 - changelog.d/12041.misc | 1 - changelog.d/12051.misc | 1 - changelog.d/12052.misc | 1 - changelog.d/12056.bugfix | 1 - changelog.d/12058.feature | 1 - changelog.d/12059.misc | 1 - changelog.d/12060.misc | 1 - changelog.d/12062.feature | 1 - changelog.d/12063.misc | 1 - changelog.d/12066.misc | 1 - changelog.d/12067.feature | 1 - changelog.d/12068.misc | 1 - changelog.d/12069.misc | 1 - changelog.d/12070.misc | 1 - changelog.d/12072.misc | 1 - changelog.d/12073.removal | 1 - changelog.d/12077.bugfix | 1 - changelog.d/12084.misc | 1 - changelog.d/12088.misc | 1 - changelog.d/12089.bugfix | 1 - changelog.d/12092.misc | 1 - changelog.d/12094.misc | 1 - changelog.d/12098.bugfix | 1 - changelog.d/12099.misc | 1 - changelog.d/12100.bugfix | 1 - changelog.d/12105.bugfix | 1 - changelog.d/12106.misc | 1 - changelog.d/12109.misc | 1 - changelog.d/12111.misc | 1 - changelog.d/12112.docker | 1 - changelog.d/12119.misc | 1 - debian/changelog | 6 +++ synapse/__init__.py | 2 +- 74 files changed, 103 insertions(+), 72 deletions(-) delete mode 100644 changelog.d/11599.doc delete mode 100644 changelog.d/11617.feature delete mode 100644 changelog.d/11808.misc delete mode 100644 changelog.d/11835.feature delete mode 100644 changelog.d/11865.removal delete mode 100644 changelog.d/11900.misc delete mode 100644 changelog.d/11972.misc delete mode 100644 changelog.d/11974.misc delete mode 100644 changelog.d/11984.misc delete mode 100644 changelog.d/11985.feature delete mode 100644 changelog.d/11991.misc delete mode 100644 changelog.d/11992.bugfix delete mode 100644 changelog.d/11994.misc delete mode 100644 changelog.d/11996.misc delete mode 100644 changelog.d/11997.docker delete mode 100644 changelog.d/11999.bugfix delete mode 100644 changelog.d/12000.feature delete mode 100644 changelog.d/12001.feature delete mode 100644 changelog.d/12003.doc delete mode 100644 changelog.d/12004.doc delete mode 100644 changelog.d/12005.misc delete mode 100644 changelog.d/12008.removal delete mode 100644 changelog.d/12009.feature delete mode 100644 changelog.d/12011.misc delete mode 100644 changelog.d/12012.misc delete mode 100644 changelog.d/12013.misc delete mode 100644 changelog.d/12015.misc delete mode 100644 changelog.d/12016.misc delete mode 100644 changelog.d/12018.removal delete mode 100644 changelog.d/12019.misc delete mode 100644 changelog.d/12020.feature delete mode 100644 changelog.d/12021.feature delete mode 100644 changelog.d/12022.feature delete mode 100644 changelog.d/12024.bugfix delete mode 100644 changelog.d/12025.misc delete mode 100644 changelog.d/12030.misc delete mode 100644 changelog.d/12031.misc delete mode 100644 changelog.d/12034.misc delete mode 100644 changelog.d/12037.bugfix delete mode 100644 changelog.d/12039.misc delete mode 100644 changelog.d/12041.misc delete mode 100644 changelog.d/12051.misc delete mode 100644 changelog.d/12052.misc delete mode 100644 changelog.d/12056.bugfix delete mode 100644 changelog.d/12058.feature delete mode 100644 changelog.d/12059.misc delete mode 100644 changelog.d/12060.misc delete mode 100644 changelog.d/12062.feature delete mode 100644 changelog.d/12063.misc delete mode 100644 changelog.d/12066.misc delete mode 100644 changelog.d/12067.feature delete mode 100644 changelog.d/12068.misc delete mode 100644 changelog.d/12069.misc delete mode 100644 changelog.d/12070.misc delete mode 100644 changelog.d/12072.misc delete mode 100644 changelog.d/12073.removal delete mode 100644 changelog.d/12077.bugfix delete mode 100644 changelog.d/12084.misc delete mode 100644 changelog.d/12088.misc delete mode 100644 changelog.d/12089.bugfix delete mode 100644 changelog.d/12092.misc delete mode 100644 changelog.d/12094.misc delete mode 100644 changelog.d/12098.bugfix delete mode 100644 changelog.d/12099.misc delete mode 100644 changelog.d/12100.bugfix delete mode 100644 changelog.d/12105.bugfix delete mode 100644 changelog.d/12106.misc delete mode 100644 changelog.d/12109.misc delete mode 100644 changelog.d/12111.misc delete mode 100644 changelog.d/12112.docker delete mode 100644 changelog.d/12119.misc diff --git a/CHANGES.md b/CHANGES.md index 81333097ae..4f0318970e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,99 @@ +Synapse 1.54.0rc1 (2022-03-02) +============================== + +Features +-------- + +- Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) +- Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. ([\#11835](https://github.com/matrix-org/synapse/issues/11835)) +- Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) +- Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000)) +- Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). ([\#12001](https://github.com/matrix-org/synapse/issues/12001), [\#12067](https://github.com/matrix-org/synapse/issues/12067)) +- Enable modules to set a custom display name when registering a user. ([\#12009](https://github.com/matrix-org/synapse/issues/12009)) +- Advertise Matrix 1.1 support on `/_matrix/client/versions`. ([\#12020](https://github.com/matrix-org/synapse/issues/12020)) +- Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. ([\#12021](https://github.com/matrix-org/synapse/issues/12021)) +- Advertise Matrix 1.2 support on `/_matrix/client/versions`. ([\#12022](https://github.com/matrix-org/synapse/issues/12022)) +- Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). ([\#12058](https://github.com/matrix-org/synapse/issues/12058)) +- Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. ([\#12062](https://github.com/matrix-org/synapse/issues/12062)) + + +Bugfixes +-------- + +- Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) +- Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) +- Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) +- Properly fix a long-standing bug where wrong data could be inserted in the `event_search` table when using sqlite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) +- Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) +- Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. ([\#12077](https://github.com/matrix-org/synapse/issues/12077)) +- Fix occasional 'Unhandled error in Deferred' error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089)) +- Fix a bug introduced in Synapse 1.51.0rc1 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#12098](https://github.com/matrix-org/synapse/issues/12098)) +- Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. ([\#12100](https://github.com/matrix-org/synapse/issues/12100)) +- Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. ([\#12105](https://github.com/matrix-org/synapse/issues/12105)) + + +Updates to the Docker image +--------------------------- + +- The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. ([\#11997](https://github.com/matrix-org/synapse/issues/11997)) +- Use Python 3.9 in Docker images by default. ([\#12112](https://github.com/matrix-org/synapse/issues/12112)) + + +Improved Documentation +---------------------- + +- Document support for the `to_device`, `account_data`, `receipts`, and `presence` stream writers for workers. ([\#11599](https://github.com/matrix-org/synapse/issues/11599)) +- Explain the meaning of spam checker callbacks' return values. ([\#12003](https://github.com/matrix-org/synapse/issues/12003)) +- Clarify information about external Identity Provider IDs. ([\#12004](https://github.com/matrix-org/synapse/issues/12004)) + + +Deprecations and Removals +------------------------- + +- Deprecate using `synctl` with the config option `synctl_cache_factor` and print a warning if a user still uses this option. ([\#11865](https://github.com/matrix-org/synapse/issues/11865)) +- Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration). ([\#12008](https://github.com/matrix-org/synapse/issues/12008)) +- Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported. ([\#12018](https://github.com/matrix-org/synapse/issues/12018)) +- Remove the unstable `/spaces` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#12073](https://github.com/matrix-org/synapse/issues/12073)) + + +Internal Changes +---------------- + +- Make method `get_room_version` use cached `get_room_version_id`. ([\#11808](https://github.com/matrix-org/synapse/issues/11808)) +- Remove unnecessary condition on knock->leave auth rule check. ([\#11900](https://github.com/matrix-org/synapse/issues/11900)) +- Add tests for device list changes between local users. ([\#11972](https://github.com/matrix-org/synapse/issues/11972)) +- Optimise calculating device_list changes in `/sync`. ([\#11974](https://github.com/matrix-org/synapse/issues/11974)) +- Add missing type hints to storage classes. ([\#11984](https://github.com/matrix-org/synapse/issues/11984)) +- Refactor the search code for improved readability. ([\#11991](https://github.com/matrix-org/synapse/issues/11991)) +- Move common deduplication code down into `_auth_and_persist_outliers`. ([\#11994](https://github.com/matrix-org/synapse/issues/11994)) +- Limit concurrent joins from applications services. ([\#11996](https://github.com/matrix-org/synapse/issues/11996)) +- Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. ([\#12005](https://github.com/matrix-org/synapse/issues/12005), [\#12039](https://github.com/matrix-org/synapse/issues/12039)) +- Preparation for faster-room-join work: parse msc3706 fields in send_join response. ([\#12011](https://github.com/matrix-org/synapse/issues/12011)) +- Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. ([\#12012](https://github.com/matrix-org/synapse/issues/12012)) +- Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. ([\#12013](https://github.com/matrix-org/synapse/issues/12013)) +- Configure `tox` to use `venv` rather than `virtualenv`. ([\#12015](https://github.com/matrix-org/synapse/issues/12015)) +- Fix bug in `StateFilter.return_expanded()` and add some tests. ([\#12016](https://github.com/matrix-org/synapse/issues/12016)) +- Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. ([\#12019](https://github.com/matrix-org/synapse/issues/12019)) +- Update the `olddeps` CI job to use an old version of `markupsafe`. ([\#12025](https://github.com/matrix-org/synapse/issues/12025)) +- Upgrade mypy to version 0.931. ([\#12030](https://github.com/matrix-org/synapse/issues/12030)) +- Remove legacy `HomeServer.get_datastore()`. ([\#12031](https://github.com/matrix-org/synapse/issues/12031), [\#12070](https://github.com/matrix-org/synapse/issues/12070)) +- Minor typing fixes. ([\#12034](https://github.com/matrix-org/synapse/issues/12034), [\#12069](https://github.com/matrix-org/synapse/issues/12069)) +- After joining a room, create a dedicated logcontext to process the queued events. ([\#12041](https://github.com/matrix-org/synapse/issues/12041)) +- Tidy up GitHub Actions config which builds distributions for PyPI. ([\#12051](https://github.com/matrix-org/synapse/issues/12051)) +- Move configuration out of `setup.cfg`. ([\#12052](https://github.com/matrix-org/synapse/issues/12052), [\#12059](https://github.com/matrix-org/synapse/issues/12059)) +- Fix error message when a worker process fails to talk to another worker process. ([\#12060](https://github.com/matrix-org/synapse/issues/12060)) +- Fix using the complement.sh script without specifying a dir or a branch. Contributed by Nico on behalf of Famedly. ([\#12063](https://github.com/matrix-org/synapse/issues/12063)) +- Add type hints to `tests/rest/client`. ([\#12066](https://github.com/matrix-org/synapse/issues/12066), [\#12072](https://github.com/matrix-org/synapse/issues/12072), [\#12084](https://github.com/matrix-org/synapse/issues/12084), [\#12094](https://github.com/matrix-org/synapse/issues/12094)) +- Add some logging to `/sync` to try and track down #11916. ([\#12068](https://github.com/matrix-org/synapse/issues/12068)) +- Inspect application dependencies using `importlib.metadata` or its backport. ([\#12088](https://github.com/matrix-org/synapse/issues/12088)) +- User `assertEqual` instead of the deprecated `assertEquals` in test code. ([\#12092](https://github.com/matrix-org/synapse/issues/12092)) +- Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to /versions. ([\#12099](https://github.com/matrix-org/synapse/issues/12099)) +- Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. ([\#12106](https://github.com/matrix-org/synapse/issues/12106)) +- Improve exception handling for concurrent execution. ([\#12109](https://github.com/matrix-org/synapse/issues/12109)) +- Advertise support for Python 3.10 in packaging files. ([\#12111](https://github.com/matrix-org/synapse/issues/12111)) +- Move CI checks out of tox, to facilitate a move to using poetry. ([\#12119](https://github.com/matrix-org/synapse/issues/12119)) + + Synapse 1.53.0 (2022-02-22) =========================== diff --git a/changelog.d/11599.doc b/changelog.d/11599.doc deleted file mode 100644 index f07cfbef4e..0000000000 --- a/changelog.d/11599.doc +++ /dev/null @@ -1 +0,0 @@ -Document support for the `to_device`, `account_data`, `receipts`, and `presence` stream writers for workers. diff --git a/changelog.d/11617.feature b/changelog.d/11617.feature deleted file mode 100644 index cf03f00e7c..0000000000 --- a/changelog.d/11617.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. \ No newline at end of file diff --git a/changelog.d/11808.misc b/changelog.d/11808.misc deleted file mode 100644 index cdc5fc75b7..0000000000 --- a/changelog.d/11808.misc +++ /dev/null @@ -1 +0,0 @@ -Make method `get_room_version` use cached `get_room_version_id`. diff --git a/changelog.d/11835.feature b/changelog.d/11835.feature deleted file mode 100644 index 7cee39b08c..0000000000 --- a/changelog.d/11835.feature +++ /dev/null @@ -1 +0,0 @@ -Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. diff --git a/changelog.d/11865.removal b/changelog.d/11865.removal deleted file mode 100644 index 9fcabfc720..0000000000 --- a/changelog.d/11865.removal +++ /dev/null @@ -1 +0,0 @@ -Deprecate using `synctl` with the config option `synctl_cache_factor` and print a warning if a user still uses this option. diff --git a/changelog.d/11900.misc b/changelog.d/11900.misc deleted file mode 100644 index edd2852fd4..0000000000 --- a/changelog.d/11900.misc +++ /dev/null @@ -1 +0,0 @@ -Remove unnecessary condition on knock->leave auth rule check. \ No newline at end of file diff --git a/changelog.d/11972.misc b/changelog.d/11972.misc deleted file mode 100644 index 29c38bfd82..0000000000 --- a/changelog.d/11972.misc +++ /dev/null @@ -1 +0,0 @@ -Add tests for device list changes between local users. \ No newline at end of file diff --git a/changelog.d/11974.misc b/changelog.d/11974.misc deleted file mode 100644 index 1debad2361..0000000000 --- a/changelog.d/11974.misc +++ /dev/null @@ -1 +0,0 @@ -Optimise calculating device_list changes in `/sync`. diff --git a/changelog.d/11984.misc b/changelog.d/11984.misc deleted file mode 100644 index 8e405b9226..0000000000 --- a/changelog.d/11984.misc +++ /dev/null @@ -1 +0,0 @@ -Add missing type hints to storage classes. \ No newline at end of file diff --git a/changelog.d/11985.feature b/changelog.d/11985.feature deleted file mode 100644 index 120d888a49..0000000000 --- a/changelog.d/11985.feature +++ /dev/null @@ -1 +0,0 @@ -Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama. diff --git a/changelog.d/11991.misc b/changelog.d/11991.misc deleted file mode 100644 index 34a3b3a6b9..0000000000 --- a/changelog.d/11991.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor the search code for improved readability. diff --git a/changelog.d/11992.bugfix b/changelog.d/11992.bugfix deleted file mode 100644 index f73c86bb25..0000000000 --- a/changelog.d/11992.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. diff --git a/changelog.d/11994.misc b/changelog.d/11994.misc deleted file mode 100644 index d64297dd78..0000000000 --- a/changelog.d/11994.misc +++ /dev/null @@ -1 +0,0 @@ -Move common deduplication code down into `_auth_and_persist_outliers`. diff --git a/changelog.d/11996.misc b/changelog.d/11996.misc deleted file mode 100644 index 6c675fd193..0000000000 --- a/changelog.d/11996.misc +++ /dev/null @@ -1 +0,0 @@ -Limit concurrent joins from applications services. \ No newline at end of file diff --git a/changelog.d/11997.docker b/changelog.d/11997.docker deleted file mode 100644 index 1b3271457e..0000000000 --- a/changelog.d/11997.docker +++ /dev/null @@ -1 +0,0 @@ -The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. diff --git a/changelog.d/11999.bugfix b/changelog.d/11999.bugfix deleted file mode 100644 index fd84095900..0000000000 --- a/changelog.d/11999.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. diff --git a/changelog.d/12000.feature b/changelog.d/12000.feature deleted file mode 100644 index 246cc87f0b..0000000000 --- a/changelog.d/12000.feature +++ /dev/null @@ -1 +0,0 @@ -Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. diff --git a/changelog.d/12001.feature b/changelog.d/12001.feature deleted file mode 100644 index dc1153c49e..0000000000 --- a/changelog.d/12001.feature +++ /dev/null @@ -1 +0,0 @@ -Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). diff --git a/changelog.d/12003.doc b/changelog.d/12003.doc deleted file mode 100644 index 1ac8163559..0000000000 --- a/changelog.d/12003.doc +++ /dev/null @@ -1 +0,0 @@ -Explain the meaning of spam checker callbacks' return values. diff --git a/changelog.d/12004.doc b/changelog.d/12004.doc deleted file mode 100644 index 0b4baef210..0000000000 --- a/changelog.d/12004.doc +++ /dev/null @@ -1 +0,0 @@ -Clarify information about external Identity Provider IDs. diff --git a/changelog.d/12005.misc b/changelog.d/12005.misc deleted file mode 100644 index 45e21dbe59..0000000000 --- a/changelog.d/12005.misc +++ /dev/null @@ -1 +0,0 @@ -Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. diff --git a/changelog.d/12008.removal b/changelog.d/12008.removal deleted file mode 100644 index 57599d9ee9..0000000000 --- a/changelog.d/12008.removal +++ /dev/null @@ -1 +0,0 @@ -Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration). diff --git a/changelog.d/12009.feature b/changelog.d/12009.feature deleted file mode 100644 index c8a531481e..0000000000 --- a/changelog.d/12009.feature +++ /dev/null @@ -1 +0,0 @@ -Enable modules to set a custom display name when registering a user. diff --git a/changelog.d/12011.misc b/changelog.d/12011.misc deleted file mode 100644 index 258b0e389f..0000000000 --- a/changelog.d/12011.misc +++ /dev/null @@ -1 +0,0 @@ -Preparation for faster-room-join work: parse msc3706 fields in send_join response. diff --git a/changelog.d/12012.misc b/changelog.d/12012.misc deleted file mode 100644 index a473f41e78..0000000000 --- a/changelog.d/12012.misc +++ /dev/null @@ -1 +0,0 @@ -Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. diff --git a/changelog.d/12013.misc b/changelog.d/12013.misc deleted file mode 100644 index c0fca8dccb..0000000000 --- a/changelog.d/12013.misc +++ /dev/null @@ -1 +0,0 @@ -Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. diff --git a/changelog.d/12015.misc b/changelog.d/12015.misc deleted file mode 100644 index 3aa32ab4cf..0000000000 --- a/changelog.d/12015.misc +++ /dev/null @@ -1 +0,0 @@ -Configure `tox` to use `venv` rather than `virtualenv`. diff --git a/changelog.d/12016.misc b/changelog.d/12016.misc deleted file mode 100644 index 8856ef46a9..0000000000 --- a/changelog.d/12016.misc +++ /dev/null @@ -1 +0,0 @@ -Fix bug in `StateFilter.return_expanded()` and add some tests. \ No newline at end of file diff --git a/changelog.d/12018.removal b/changelog.d/12018.removal deleted file mode 100644 index e940b62228..0000000000 --- a/changelog.d/12018.removal +++ /dev/null @@ -1 +0,0 @@ -Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported. diff --git a/changelog.d/12019.misc b/changelog.d/12019.misc deleted file mode 100644 index b2186320ea..0000000000 --- a/changelog.d/12019.misc +++ /dev/null @@ -1 +0,0 @@ -Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. \ No newline at end of file diff --git a/changelog.d/12020.feature b/changelog.d/12020.feature deleted file mode 100644 index 1ac9d2060e..0000000000 --- a/changelog.d/12020.feature +++ /dev/null @@ -1 +0,0 @@ -Advertise Matrix 1.1 support on `/_matrix/client/versions`. \ No newline at end of file diff --git a/changelog.d/12021.feature b/changelog.d/12021.feature deleted file mode 100644 index 01378df8ca..0000000000 --- a/changelog.d/12021.feature +++ /dev/null @@ -1 +0,0 @@ -Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. \ No newline at end of file diff --git a/changelog.d/12022.feature b/changelog.d/12022.feature deleted file mode 100644 index 188fb12570..0000000000 --- a/changelog.d/12022.feature +++ /dev/null @@ -1 +0,0 @@ -Advertise Matrix 1.2 support on `/_matrix/client/versions`. \ No newline at end of file diff --git a/changelog.d/12024.bugfix b/changelog.d/12024.bugfix deleted file mode 100644 index 59bcdb93a5..0000000000 --- a/changelog.d/12024.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. diff --git a/changelog.d/12025.misc b/changelog.d/12025.misc deleted file mode 100644 index d9475a7718..0000000000 --- a/changelog.d/12025.misc +++ /dev/null @@ -1 +0,0 @@ -Update the `olddeps` CI job to use an old version of `markupsafe`. diff --git a/changelog.d/12030.misc b/changelog.d/12030.misc deleted file mode 100644 index 607ee97ce6..0000000000 --- a/changelog.d/12030.misc +++ /dev/null @@ -1 +0,0 @@ -Upgrade mypy to version 0.931. diff --git a/changelog.d/12031.misc b/changelog.d/12031.misc deleted file mode 100644 index d4bedc6b97..0000000000 --- a/changelog.d/12031.misc +++ /dev/null @@ -1 +0,0 @@ -Remove legacy `HomeServer.get_datastore()`. diff --git a/changelog.d/12034.misc b/changelog.d/12034.misc deleted file mode 100644 index 8374a63220..0000000000 --- a/changelog.d/12034.misc +++ /dev/null @@ -1 +0,0 @@ -Minor typing fixes. diff --git a/changelog.d/12037.bugfix b/changelog.d/12037.bugfix deleted file mode 100644 index 9295cb4dc0..0000000000 --- a/changelog.d/12037.bugfix +++ /dev/null @@ -1 +0,0 @@ -Properly fix a long-standing bug where wrong data could be inserted in the `event_search` table when using sqlite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. diff --git a/changelog.d/12039.misc b/changelog.d/12039.misc deleted file mode 100644 index 45e21dbe59..0000000000 --- a/changelog.d/12039.misc +++ /dev/null @@ -1 +0,0 @@ -Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. diff --git a/changelog.d/12041.misc b/changelog.d/12041.misc deleted file mode 100644 index e56dc093de..0000000000 --- a/changelog.d/12041.misc +++ /dev/null @@ -1 +0,0 @@ -After joining a room, create a dedicated logcontext to process the queued events. diff --git a/changelog.d/12051.misc b/changelog.d/12051.misc deleted file mode 100644 index 9959191352..0000000000 --- a/changelog.d/12051.misc +++ /dev/null @@ -1 +0,0 @@ -Tidy up GitHub Actions config which builds distributions for PyPI. \ No newline at end of file diff --git a/changelog.d/12052.misc b/changelog.d/12052.misc deleted file mode 100644 index 11755ae61b..0000000000 --- a/changelog.d/12052.misc +++ /dev/null @@ -1 +0,0 @@ -Move configuration out of `setup.cfg`. diff --git a/changelog.d/12056.bugfix b/changelog.d/12056.bugfix deleted file mode 100644 index 210e30c63f..0000000000 --- a/changelog.d/12056.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. \ No newline at end of file diff --git a/changelog.d/12058.feature b/changelog.d/12058.feature deleted file mode 100644 index 7b71692229..0000000000 --- a/changelog.d/12058.feature +++ /dev/null @@ -1 +0,0 @@ -Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). diff --git a/changelog.d/12059.misc b/changelog.d/12059.misc deleted file mode 100644 index 9ba4759d99..0000000000 --- a/changelog.d/12059.misc +++ /dev/null @@ -1 +0,0 @@ -Move configuration out of `setup.cfg`. \ No newline at end of file diff --git a/changelog.d/12060.misc b/changelog.d/12060.misc deleted file mode 100644 index d771e6a1b3..0000000000 --- a/changelog.d/12060.misc +++ /dev/null @@ -1 +0,0 @@ -Fix error message when a worker process fails to talk to another worker process. diff --git a/changelog.d/12062.feature b/changelog.d/12062.feature deleted file mode 100644 index 46a606709d..0000000000 --- a/changelog.d/12062.feature +++ /dev/null @@ -1 +0,0 @@ -Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. diff --git a/changelog.d/12063.misc b/changelog.d/12063.misc deleted file mode 100644 index e48c5dd08b..0000000000 --- a/changelog.d/12063.misc +++ /dev/null @@ -1 +0,0 @@ -Fix using the complement.sh script without specifying a dir or a branch. Contributed by Nico on behalf of Famedly. diff --git a/changelog.d/12066.misc b/changelog.d/12066.misc deleted file mode 100644 index 0360dbd61e..0000000000 --- a/changelog.d/12066.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `tests/rest/client`. diff --git a/changelog.d/12067.feature b/changelog.d/12067.feature deleted file mode 100644 index dc1153c49e..0000000000 --- a/changelog.d/12067.feature +++ /dev/null @@ -1 +0,0 @@ -Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). diff --git a/changelog.d/12068.misc b/changelog.d/12068.misc deleted file mode 100644 index 72b211e4f5..0000000000 --- a/changelog.d/12068.misc +++ /dev/null @@ -1 +0,0 @@ -Add some logging to `/sync` to try and track down #11916. diff --git a/changelog.d/12069.misc b/changelog.d/12069.misc deleted file mode 100644 index 8374a63220..0000000000 --- a/changelog.d/12069.misc +++ /dev/null @@ -1 +0,0 @@ -Minor typing fixes. diff --git a/changelog.d/12070.misc b/changelog.d/12070.misc deleted file mode 100644 index d4bedc6b97..0000000000 --- a/changelog.d/12070.misc +++ /dev/null @@ -1 +0,0 @@ -Remove legacy `HomeServer.get_datastore()`. diff --git a/changelog.d/12072.misc b/changelog.d/12072.misc deleted file mode 100644 index 0360dbd61e..0000000000 --- a/changelog.d/12072.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `tests/rest/client`. diff --git a/changelog.d/12073.removal b/changelog.d/12073.removal deleted file mode 100644 index 1f39792712..0000000000 --- a/changelog.d/12073.removal +++ /dev/null @@ -1 +0,0 @@ -Remove the unstable `/spaces` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/12077.bugfix b/changelog.d/12077.bugfix deleted file mode 100644 index 1bce82082d..0000000000 --- a/changelog.d/12077.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. diff --git a/changelog.d/12084.misc b/changelog.d/12084.misc deleted file mode 100644 index 0360dbd61e..0000000000 --- a/changelog.d/12084.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `tests/rest/client`. diff --git a/changelog.d/12088.misc b/changelog.d/12088.misc deleted file mode 100644 index ce4213650c..0000000000 --- a/changelog.d/12088.misc +++ /dev/null @@ -1 +0,0 @@ -Inspect application dependencies using `importlib.metadata` or its backport. \ No newline at end of file diff --git a/changelog.d/12089.bugfix b/changelog.d/12089.bugfix deleted file mode 100644 index 27172c4828..0000000000 --- a/changelog.d/12089.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix occasional 'Unhandled error in Deferred' error message. diff --git a/changelog.d/12092.misc b/changelog.d/12092.misc deleted file mode 100644 index 62653d6f8d..0000000000 --- a/changelog.d/12092.misc +++ /dev/null @@ -1 +0,0 @@ -User `assertEqual` instead of the deprecated `assertEquals` in test code. diff --git a/changelog.d/12094.misc b/changelog.d/12094.misc deleted file mode 100644 index 0360dbd61e..0000000000 --- a/changelog.d/12094.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `tests/rest/client`. diff --git a/changelog.d/12098.bugfix b/changelog.d/12098.bugfix deleted file mode 100644 index 6b696692e3..0000000000 --- a/changelog.d/12098.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse 1.51.0rc1 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. \ No newline at end of file diff --git a/changelog.d/12099.misc b/changelog.d/12099.misc deleted file mode 100644 index 0553825dbc..0000000000 --- a/changelog.d/12099.misc +++ /dev/null @@ -1 +0,0 @@ -Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to /versions. diff --git a/changelog.d/12100.bugfix b/changelog.d/12100.bugfix deleted file mode 100644 index 181095ad99..0000000000 --- a/changelog.d/12100.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. diff --git a/changelog.d/12105.bugfix b/changelog.d/12105.bugfix deleted file mode 100644 index f42e63e01f..0000000000 --- a/changelog.d/12105.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. diff --git a/changelog.d/12106.misc b/changelog.d/12106.misc deleted file mode 100644 index d918e9e3b1..0000000000 --- a/changelog.d/12106.misc +++ /dev/null @@ -1 +0,0 @@ -Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. diff --git a/changelog.d/12109.misc b/changelog.d/12109.misc deleted file mode 100644 index 3295e49f43..0000000000 --- a/changelog.d/12109.misc +++ /dev/null @@ -1 +0,0 @@ -Improve exception handling for concurrent execution. diff --git a/changelog.d/12111.misc b/changelog.d/12111.misc deleted file mode 100644 index be84789c9d..0000000000 --- a/changelog.d/12111.misc +++ /dev/null @@ -1 +0,0 @@ -Advertise support for Python 3.10 in packaging files. \ No newline at end of file diff --git a/changelog.d/12112.docker b/changelog.d/12112.docker deleted file mode 100644 index b9e630653d..0000000000 --- a/changelog.d/12112.docker +++ /dev/null @@ -1 +0,0 @@ -Use Python 3.9 in Docker images by default. \ No newline at end of file diff --git a/changelog.d/12119.misc b/changelog.d/12119.misc deleted file mode 100644 index f02d140f38..0000000000 --- a/changelog.d/12119.misc +++ /dev/null @@ -1 +0,0 @@ -Move CI checks out of tox, to facilitate a move to using poetry. \ No newline at end of file diff --git a/debian/changelog b/debian/changelog index 574930c085..df3db85b8e 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.54.0~rc1) stable; urgency=medium + + * New synapse release 1.54.0~rc1. + + -- Synapse Packaging team Wed, 02 Mar 2022 10:43:22 +0000 + matrix-synapse-py3 (1.53.0) stable; urgency=medium * New synapse release 1.53.0. diff --git a/synapse/__init__.py b/synapse/__init__.py index 903f2e815d..b21e1ed0f3 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.53.0" +__version__ = "1.54.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when From d800108bb4e1272235aa6f5f80b2732cee9aa5bf Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 2 Mar 2022 10:54:52 +0000 Subject: [PATCH 40/40] Reword changelog --- CHANGES.md | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 4f0318970e..5485e8d47e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,10 +1,14 @@ Synapse 1.54.0rc1 (2022-03-02) ============================== +Please note that this will be the last release of Synapse that is compatible with Mjolnir 1.3.1 and earlier. +Administrators of servers which have the Mjolnir module installed are advised to upgrade Mjolnir to version 1.3.2 or later. + + Features -------- -- Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) +- Add support for [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) - Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. ([\#11835](https://github.com/matrix-org/synapse/issues/11835)) - Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) - Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000)) @@ -22,8 +26,8 @@ Bugfixes - Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) - Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) -- Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) -- Properly fix a long-standing bug where wrong data could be inserted in the `event_search` table when using sqlite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) +- Fix a 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) +- Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) - Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) - Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. ([\#12077](https://github.com/matrix-org/synapse/issues/12077)) - Fix occasional 'Unhandled error in Deferred' error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089)) @@ -35,7 +39,7 @@ Bugfixes Updates to the Docker image --------------------------- -- The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. ([\#11997](https://github.com/matrix-org/synapse/issues/11997)) +- The Docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. ([\#11997](https://github.com/matrix-org/synapse/issues/11997)) - Use Python 3.9 in Docker images by default. ([\#12112](https://github.com/matrix-org/synapse/issues/12112)) @@ -59,35 +63,35 @@ Deprecations and Removals Internal Changes ---------------- -- Make method `get_room_version` use cached `get_room_version_id`. ([\#11808](https://github.com/matrix-org/synapse/issues/11808)) -- Remove unnecessary condition on knock->leave auth rule check. ([\#11900](https://github.com/matrix-org/synapse/issues/11900)) +- Make the `get_room_version` method use `get_room_version_id` to benefit from caching. ([\#11808](https://github.com/matrix-org/synapse/issues/11808)) +- Remove unnecessary condition on knock -> leave auth rule check. ([\#11900](https://github.com/matrix-org/synapse/issues/11900)) - Add tests for device list changes between local users. ([\#11972](https://github.com/matrix-org/synapse/issues/11972)) -- Optimise calculating device_list changes in `/sync`. ([\#11974](https://github.com/matrix-org/synapse/issues/11974)) +- Optimise calculating `device_list` changes in `/sync`. ([\#11974](https://github.com/matrix-org/synapse/issues/11974)) - Add missing type hints to storage classes. ([\#11984](https://github.com/matrix-org/synapse/issues/11984)) - Refactor the search code for improved readability. ([\#11991](https://github.com/matrix-org/synapse/issues/11991)) - Move common deduplication code down into `_auth_and_persist_outliers`. ([\#11994](https://github.com/matrix-org/synapse/issues/11994)) - Limit concurrent joins from applications services. ([\#11996](https://github.com/matrix-org/synapse/issues/11996)) - Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. ([\#12005](https://github.com/matrix-org/synapse/issues/12005), [\#12039](https://github.com/matrix-org/synapse/issues/12039)) -- Preparation for faster-room-join work: parse msc3706 fields in send_join response. ([\#12011](https://github.com/matrix-org/synapse/issues/12011)) +- Preparation for faster-room-join work: parse MSC3706 fields in send_join response. ([\#12011](https://github.com/matrix-org/synapse/issues/12011)) - Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. ([\#12012](https://github.com/matrix-org/synapse/issues/12012)) - Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. ([\#12013](https://github.com/matrix-org/synapse/issues/12013)) - Configure `tox` to use `venv` rather than `virtualenv`. ([\#12015](https://github.com/matrix-org/synapse/issues/12015)) - Fix bug in `StateFilter.return_expanded()` and add some tests. ([\#12016](https://github.com/matrix-org/synapse/issues/12016)) - Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. ([\#12019](https://github.com/matrix-org/synapse/issues/12019)) - Update the `olddeps` CI job to use an old version of `markupsafe`. ([\#12025](https://github.com/matrix-org/synapse/issues/12025)) -- Upgrade mypy to version 0.931. ([\#12030](https://github.com/matrix-org/synapse/issues/12030)) +- Upgrade Mypy to version 0.931. ([\#12030](https://github.com/matrix-org/synapse/issues/12030)) - Remove legacy `HomeServer.get_datastore()`. ([\#12031](https://github.com/matrix-org/synapse/issues/12031), [\#12070](https://github.com/matrix-org/synapse/issues/12070)) - Minor typing fixes. ([\#12034](https://github.com/matrix-org/synapse/issues/12034), [\#12069](https://github.com/matrix-org/synapse/issues/12069)) - After joining a room, create a dedicated logcontext to process the queued events. ([\#12041](https://github.com/matrix-org/synapse/issues/12041)) - Tidy up GitHub Actions config which builds distributions for PyPI. ([\#12051](https://github.com/matrix-org/synapse/issues/12051)) - Move configuration out of `setup.cfg`. ([\#12052](https://github.com/matrix-org/synapse/issues/12052), [\#12059](https://github.com/matrix-org/synapse/issues/12059)) - Fix error message when a worker process fails to talk to another worker process. ([\#12060](https://github.com/matrix-org/synapse/issues/12060)) -- Fix using the complement.sh script without specifying a dir or a branch. Contributed by Nico on behalf of Famedly. ([\#12063](https://github.com/matrix-org/synapse/issues/12063)) +- Fix using the `complement.sh` script without specifying a directory or a branch. Contributed by Nico on behalf of Famedly. ([\#12063](https://github.com/matrix-org/synapse/issues/12063)) - Add type hints to `tests/rest/client`. ([\#12066](https://github.com/matrix-org/synapse/issues/12066), [\#12072](https://github.com/matrix-org/synapse/issues/12072), [\#12084](https://github.com/matrix-org/synapse/issues/12084), [\#12094](https://github.com/matrix-org/synapse/issues/12094)) - Add some logging to `/sync` to try and track down #11916. ([\#12068](https://github.com/matrix-org/synapse/issues/12068)) - Inspect application dependencies using `importlib.metadata` or its backport. ([\#12088](https://github.com/matrix-org/synapse/issues/12088)) -- User `assertEqual` instead of the deprecated `assertEquals` in test code. ([\#12092](https://github.com/matrix-org/synapse/issues/12092)) -- Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to /versions. ([\#12099](https://github.com/matrix-org/synapse/issues/12099)) +- Use `assertEqual` instead of the deprecated `assertEquals` in test code. ([\#12092](https://github.com/matrix-org/synapse/issues/12092)) +- Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to `/versions`. ([\#12099](https://github.com/matrix-org/synapse/issues/12099)) - Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. ([\#12106](https://github.com/matrix-org/synapse/issues/12106)) - Improve exception handling for concurrent execution. ([\#12109](https://github.com/matrix-org/synapse/issues/12109)) - Advertise support for Python 3.10 in packaging files. ([\#12111](https://github.com/matrix-org/synapse/issues/12111))