From 3da6bc19028cdb28f1b830caf09c9cd69b103425 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 22 May 2024 14:48:03 -0500 Subject: [PATCH] Use `@parameterized_class` As suggested in https://github.com/element-hq/synapse/pull/17167#discussion_r1610255726 --- tests/rest/client/test_sendtodevice.py | 433 ++++++++------- tests/rest/client/test_sliding_sync.py | 93 ---- tests/rest/client/test_sync.py | 709 +++++++++++++------------ 3 files changed, 574 insertions(+), 661 deletions(-) delete mode 100644 tests/rest/client/test_sliding_sync.py diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index 8e6f372ca1..ce8ba0f209 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -18,6 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # +from parameterized import parameterized_class from twisted.test.proto_helpers import MemoryReactor @@ -25,200 +26,164 @@ from synapse.api.constants import EduTypes from synapse.rest import admin from synapse.rest.client import login, sendtodevice, sync from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock from tests.unittest import HomeserverTestCase, override_config -class NotTested: +@parameterized_class( + ("sync_endpoint", "experimental_features"), + [ + ("/sync", {}), + ( + "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee", + # Enable sliding sync + {"msc3575_enabled": True}, + ), + ], +) +class SendToDeviceTestCaseBase(HomeserverTestCase): """ - We nest the base test class to avoid the tests being run twice by the test runner - when we share/import these tests in other files. Without this, Twisted trial throws - a `KeyError` in the reporter when using multiple jobs (`poetry run trial --jobs=6`). + Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`. + + In order to run the tests, inherit from this base-class with `HomeserverTestCase`, e.g. + `class SendToDeviceTestCase(SendToDeviceTestCase, HomeserverTestCase)` """ - class SendToDeviceTestCaseBase(HomeserverTestCase): - """ - Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`. + servlets = [ + admin.register_servlets, + login.register_servlets, + sendtodevice.register_servlets, + sync.register_servlets, + ] - In order to run the tests, inherit from this base-class with `HomeserverTestCase`, e.g. - `class SendToDeviceTestCase(SendToDeviceTestCase, HomeserverTestCase)` - """ + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = self.experimental_features + return config - servlets = [ - admin.register_servlets, - login.register_servlets, - sendtodevice.register_servlets, - sync.register_servlets, - ] + def test_user_to_user(self) -> None: + """A to-device message from one user to another should get delivered""" - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sync_endpoint = "/sync" + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") - def test_user_to_user(self) -> None: - """A to-device message from one user to another should get delivered""" + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") + # send the message + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") + # check it appears + channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + expected_result = { + "events": [ + { + "sender": user1, + "type": "m.test", + "content": test_msg, + } + ] + } + self.assertEqual(channel.json_body["to_device"], expected_result) - # send the message - test_msg = {"foo": "bar"} + # it should re-appear if we do another sync because the to-device message is not + # deleted until we acknowledge it by sending a `?since=...` parameter in the + # next sync request corresponding to the `next_batch` value from the response. + channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should *not* appear if we do an incremental sync + sync_token = channel.json_body["next_batch"] + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_local_room_key_request(self) -> None: + """m.room_key_request has special-casing; test from local user""" + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send three messages + for i in range(3): chan = self.make_request( "PUT", - "/_matrix/client/r0/sendToDevice/m.test/1234", - content={"messages": {user2: {"d2": test_msg}}}, + f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", + content={"messages": {user2: {"d2": {"idx": i}}}}, access_token=user1_tok, ) self.assertEqual(chan.code, 200, chan.result) - # check it appears - channel = self.make_request( - "GET", self.sync_endpoint, access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - expected_result = { - "events": [ - { - "sender": user1, - "type": "m.test", - "content": test_msg, - } - ] - } - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should re-appear if we do another sync because the to-device message is not - # deleted until we acknowledge it by sending a `?since=...` parameter in the - # next sync request corresponding to the `next_batch` value from the response. - channel = self.make_request( - "GET", self.sync_endpoint, access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should *not* appear if we do an incremental sync - sync_token = channel.json_body["next_batch"] - channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}", - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.result) + # now sync: we should get two of the three (because burst_count=2) + channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): self.assertEqual( - channel.json_body.get("to_device", {}).get("events", []), [] + msgs[i], + { + "sender": user1, + "type": "m.room_key_request", + "content": {"idx": i}, + }, ) + sync_token = channel.json_body["next_batch"] - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_local_room_key_request(self) -> None: - """m.room_key_request has special-casing; test from local user""" - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") + # ... time passes + self.reactor.advance(1) - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") + # and we can send more messages + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/3", + content={"messages": {user2: {"d2": {"idx": 3}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) - # send three messages - for i in range(3): - chan = self.make_request( - "PUT", - f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", - content={"messages": {user2: {"d2": {"idx": i}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) + # ... which should arrive + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, + ) - # now sync: we should get two of the three (because burst_count=2) - channel = self.make_request( - "GET", self.sync_endpoint, access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - { - "sender": user1, - "type": "m.room_key_request", - "content": {"idx": i}, - }, - ) - sync_token = channel.json_body["next_batch"] + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_remote_room_key_request(self) -> None: + """m.room_key_request has special-casing; test from remote user""" + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") - # ... time passes - self.reactor.advance(1) + federation_registry = self.hs.get_federation_registry() - # and we can send more messages - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.room_key_request/3", - content={"messages": {user2: {"d2": {"idx": 3}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # ... which should arrive - channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}", - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, - ) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_remote_room_key_request(self) -> None: - """m.room_key_request has special-casing; test from remote user""" - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - federation_registry = self.hs.get_federation_registry() - - # send three messages - for i in range(3): - self.get_success( - federation_registry.on_edu( - EduTypes.DIRECT_TO_DEVICE, - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": i}}}, - "message_id": f"{i}", - }, - ) - ) - - # now sync: we should get two of the three - channel = self.make_request( - "GET", self.sync_endpoint, access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": i}, - }, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages + # send three messages + for i in range(3): self.get_success( federation_registry.on_edu( EduTypes.DIRECT_TO_DEVICE, @@ -226,77 +191,103 @@ class NotTested: { "sender": "@user:remote_server", "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": 3}}}, - "message_id": "3", + "messages": {user2: {"d2": {"idx": i}}}, + "message_id": f"{i}", }, ) ) - # ... which should arrive - channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}", - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) + # now sync: we should get two of the three + channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): self.assertEqual( - msgs[0], + msgs[i], { "sender": "@user:remote_server", "type": "m.room_key_request", - "content": {"idx": 3}, + "content": {"idx": i}, }, ) + sync_token = channel.json_body["next_batch"] - def test_limited_sync(self) -> None: - """If a limited sync for to-devices happens the next /sync should respond immediately.""" + # ... time passes + self.reactor.advance(1) - self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # Do an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=user2_tok + # and we can send more messages + self.get_success( + federation_registry.on_edu( + EduTypes.DIRECT_TO_DEVICE, + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": 3}}}, + "message_id": "3", + }, ) - self.assertEqual(channel.code, 200, channel.result) - sync_token = channel.json_body["next_batch"] + ) - # Send 150 to-device messages. We limit to 100 in `/sync` - for i in range(150): - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}", - content={"messages": {user2: {"d2": test_msg}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) + # ... which should arrive + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": 3}, + }, + ) - channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}&timeout=300000", - access_token=user2_tok, + def test_limited_sync(self) -> None: + """If a limited sync for to-devices happens the next /sync should respond immediately.""" + + self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # Do an initial sync + channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + sync_token = channel.json_body["next_batch"] + + # Send 150 to-device messages. We limit to 100 in `/sync` + for i in range(150): + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, ) - self.assertEqual(channel.code, 200, channel.result) - messages = channel.json_body.get("to_device", {}).get("events", []) - self.assertEqual(len(messages), 100) - sync_token = channel.json_body["next_batch"] + self.assertEqual(chan.code, 200, chan.result) - channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}&timeout=300000", - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - messages = channel.json_body.get("to_device", {}).get("events", []) - self.assertEqual(len(messages), 50) + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}&timeout=300000", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + messages = channel.json_body.get("to_device", {}).get("events", []) + self.assertEqual(len(messages), 100) + sync_token = channel.json_body["next_batch"] - -class SendToDeviceTestCase(NotTested.SendToDeviceTestCaseBase): - # See SendToDeviceTestCaseBase above - pass + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}&timeout=300000", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + messages = channel.json_body.get("to_device", {}).get("events", []) + self.assertEqual(len(messages), 50) diff --git a/tests/rest/client/test_sliding_sync.py b/tests/rest/client/test_sliding_sync.py deleted file mode 100644 index d960ef4cb4..0000000000 --- a/tests/rest/client/test_sliding_sync.py +++ /dev/null @@ -1,93 +0,0 @@ -from twisted.test.proto_helpers import MemoryReactor - -from synapse.server import HomeServer -from synapse.types import JsonDict -from synapse.util import Clock - -from tests.rest.client.test_sendtodevice import NotTested as SendToDeviceNotTested -from tests.rest.client.test_sync import NotTested as SyncNotTested - - -class SlidingSyncE2eeSendToDeviceTestCase( - SendToDeviceNotTested.SendToDeviceTestCaseBase -): - """ - Test To-Device messages working correctly with the `/sync/e2ee` endpoint - (`to_device`) - """ - - def default_config(self) -> JsonDict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - super().prepare(reactor, clock, hs) - # Use the Sliding Sync `/sync/e2ee` endpoint - self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee" - - # See SendToDeviceTestCaseBase for tests - - -class SlidingSyncE2eeDeviceListSyncTestCase(SyncNotTested.DeviceListSyncTestCaseBase): - """ - Test device lists working correctly with the `/sync/e2ee` endpoint (`device_lists`) - """ - - def default_config(self) -> JsonDict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - super().prepare(reactor, clock, hs) - # Use the Sliding Sync `/sync/e2ee` endpoint - self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee" - - # See DeviceListSyncTestCaseBase for tests - - -class SlidingSyncE2eeDeviceOneTimeKeysSyncTestCase( - SyncNotTested.DeviceOneTimeKeysSyncTestCaseBase -): - """ - Test device one time keys working correctly with the `/sync/e2ee` endpoint - (`device_one_time_keys_count`) - """ - - def default_config(self) -> JsonDict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - super().prepare(reactor, clock, hs) - # Use the Sliding Sync `/sync/e2ee` endpoint - self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee" - - # See DeviceOneTimeKeysSyncTestCaseBase for tests - - -class SlidingSyncE2eeDeviceUnusedFallbackKeySyncTestCase( - SyncNotTested.DeviceUnusedFallbackKeySyncTestCaseBase -): - """ - Test device unused fallback key types working correctly with the `/sync/e2ee` - endpoint (`device_unused_fallback_key_types`) - """ - - def default_config(self) -> JsonDict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - super().prepare(reactor, clock, hs) - # Use the Sliding Sync `/sync/e2ee` endpoint - self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee" - - # See DeviceUnusedFallbackKeySyncTestCaseBase for tests diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index e89c2d7200..8397d89f0d 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -21,7 +21,7 @@ import json from typing import List -from parameterized import parameterized +from parameterized import parameterized, parameterized_class from twisted.test.proto_helpers import MemoryReactor @@ -688,396 +688,411 @@ class SyncCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.json_body) -class NotTested: - """ - We nest the base test class to avoid the tests being run twice by the test runner - when we share/import these tests in other files. Without this, Twisted trial throws - a `KeyError` in the reporter when using multiple jobs (`poetry run trial --jobs=6`). - """ +@parameterized_class( + ("sync_endpoint", "experimental_features"), + [ + ("/sync", {}), + ( + "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee", + # Enable sliding sync + {"msc3575_enabled": True}, + ), + ], +) +class DeviceListSyncTestCaseBase(unittest.HomeserverTestCase): + """Tests regarding device list (`device_lists`) changes.""" - class DeviceListSyncTestCaseBase(unittest.HomeserverTestCase): - """Tests regarding device list (`device_lists`) changes.""" + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - devices.register_servlets, - ] + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = self.experimental_features + return config - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sync_endpoint = "/sync" + def test_receiving_local_device_list_changes(self) -> None: + """Tests that a local users that share a room receive each other's device list + changes. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) - def test_receiving_local_device_list_changes(self) -> None: - """Tests that a local users that share a room receive each other's device list - changes. - """ - # Register two users - test_device_id = "TESTDEVICE" - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") - bob_user_id = self.register_user("bob", "ponyponypony") - bob_access_token = self.login(bob_user_id, "ponyponypony") + # Create a room for them to coexist peacefully in + new_room_id = self.helper.create_room_as( + alice_user_id, is_public=True, tok=alice_access_token + ) + self.assertIsNotNone(new_room_id) - # Create a room for them to coexist peacefully in - new_room_id = self.helper.create_room_as( - alice_user_id, is_public=True, tok=alice_access_token - ) - self.assertIsNotNone(new_room_id) + # Have Bob join the room + self.helper.invite( + new_room_id, alice_user_id, bob_user_id, tok=alice_access_token + ) + self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) - # Have Bob join the room - self.helper.invite( - new_room_id, alice_user_id, bob_user_id, tok=alice_access_token - ) - self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) + # Now have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + self.sync_endpoint, + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] - # Now have Bob initiate an initial sync (in order to get a since token) - channel = self.make_request( - "GET", - self.sync_endpoint, - access_token=bob_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) - next_batch_token = channel.json_body["next_batch"] + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) - # ...and then an incremental sync. This should block until the sync stream is woken up, - # which we hope will happen as a result of Alice updating their device list. - bob_sync_channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000", - access_token=bob_access_token, - # Start the request, then continue on. - await_result=False, - ) + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) - # Have alice update their device list - channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=alice_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) + # Check that bob's incremental sync contains the updated device list. + # If not, the client would only receive the device list update on the + # *next* sync. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) - # Check that bob's incremental sync contains the updated device list. - # If not, the client would only receive the device list update on the - # *next* sync. - bob_sync_channel.await_result() - self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) - changed_device_lists = bob_sync_channel.json_body.get( - "device_lists", {} - ).get("changed", []) - self.assertIn( - alice_user_id, changed_device_lists, bob_sync_channel.json_body - ) + def test_not_receiving_local_device_list_changes(self) -> None: + """Tests a local users DO NOT receive device updates from each other if they do not + share a room. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) - def test_not_receiving_local_device_list_changes(self) -> None: - """Tests a local users DO NOT receive device updates from each other if they do not - share a room. - """ - # Register two users - test_device_id = "TESTDEVICE" - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") - bob_user_id = self.register_user("bob", "ponyponypony") - bob_access_token = self.login(bob_user_id, "ponyponypony") + # These users do not share a room. They are lonely. - # These users do not share a room. They are lonely. + # Have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + self.sync_endpoint, + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] - # Have Bob initiate an initial sync (in order to get a since token) - channel = self.make_request( - "GET", - self.sync_endpoint, - access_token=bob_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) - next_batch_token = channel.json_body["next_batch"] + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) - # ...and then an incremental sync. This should block until the sync stream is woken up, - # which we hope will happen as a result of Alice updating their device list. - bob_sync_channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000", - access_token=bob_access_token, - # Start the request, then continue on. - await_result=False, - ) + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) - # Have alice update their device list - channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=alice_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) + # Check that bob's incremental sync does not contain the updated device list. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) - # Check that bob's incremental sync does not contain the updated device list. - bob_sync_channel.await_result() - self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertNotIn( + alice_user_id, changed_device_lists, bob_sync_channel.json_body + ) - changed_device_lists = bob_sync_channel.json_body.get( - "device_lists", {} - ).get("changed", []) - self.assertNotIn( - alice_user_id, changed_device_lists, bob_sync_channel.json_body - ) + def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: + """Tests that a user with no rooms still receives their own device list updates""" + test_device_id = "TESTDEVICE" - def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: - """Tests that a user with no rooms still receives their own device list updates""" - test_device_id = "TESTDEVICE" + # Register a user and login, creating a device + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) - # Register a user and login, creating a device - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) + # Request an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=alice_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch = channel.json_body["next_batch"] - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) - next_batch = channel.json_body["next_batch"] + # Now, make an incremental sync request. + # It won't return until something has happened + incremental_sync_channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={next_batch}&timeout=30000", + access_token=alice_access_token, + await_result=False, + ) - # Now, make an incremental sync request. - # It won't return until something has happened - incremental_sync_channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={next_batch}&timeout=30000", - access_token=alice_access_token, - await_result=False, - ) + # Change our device's display name + channel = self.make_request( + "PUT", + f"devices/{test_device_id}", + { + "display_name": "freeze ray", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) - # Change our device's display name - channel = self.make_request( - "PUT", - f"devices/{test_device_id}", - { - "display_name": "freeze ray", - }, - access_token=alice_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) + # The sync should now have returned + incremental_sync_channel.await_result(timeout_ms=20000) + self.assertEqual(incremental_sync_channel.code, 200, channel.json_body) - # The sync should now have returned - incremental_sync_channel.await_result(timeout_ms=20000) - self.assertEqual(incremental_sync_channel.code, 200, channel.json_body) + # We should have received notification that the (user's) device has changed + device_list_changes = incremental_sync_channel.json_body.get( + "device_lists", {} + ).get("changed", []) - # We should have received notification that the (user's) device has changed - device_list_changes = incremental_sync_channel.json_body.get( - "device_lists", {} - ).get("changed", []) + self.assertIn( + alice_user_id, device_list_changes, incremental_sync_channel.json_body + ) - self.assertIn( - alice_user_id, device_list_changes, incremental_sync_channel.json_body - ) - class DeviceOneTimeKeysSyncTestCaseBase(unittest.HomeserverTestCase): - """Tests regarding device one time keys (`device_one_time_keys_count`) changes.""" +@parameterized_class( + ("sync_endpoint", "experimental_features"), + [ + ("/sync", {}), + ( + "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee", + # Enable sliding sync + {"msc3575_enabled": True}, + ), + ], +) +class DeviceOneTimeKeysSyncTestCaseBase(unittest.HomeserverTestCase): + """Tests regarding device one time keys (`device_one_time_keys_count`) changes.""" - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - sync.register_servlets, - devices.register_servlets, - ] + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sync_endpoint = "/sync" - self.e2e_keys_handler = hs.get_e2e_keys_handler() + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = self.experimental_features + return config - def test_no_device_one_time_keys(self) -> None: - """ - Tests when no one time keys set, it still has the default `signed_curve25519` in - `device_one_time_keys_count` - """ - test_device_id = "TESTDEVICE" + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.e2e_keys_handler = hs.get_e2e_keys_handler() - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) + def test_no_device_one_time_keys(self) -> None: + """ + Tests when no one time keys set, it still has the default `signed_curve25519` in + `device_one_time_keys_count` + """ + test_device_id = "TESTDEVICE" - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) - # Check for those one time key counts - self.assertDictEqual( - channel.json_body["device_one_time_keys_count"], - # Note that "signed_curve25519" is always returned in key count responses - # regardless of whether we uploaded any keys for it. This is necessary until - # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - {"signed_curve25519": 0}, - channel.json_body["device_one_time_keys_count"], - ) + # Request an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=alice_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) - def test_returns_device_one_time_keys(self) -> None: - """ - Tests that one time keys for the device/user are counted correctly in the `/sync` - response - """ - test_device_id = "TESTDEVICE" - - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - # Upload one time keys for the user/device - keys: JsonDict = { - "alg1:k1": "key1", - "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, - "alg2:k3": {"key": "key3"}, - } - res = self.get_success( - self.e2e_keys_handler.upload_keys_for_user( - alice_user_id, test_device_id, {"one_time_keys": keys} - ) - ) + # Check for those one time key counts + self.assertDictEqual( + channel.json_body["device_one_time_keys_count"], # Note that "signed_curve25519" is always returned in key count responses # regardless of whether we uploaded any keys for it. This is necessary until # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - self.assertDictEqual( - res, - {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}, + {"signed_curve25519": 0}, + channel.json_body["device_one_time_keys_count"], + ) + + def test_returns_device_one_time_keys(self) -> None: + """ + Tests that one time keys for the device/user are counted correctly in the `/sync` + response + """ + test_device_id = "TESTDEVICE" + + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + # Upload one time keys for the user/device + keys: JsonDict = { + "alg1:k1": "key1", + "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, + "alg2:k3": {"key": "key3"}, + } + res = self.get_success( + self.e2e_keys_handler.upload_keys_for_user( + alice_user_id, test_device_id, {"one_time_keys": keys} ) + ) + # Note that "signed_curve25519" is always returned in key count responses + # regardless of whether we uploaded any keys for it. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. + self.assertDictEqual( + res, + {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}, + ) - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token + # Request an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=alice_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check for those one time key counts + self.assertDictEqual( + channel.json_body["device_one_time_keys_count"], + {"alg1": 1, "alg2": 2, "signed_curve25519": 0}, + channel.json_body["device_one_time_keys_count"], + ) + + +@parameterized_class( + ("sync_endpoint", "experimental_features"), + [ + ("/sync", {}), + ( + "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee", + # Enable sliding sync + {"msc3575_enabled": True}, + ), + ], +) +class DeviceUnusedFallbackKeySyncTestCaseBase(unittest.HomeserverTestCase): + """Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.""" + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = self.experimental_features + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = self.hs.get_datastores().main + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + def test_no_device_unused_fallback_key(self) -> None: + """ + Test when no unused fallback key is set, it just returns an empty list. The MSC + says "The device_unused_fallback_key_types parameter must be present if the + server supports fallback keys.", + https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md + """ + test_device_id = "TESTDEVICE" + + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + # Request an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=alice_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check for those one time key counts + self.assertListEqual( + channel.json_body["device_unused_fallback_key_types"], + [], + channel.json_body["device_unused_fallback_key_types"], + ) + + def test_returns_device_one_time_keys(self) -> None: + """ + Tests that device unused fallback key type is returned correctly in the `/sync` + """ + test_device_id = "TESTDEVICE" + + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + # We shouldn't have any unused fallback keys yet + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id) + ) + self.assertEqual(res, []) + + # Upload a fallback key for the user/device + fallback_key = {"alg1:k1": "fallback_key1"} + self.get_success( + self.e2e_keys_handler.upload_keys_for_user( + alice_user_id, + test_device_id, + {"fallback_keys": fallback_key}, ) - self.assertEqual(channel.code, 200, channel.json_body) + ) + # We should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id) + ) + self.assertEqual(fallback_res, ["alg1"], fallback_res) - # Check for those one time key counts - self.assertDictEqual( - channel.json_body["device_one_time_keys_count"], - {"alg1": 1, "alg2": 2, "signed_curve25519": 0}, - channel.json_body["device_one_time_keys_count"], - ) + # Request an initial sync + channel = self.make_request( + "GET", self.sync_endpoint, access_token=alice_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) - class DeviceUnusedFallbackKeySyncTestCaseBase(unittest.HomeserverTestCase): - """Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.""" - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - sync.register_servlets, - devices.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sync_endpoint = "/sync" - self.store = self.hs.get_datastores().main - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - def test_no_device_unused_fallback_key(self) -> None: - """ - Test when no unused fallback key is set, it just returns an empty list. The MSC - says "The device_unused_fallback_key_types parameter must be present if the - server supports fallback keys.", - https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md - """ - test_device_id = "TESTDEVICE" - - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check for those one time key counts - self.assertListEqual( - channel.json_body["device_unused_fallback_key_types"], - [], - channel.json_body["device_unused_fallback_key_types"], - ) - - def test_returns_device_one_time_keys(self) -> None: - """ - Tests that device unused fallback key type is returned correctly in the `/sync` - """ - test_device_id = "TESTDEVICE" - - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - # We shouldn't have any unused fallback keys yet - res = self.get_success( - self.store.get_e2e_unused_fallback_key_types( - alice_user_id, test_device_id - ) - ) - self.assertEqual(res, []) - - # Upload a fallback key for the user/device - fallback_key = {"alg1:k1": "fallback_key1"} - self.get_success( - self.e2e_keys_handler.upload_keys_for_user( - alice_user_id, - test_device_id, - {"fallback_keys": fallback_key}, - ) - ) - # We should now have an unused alg1 key - fallback_res = self.get_success( - self.store.get_e2e_unused_fallback_key_types( - alice_user_id, test_device_id - ) - ) - self.assertEqual(fallback_res, ["alg1"], fallback_res) - - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check for the unused fallback key types - self.assertListEqual( - channel.json_body["device_unused_fallback_key_types"], - ["alg1"], - channel.json_body["device_unused_fallback_key_types"], - ) - - -class DeviceListSyncTestCase(NotTested.DeviceListSyncTestCaseBase): - # See DeviceListSyncTestCaseBase above - pass - - -class DeviceOneTimeKeysSyncTestCase(NotTested.DeviceOneTimeKeysSyncTestCaseBase): - # See DeviceOneTimeKeysSyncTestCaseBase above - pass - - -class DeviceUnusedFallbackKeySyncTestCase( - NotTested.DeviceUnusedFallbackKeySyncTestCaseBase -): - # See DeviceUnusedFallbackKeySyncTestCaseBase above - pass + # Check for the unused fallback key types + self.assertListEqual( + channel.json_body["device_unused_fallback_key_types"], + ["alg1"], + channel.json_body["device_unused_fallback_key_types"], + ) class ExcludeRoomTestCase(unittest.HomeserverTestCase):