From b23abca9e7c0ccb176ddb4f96270cd9dc5b6550d Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 16 May 2024 17:04:26 -0500 Subject: [PATCH] Fix test inheritance See https://github.com/element-hq/synapse/pull/17167#discussion_r1594517041 --- tests/rest/client/test_sendtodevice.py | 282 ++++++++- tests/rest/client/test_sendtodevice_base.py | 268 -------- tests/rest/client/test_sliding_sync.py | 84 ++- tests/rest/client/test_sync.py | 665 ++++++++++---------- 4 files changed, 699 insertions(+), 600 deletions(-) delete mode 100644 tests/rest/client/test_sendtodevice_base.py diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index a22f51da91..44683fdf12 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -1,7 +1,281 @@ -from tests.rest.client.test_sendtodevice_base import SendToDeviceTestCaseBase -from tests.unittest import HomeserverTestCase +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EduTypes +from synapse.rest import admin +from synapse.rest.client import login, sendtodevice, sync +from synapse.server import HomeServer +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase, override_config -class SendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTestCase): - # See SendToDeviceTestCaseBase for tests +class NotTested: + """ + We nest the base test class to avoid the tests being run twice by the test runner + when we share/import these tests in other files. Without this, Twisted trial throws + a `KeyError` in the reporter when using multiple jobs (`poetry run trial --jobs=6`). + """ + + class SendToDeviceTestCaseBase(HomeserverTestCase): + """ + Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`. + + In order to run the tests, inherit from this base-class with `HomeserverTestCase`, e.g. + `class SendToDeviceTestCase(SendToDeviceTestCase, HomeserverTestCase)` + """ + + servlets = [ + admin.register_servlets, + login.register_servlets, + sendtodevice.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sync_endpoint = "/sync" + + def test_user_to_user(self) -> None: + """A to-device message from one user to another should get delivered""" + + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send the message + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # check it appears + channel = self.make_request( + "GET", self.sync_endpoint, access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + expected_result = { + "events": [ + { + "sender": user1, + "type": "m.test", + "content": test_msg, + } + ] + } + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should re-appear if we do another sync because the to-device message is not + # deleted until we acknowledge it by sending a `?since=...` parameter in the + # next sync request corresponding to the `next_batch` value from the response. + channel = self.make_request( + "GET", self.sync_endpoint, access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should *not* appear if we do an incremental sync + sync_token = channel.json_body["next_batch"] + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual( + channel.json_body.get("to_device", {}).get("events", []), [] + ) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_local_room_key_request(self) -> None: + """m.room_key_request has special-casing; test from local user""" + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send three messages + for i in range(3): + chan = self.make_request( + "PUT", + f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", + content={"messages": {user2: {"d2": {"idx": i}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # now sync: we should get two of the three (because burst_count=2) + channel = self.make_request( + "GET", self.sync_endpoint, access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): + self.assertEqual( + msgs[i], + { + "sender": user1, + "type": "m.room_key_request", + "content": {"idx": i}, + }, + ) + sync_token = channel.json_body["next_batch"] + + # ... time passes + self.reactor.advance(1) + + # and we can send more messages + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/3", + content={"messages": {user2: {"d2": {"idx": 3}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # ... which should arrive + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, + ) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_remote_room_key_request(self) -> None: + """m.room_key_request has special-casing; test from remote user""" + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + federation_registry = self.hs.get_federation_registry() + + # send three messages + for i in range(3): + self.get_success( + federation_registry.on_edu( + EduTypes.DIRECT_TO_DEVICE, + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": i}}}, + "message_id": f"{i}", + }, + ) + ) + + # now sync: we should get two of the three + channel = self.make_request( + "GET", self.sync_endpoint, access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): + self.assertEqual( + msgs[i], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": i}, + }, + ) + sync_token = channel.json_body["next_batch"] + + # ... time passes + self.reactor.advance(1) + + # and we can send more messages + self.get_success( + federation_registry.on_edu( + EduTypes.DIRECT_TO_DEVICE, + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": 3}}}, + "message_id": "3", + }, + ) + ) + + # ... which should arrive + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": 3}, + }, + ) + + def test_limited_sync(self) -> None: + """If a limited sync for to-devices happens the next /sync should respond immediately.""" + + self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # Do an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + sync_token = channel.json_body["next_batch"] + + # Send 150 to-device messages. We limit to 100 in `/sync` + for i in range(150): + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}&timeout=300000", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + messages = channel.json_body.get("to_device", {}).get("events", []) + self.assertEqual(len(messages), 100) + sync_token = channel.json_body["next_batch"] + + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}&timeout=300000", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + messages = channel.json_body.get("to_device", {}).get("events", []) + self.assertEqual(len(messages), 50) + + +class SendToDeviceTestCase(NotTested.SendToDeviceTestCaseBase): + # See SendToDeviceTestCaseBase above pass diff --git a/tests/rest/client/test_sendtodevice_base.py b/tests/rest/client/test_sendtodevice_base.py deleted file mode 100644 index 5677f4f280..0000000000 --- a/tests/rest/client/test_sendtodevice_base.py +++ /dev/null @@ -1,268 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# . -# -# Originally licensed under the Apache License, Version 2.0: -# . -# -# [This file includes modifications made by New Vector Limited] -# -# - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.constants import EduTypes -from synapse.rest import admin -from synapse.rest.client import login, sendtodevice, sync -from synapse.server import HomeServer -from synapse.util import Clock - -from tests.unittest import HomeserverTestCase, override_config - - -class SendToDeviceTestCaseBase(HomeserverTestCase): - """ - Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`. - - In order to run the tests, inherit from this base-class with `HomeserverTestCase`, e.g. - `class SendToDeviceTestCase(SendToDeviceTestCase, HomeserverTestCase)` - """ - - servlets = [ - admin.register_servlets, - login.register_servlets, - sendtodevice.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sync_endpoint = "/sync" - - def test_user_to_user(self) -> None: - """A to-device message from one user to another should get delivered""" - - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # send the message - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.test/1234", - content={"messages": {user2: {"d2": test_msg}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # check it appears - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - expected_result = { - "events": [ - { - "sender": user1, - "type": "m.test", - "content": test_msg, - } - ] - } - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should re-appear if we do another sync because the to-device message is not - # deleted until we acknowledge it by sending a `?since=...` parameter in the - # next sync request corresponding to the `next_batch` value from the response. - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should *not* appear if we do an incremental sync - sync_token = channel.json_body["next_batch"] - channel = self.make_request( - "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_local_room_key_request(self) -> None: - """m.room_key_request has special-casing; test from local user""" - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # send three messages - for i in range(3): - chan = self.make_request( - "PUT", - f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", - content={"messages": {user2: {"d2": {"idx": i}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # now sync: we should get two of the three (because burst_count=2) - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}}, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.room_key_request/3", - content={"messages": {user2: {"d2": {"idx": 3}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # ... which should arrive - channel = self.make_request( - "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, - ) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_remote_room_key_request(self) -> None: - """m.room_key_request has special-casing; test from remote user""" - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - federation_registry = self.hs.get_federation_registry() - - # send three messages - for i in range(3): - self.get_success( - federation_registry.on_edu( - EduTypes.DIRECT_TO_DEVICE, - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": i}}}, - "message_id": f"{i}", - }, - ) - ) - - # now sync: we should get two of the three - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": i}, - }, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages - self.get_success( - federation_registry.on_edu( - EduTypes.DIRECT_TO_DEVICE, - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": 3}}}, - "message_id": "3", - }, - ) - ) - - # ... which should arrive - channel = self.make_request( - "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": 3}, - }, - ) - - def test_limited_sync(self) -> None: - """If a limited sync for to-devices happens the next /sync should respond immediately.""" - - self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # Do an initial sync - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - sync_token = channel.json_body["next_batch"] - - # Send 150 to-device messages. We limit to 100 in `/sync` - for i in range(150): - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}", - content={"messages": {user2: {"d2": test_msg}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}&timeout=300000", - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - messages = channel.json_body.get("to_device", {}).get("events", []) - self.assertEqual(len(messages), 100) - sync_token = channel.json_body["next_batch"] - - channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}&timeout=300000", - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - messages = channel.json_body.get("to_device", {}).get("events", []) - self.assertEqual(len(messages), 50) diff --git a/tests/rest/client/test_sliding_sync.py b/tests/rest/client/test_sliding_sync.py index eb2eb397a5..d960ef4cb4 100644 --- a/tests/rest/client/test_sliding_sync.py +++ b/tests/rest/client/test_sliding_sync.py @@ -4,18 +4,18 @@ from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -# TODO: Uncomment this line when we have a pattern to share tests across files, see -# https://github.com/element-hq/synapse/pull/17167#discussion_r1594517041 -# -# from tests.rest.client.test_sync import DeviceListSyncTestCase -# from tests.rest.client.test_sync import DeviceOneTimeKeysSyncTestCase -# from tests.rest.client.test_sync import DeviceUnusedFallbackKeySyncTestCase -from tests.rest.client.test_sendtodevice_base import SendToDeviceTestCaseBase -from tests.unittest import HomeserverTestCase +from tests.rest.client.test_sendtodevice import NotTested as SendToDeviceNotTested +from tests.rest.client.test_sync import NotTested as SyncNotTested -# Test To-Device messages working correctly with the `/sync/e2ee` endpoint (`to_device`) -class SlidingSyncE2eeSendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTestCase): +class SlidingSyncE2eeSendToDeviceTestCase( + SendToDeviceNotTested.SendToDeviceTestCaseBase +): + """ + Test To-Device messages working correctly with the `/sync/e2ee` endpoint + (`to_device`) + """ + def default_config(self) -> JsonDict: config = super().default_config() # Enable sliding sync @@ -23,7 +23,71 @@ class SlidingSyncE2eeSendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTe return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) # Use the Sliding Sync `/sync/e2ee` endpoint self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee" # See SendToDeviceTestCaseBase for tests + + +class SlidingSyncE2eeDeviceListSyncTestCase(SyncNotTested.DeviceListSyncTestCaseBase): + """ + Test device lists working correctly with the `/sync/e2ee` endpoint (`device_lists`) + """ + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable sliding sync + config["experimental_features"] = {"msc3575_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + # Use the Sliding Sync `/sync/e2ee` endpoint + self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee" + + # See DeviceListSyncTestCaseBase for tests + + +class SlidingSyncE2eeDeviceOneTimeKeysSyncTestCase( + SyncNotTested.DeviceOneTimeKeysSyncTestCaseBase +): + """ + Test device one time keys working correctly with the `/sync/e2ee` endpoint + (`device_one_time_keys_count`) + """ + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable sliding sync + config["experimental_features"] = {"msc3575_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + # Use the Sliding Sync `/sync/e2ee` endpoint + self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee" + + # See DeviceOneTimeKeysSyncTestCaseBase for tests + + +class SlidingSyncE2eeDeviceUnusedFallbackKeySyncTestCase( + SyncNotTested.DeviceUnusedFallbackKeySyncTestCaseBase +): + """ + Test device unused fallback key types working correctly with the `/sync/e2ee` + endpoint (`device_unused_fallback_key_types`) + """ + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable sliding sync + config["experimental_features"] = {"msc3575_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + # Use the Sliding Sync `/sync/e2ee` endpoint + self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee" + + # See DeviceUnusedFallbackKeySyncTestCaseBase for tests diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index b4baa6a385..e89c2d7200 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -688,367 +688,396 @@ class SyncCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.json_body) -class DeviceListSyncTestCase(unittest.HomeserverTestCase): - """Tests regarding device list (`device_lists`) changes.""" +class NotTested: + """ + We nest the base test class to avoid the tests being run twice by the test runner + when we share/import these tests in other files. Without this, Twisted trial throws + a `KeyError` in the reporter when using multiple jobs (`poetry run trial --jobs=6`). + """ - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - devices.register_servlets, - ] + class DeviceListSyncTestCaseBase(unittest.HomeserverTestCase): + """Tests regarding device list (`device_lists`) changes.""" - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sync_endpoint = "/sync" + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] - def test_receiving_local_device_list_changes(self) -> None: - """Tests that a local users that share a room receive each other's device list - changes. - """ - # Register two users - test_device_id = "TESTDEVICE" - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sync_endpoint = "/sync" - bob_user_id = self.register_user("bob", "ponyponypony") - bob_access_token = self.login(bob_user_id, "ponyponypony") + def test_receiving_local_device_list_changes(self) -> None: + """Tests that a local users that share a room receive each other's device list + changes. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) - # Create a room for them to coexist peacefully in - new_room_id = self.helper.create_room_as( - alice_user_id, is_public=True, tok=alice_access_token - ) - self.assertIsNotNone(new_room_id) + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") - # Have Bob join the room - self.helper.invite( - new_room_id, alice_user_id, bob_user_id, tok=alice_access_token - ) - self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) + # Create a room for them to coexist peacefully in + new_room_id = self.helper.create_room_as( + alice_user_id, is_public=True, tok=alice_access_token + ) + self.assertIsNotNone(new_room_id) - # Now have Bob initiate an initial sync (in order to get a since token) - channel = self.make_request( - "GET", - self.sync_endpoint, - access_token=bob_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) - next_batch_token = channel.json_body["next_batch"] + # Have Bob join the room + self.helper.invite( + new_room_id, alice_user_id, bob_user_id, tok=alice_access_token + ) + self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) - # ...and then an incremental sync. This should block until the sync stream is woken up, - # which we hope will happen as a result of Alice updating their device list. - bob_sync_channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000", - access_token=bob_access_token, - # Start the request, then continue on. - await_result=False, - ) + # Now have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + self.sync_endpoint, + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] - # Have alice update their device list - channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=alice_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) - # Check that bob's incremental sync contains the updated device list. - # If not, the client would only receive the device list update on the - # *next* sync. - bob_sync_channel.await_result() - self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) - changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( - "changed", [] - ) - self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) + # Check that bob's incremental sync contains the updated device list. + # If not, the client would only receive the device list update on the + # *next* sync. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) - def test_not_receiving_local_device_list_changes(self) -> None: - """Tests a local users DO NOT receive device updates from each other if they do not - share a room. - """ - # Register two users - test_device_id = "TESTDEVICE" - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) + changed_device_lists = bob_sync_channel.json_body.get( + "device_lists", {} + ).get("changed", []) + self.assertIn( + alice_user_id, changed_device_lists, bob_sync_channel.json_body + ) - bob_user_id = self.register_user("bob", "ponyponypony") - bob_access_token = self.login(bob_user_id, "ponyponypony") + def test_not_receiving_local_device_list_changes(self) -> None: + """Tests a local users DO NOT receive device updates from each other if they do not + share a room. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) - # These users do not share a room. They are lonely. + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") - # Have Bob initiate an initial sync (in order to get a since token) - channel = self.make_request( - "GET", - self.sync_endpoint, - access_token=bob_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) - next_batch_token = channel.json_body["next_batch"] + # These users do not share a room. They are lonely. - # ...and then an incremental sync. This should block until the sync stream is woken up, - # which we hope will happen as a result of Alice updating their device list. - bob_sync_channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000", - access_token=bob_access_token, - # Start the request, then continue on. - await_result=False, - ) + # Have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + self.sync_endpoint, + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] - # Have alice update their device list - channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=alice_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) - # Check that bob's incremental sync does not contain the updated device list. - bob_sync_channel.await_result() - self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) - changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( - "changed", [] - ) - self.assertNotIn( - alice_user_id, changed_device_lists, bob_sync_channel.json_body - ) + # Check that bob's incremental sync does not contain the updated device list. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) - def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: - """Tests that a user with no rooms still receives their own device list updates""" - test_device_id = "TESTDEVICE" + changed_device_lists = bob_sync_channel.json_body.get( + "device_lists", {} + ).get("changed", []) + self.assertNotIn( + alice_user_id, changed_device_lists, bob_sync_channel.json_body + ) - # Register a user and login, creating a device - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) + def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: + """Tests that a user with no rooms still receives their own device list updates""" + test_device_id = "TESTDEVICE" - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) - next_batch = channel.json_body["next_batch"] + # Register a user and login, creating a device + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) - # Now, make an incremental sync request. - # It won't return until something has happened - incremental_sync_channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={next_batch}&timeout=30000", - access_token=alice_access_token, - await_result=False, - ) + # Request an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=alice_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch = channel.json_body["next_batch"] - # Change our device's display name - channel = self.make_request( - "PUT", - f"devices/{test_device_id}", - { - "display_name": "freeze ray", - }, - access_token=alice_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) + # Now, make an incremental sync request. + # It won't return until something has happened + incremental_sync_channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={next_batch}&timeout=30000", + access_token=alice_access_token, + await_result=False, + ) - # The sync should now have returned - incremental_sync_channel.await_result(timeout_ms=20000) - self.assertEqual(incremental_sync_channel.code, 200, channel.json_body) + # Change our device's display name + channel = self.make_request( + "PUT", + f"devices/{test_device_id}", + { + "display_name": "freeze ray", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) - # We should have received notification that the (user's) device has changed - device_list_changes = incremental_sync_channel.json_body.get( - "device_lists", {} - ).get("changed", []) + # The sync should now have returned + incremental_sync_channel.await_result(timeout_ms=20000) + self.assertEqual(incremental_sync_channel.code, 200, channel.json_body) - self.assertIn( - alice_user_id, device_list_changes, incremental_sync_channel.json_body - ) + # We should have received notification that the (user's) device has changed + device_list_changes = incremental_sync_channel.json_body.get( + "device_lists", {} + ).get("changed", []) + self.assertIn( + alice_user_id, device_list_changes, incremental_sync_channel.json_body + ) -class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase): - """Tests regarding device one time keys (`device_one_time_keys_count`) changes.""" + class DeviceOneTimeKeysSyncTestCaseBase(unittest.HomeserverTestCase): + """Tests regarding device one time keys (`device_one_time_keys_count`) changes.""" - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - sync.register_servlets, - devices.register_servlets, - ] + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sync_endpoint = "/sync" - self.e2e_keys_handler = hs.get_e2e_keys_handler() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sync_endpoint = "/sync" + self.e2e_keys_handler = hs.get_e2e_keys_handler() - def test_no_device_one_time_keys(self) -> None: - """ - Tests when no one time keys set, it still has the default `signed_curve25519` in - `device_one_time_keys_count` - """ - test_device_id = "TESTDEVICE" + def test_no_device_one_time_keys(self) -> None: + """ + Tests when no one time keys set, it still has the default `signed_curve25519` in + `device_one_time_keys_count` + """ + test_device_id = "TESTDEVICE" - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) + # Request an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=alice_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) - # Check for those one time key counts - self.assertDictEqual( - channel.json_body["device_one_time_keys_count"], + # Check for those one time key counts + self.assertDictEqual( + channel.json_body["device_one_time_keys_count"], + # Note that "signed_curve25519" is always returned in key count responses + # regardless of whether we uploaded any keys for it. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. + {"signed_curve25519": 0}, + channel.json_body["device_one_time_keys_count"], + ) + + def test_returns_device_one_time_keys(self) -> None: + """ + Tests that one time keys for the device/user are counted correctly in the `/sync` + response + """ + test_device_id = "TESTDEVICE" + + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + # Upload one time keys for the user/device + keys: JsonDict = { + "alg1:k1": "key1", + "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, + "alg2:k3": {"key": "key3"}, + } + res = self.get_success( + self.e2e_keys_handler.upload_keys_for_user( + alice_user_id, test_device_id, {"one_time_keys": keys} + ) + ) # Note that "signed_curve25519" is always returned in key count responses # regardless of whether we uploaded any keys for it. This is necessary until # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - {"signed_curve25519": 0}, - channel.json_body["device_one_time_keys_count"], - ) - - def test_returns_device_one_time_keys(self) -> None: - """ - Tests that one time keys for the device/user are counted correctly in the `/sync` - response - """ - test_device_id = "TESTDEVICE" - - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - # Upload one time keys for the user/device - keys: JsonDict = { - "alg1:k1": "key1", - "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, - "alg2:k3": {"key": "key3"}, - } - res = self.get_success( - self.e2e_keys_handler.upload_keys_for_user( - alice_user_id, test_device_id, {"one_time_keys": keys} + self.assertDictEqual( + res, + {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}, ) - ) - # Note that "signed_curve25519" is always returned in key count responses - # regardless of whether we uploaded any keys for it. This is necessary until - # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - self.assertDictEqual( - res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} - ) - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check for those one time key counts - self.assertDictEqual( - channel.json_body["device_one_time_keys_count"], - {"alg1": 1, "alg2": 2, "signed_curve25519": 0}, - channel.json_body["device_one_time_keys_count"], - ) - - -class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase): - """Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.""" - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - sync.register_servlets, - devices.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sync_endpoint = "/sync" - self.store = self.hs.get_datastores().main - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - def test_no_device_unused_fallback_key(self) -> None: - """ - Test when no unused fallback key is set, it just returns an empty list. The MSC - says "The device_unused_fallback_key_types parameter must be present if the - server supports fallback keys.", - https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md - """ - test_device_id = "TESTDEVICE" - - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check for those one time key counts - self.assertListEqual( - channel.json_body["device_unused_fallback_key_types"], - [], - channel.json_body["device_unused_fallback_key_types"], - ) - - def test_returns_device_one_time_keys(self) -> None: - """ - Tests that device unused fallback key type is returned correctly in the `/sync` - """ - test_device_id = "TESTDEVICE" - - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - # We shouldn't have any unused fallback keys yet - res = self.get_success( - self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id) - ) - self.assertEqual(res, []) - - # Upload a fallback key for the user/device - fallback_key = {"alg1:k1": "fallback_key1"} - self.get_success( - self.e2e_keys_handler.upload_keys_for_user( - alice_user_id, - test_device_id, - {"fallback_keys": fallback_key}, + # Request an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=alice_access_token ) - ) - # We should now have an unused alg1 key - fallback_res = self.get_success( - self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id) - ) - self.assertEqual(fallback_res, ["alg1"], fallback_res) + self.assertEqual(channel.code, 200, channel.json_body) - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) + # Check for those one time key counts + self.assertDictEqual( + channel.json_body["device_one_time_keys_count"], + {"alg1": 1, "alg2": 2, "signed_curve25519": 0}, + channel.json_body["device_one_time_keys_count"], + ) - # Check for the unused fallback key types - self.assertListEqual( - channel.json_body["device_unused_fallback_key_types"], - ["alg1"], - channel.json_body["device_unused_fallback_key_types"], - ) + class DeviceUnusedFallbackKeySyncTestCaseBase(unittest.HomeserverTestCase): + """Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.""" + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sync_endpoint = "/sync" + self.store = self.hs.get_datastores().main + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + def test_no_device_unused_fallback_key(self) -> None: + """ + Test when no unused fallback key is set, it just returns an empty list. The MSC + says "The device_unused_fallback_key_types parameter must be present if the + server supports fallback keys.", + https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md + """ + test_device_id = "TESTDEVICE" + + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + # Request an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=alice_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check for those one time key counts + self.assertListEqual( + channel.json_body["device_unused_fallback_key_types"], + [], + channel.json_body["device_unused_fallback_key_types"], + ) + + def test_returns_device_one_time_keys(self) -> None: + """ + Tests that device unused fallback key type is returned correctly in the `/sync` + """ + test_device_id = "TESTDEVICE" + + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + # We shouldn't have any unused fallback keys yet + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types( + alice_user_id, test_device_id + ) + ) + self.assertEqual(res, []) + + # Upload a fallback key for the user/device + fallback_key = {"alg1:k1": "fallback_key1"} + self.get_success( + self.e2e_keys_handler.upload_keys_for_user( + alice_user_id, + test_device_id, + {"fallback_keys": fallback_key}, + ) + ) + # We should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types( + alice_user_id, test_device_id + ) + ) + self.assertEqual(fallback_res, ["alg1"], fallback_res) + + # Request an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=alice_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check for the unused fallback key types + self.assertListEqual( + channel.json_body["device_unused_fallback_key_types"], + ["alg1"], + channel.json_body["device_unused_fallback_key_types"], + ) + + +class DeviceListSyncTestCase(NotTested.DeviceListSyncTestCaseBase): + # See DeviceListSyncTestCaseBase above + pass + + +class DeviceOneTimeKeysSyncTestCase(NotTested.DeviceOneTimeKeysSyncTestCaseBase): + # See DeviceOneTimeKeysSyncTestCaseBase above + pass + + +class DeviceUnusedFallbackKeySyncTestCase( + NotTested.DeviceUnusedFallbackKeySyncTestCaseBase +): + # See DeviceUnusedFallbackKeySyncTestCaseBase above + pass class ExcludeRoomTestCase(unittest.HomeserverTestCase):