1
0

Compare commits

...

1 Commits

Author SHA1 Message Date
Mathieu Velten
e25c15ea0f Implements part of MSC 3944 by dropping cancelled&duplicated m.room_key_request 2023-06-27 09:49:42 +02:00
5 changed files with 215 additions and 8 deletions

View File

@@ -0,0 +1 @@
Implements bullets 1 and 2 of [MSC 3944](https://github.com/matrix-org/matrix-spec-proposals/pull/3944) related to dropping cancelled and duplicated `m.room_key_request`.

View File

@@ -389,3 +389,6 @@ class ExperimentalConfig(Config):
self.msc4010_push_rules_account_data = experimental.get( self.msc4010_push_rules_account_data = experimental.get(
"msc4010_push_rules_account_data", False "msc4010_push_rules_account_data", False
) )
# MSC3944: Dropping stale send-to-device messages
self.msc3944_enabled: bool = experimental.get("msc3944_enabled", False)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict
@@ -90,6 +91,8 @@ class DeviceMessageHandler:
burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
) )
self._msc3944_enabled = hs.config.experimental.msc3944_enabled
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
""" """
Handle receiving to-device messages from remote homeservers. Handle receiving to-device messages from remote homeservers.
@@ -220,7 +223,7 @@ class DeviceMessageHandler:
set_tag(SynapseTags.TO_DEVICE_TYPE, message_type) set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id) set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
local_messages = {} local_messages: Dict[str, Dict[str, JsonDict]] = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items(): for user_id, by_device in messages.items():
# add an opentracing log entry for each message # add an opentracing log entry for each message
@@ -255,16 +258,56 @@ class DeviceMessageHandler:
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)): if self.is_mine(UserID.from_string(user_id)):
messages_by_device = { for device_id, message_content in by_device.items():
device_id: { # Drop any previous identical (same request_id and requesting_device_id)
# room_key_request, ignoring the action property when comparing.
# This handles dropping previous identical and cancelled requests.
if (
self._msc3944_enabled
and message_type == ToDeviceEventTypes.RoomKeyRequest
and user_id == sender_user_id
):
req_id = message_content.get("request_id")
requesting_device_id = message_content.get(
"requesting_device_id"
)
if req_id and requesting_device_id:
previous_request_deleted = False
for (
stream_id,
message_json,
) in await self.store.get_all_device_messages(
user_id, device_id
):
orig_message = json.loads(message_json)
if (
orig_message["type"]
== ToDeviceEventTypes.RoomKeyRequest
):
content = orig_message.get("content", {})
if (
content.get("request_id") == req_id
and content.get("requesting_device_id")
== requesting_device_id
):
if await self.store.delete_device_message(
stream_id
):
previous_request_deleted = True
if (
message_content.get("action") == "request_cancellation"
and previous_request_deleted
):
# Do not store the cancellation since we deleted the matching
# request(s) before it reaches the device.
continue
message = {
"content": message_content, "content": message_content,
"type": message_type, "type": message_type,
"sender": sender_user_id, "sender": sender_user_id,
} }
for device_id, message_content in by_device.items() local_messages.setdefault(user_id, {})[device_id] = message
}
if messages_by_device:
local_messages[user_id] = messages_by_device
else: else:
destination = get_domain_from_id(user_id) destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device remote_messages.setdefault(destination, {})[user_id] = by_device

View File

@@ -27,6 +27,7 @@ from typing import (
) )
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
from synapse.api.errors import StoreError
from synapse.logging import issue9533_logger from synapse.logging import issue9533_logger
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
SynapseTags, SynapseTags,
@@ -891,6 +892,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
], ],
) )
async def delete_device_message(self, stream_id: int) -> bool:
"""Delete a specific device message from the message inbox.
Args:
stream_id: the stream ID identifying the message.
Returns:
True if the message has been deleted, False if it didn't exist.
"""
try:
await self.db_pool.simple_delete_one(
"device_inbox",
keyvalues={"stream_id": stream_id},
desc="delete_device_message",
)
except StoreError:
# Deletion failed because device message does not exist
return False
return True
async def get_all_device_messages(
self,
user_id: str,
device_id: str,
) -> List[Tuple[int, str]]:
"""Get all device messages in the inbox from a specific device.
Args:
user_id: the user ID of the device we want to query.
device_id: the device ID of the device we want to query.
Returns:
A list of (stream ID, message content) tuples.
"""
rows = await self.db_pool.simple_select_list(
table="device_inbox",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("stream_id", "message_json"),
desc="get_all_device_messages",
)
return [(r["stream_id"], r["message_json"]) for r in rows]
class DeviceInboxBackgroundUpdateStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"

View File

@@ -19,13 +19,14 @@ from unittest import mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
import synapse
from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict from synapse.types import JsonDict, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@@ -37,6 +38,11 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase): class DeviceTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice_api = mock.Mock() self.appservice_api = mock.Mock()
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
@@ -47,6 +53,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler) assert isinstance(handler, DeviceHandler)
self.handler = handler self.handler = handler
self.msg_handler = hs.get_device_message_handler()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
return hs return hs
@@ -398,6 +406,117 @@ class DeviceTestCase(unittest.HomeserverTestCase):
], ],
) )
@override_config({"experimental_features": {"msc3944_enabled": True}})
def test_duplicated_and_cancelled_room_key_request(self) -> None:
myuser = self.register_user("myuser", "pass")
self.login("myuser", "pass", "device")
self.login("myuser", "pass", "device2")
self.login("myuser", "pass", "device3")
requester = requester = create_requester(myuser)
from_token = self.event_sources.get_current_token()
# This room_key_request is for device3 and should not be deleted.
self.get_success(
self.msg_handler.send_device_message(
requester,
"m.room_key_request",
{
myuser: {
"device3": {
"action": "request",
"request_id": "request_id",
"requesting_device_id": "device",
}
}
},
)
)
for _ in range(0, 2):
self.get_success(
self.msg_handler.send_device_message(
requester,
"m.room_key_request",
{
myuser: {
"device2": {
"action": "request",
"request_id": "request_id",
"requesting_device_id": "device",
}
}
},
)
)
to_token = self.event_sources.get_current_token()
# Test that if we queue 2 identical room_key_request,
# only one is delivered to the device.
res = self.get_success(
self.store.get_messages_for_device(
myuser,
"device2",
from_token.to_device_key,
to_token.to_device_key,
)
)
self.assertEqual(len(res[0]), 1)
# room_key_request for device3 should still be around.
res = self.get_success(
self.store.get_messages_for_device(
myuser,
"device3",
from_token.to_device_key,
to_token.to_device_key,
)
)
self.assertEqual(len(res[0]), 1)
self.get_success(
self.msg_handler.send_device_message(
requester,
"m.room_key_request",
{
myuser: {
"device2": {
"action": "request_cancellation",
"request_id": "request_id",
"requesting_device_id": "device",
}
}
},
)
)
to_token = self.event_sources.get_current_token()
# Test that if we cancel a room_key_request, both previous matching
# requests and the cancelled request are not delivered to the device.
res = self.get_success(
self.store.get_messages_for_device(
myuser,
"device2",
from_token.to_device_key,
to_token.to_device_key,
)
)
self.assertEqual(len(res[0]), 0)
# room_key_request for device3 should still be around.
res = self.get_success(
self.store.get_messages_for_device(
myuser,
"device3",
from_token.to_device_key,
to_token.to_device_key,
)
)
self.assertEqual(len(res[0]), 1)
class DehydrationTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: