Compare commits
2 Commits
develop
...
mv/key_req
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
921fa8f9ce | ||
|
|
5047c01d3f |
1
changelog.d/15808.misc
Normal file
1
changelog.d/15808.misc
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Ignore key request if the device inbox is already big.
|
||||||
@@ -39,6 +39,9 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
INBOX_SIZE_LIMIT_FOR_KEY_REQUEST = 100
|
||||||
|
|
||||||
|
|
||||||
class DeviceMessageHandler:
|
class DeviceMessageHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
"""
|
"""
|
||||||
@@ -166,7 +169,7 @@ class DeviceMessageHandler:
|
|||||||
found marks the remote cache for the user as stale.
|
found marks the remote cache for the user as stale.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if message_type != "m.room_key_request":
|
if message_type != ToDeviceEventTypes.RoomKeyRequest:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get the sending device IDs
|
# Get the sending device IDs
|
||||||
@@ -286,10 +289,16 @@ class DeviceMessageHandler:
|
|||||||
"org.matrix.opentracing_context": json_encoder.encode(context),
|
"org.matrix.opentracing_context": json_encoder.encode(context),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device_inbox_size_limit = None
|
||||||
|
if message_type == ToDeviceEventTypes.RoomKeyRequest and self.is_mine(
|
||||||
|
UserID.from_string(user_id)
|
||||||
|
):
|
||||||
|
device_inbox_size_limit = INBOX_SIZE_LIMIT_FOR_KEY_REQUEST
|
||||||
|
|
||||||
# Add messages to the database.
|
# Add messages to the database.
|
||||||
# Retrieve the stream id of the last-processed to-device message.
|
# Retrieve the stream id of the last-processed to-device message.
|
||||||
last_stream_id = await self.store.add_messages_to_device_inbox(
|
last_stream_id = await self.store.add_messages_to_device_inbox(
|
||||||
local_messages, remote_edu_contents
|
local_messages, remote_edu_contents, device_inbox_size_limit
|
||||||
)
|
)
|
||||||
|
|
||||||
# Notify listeners that there are new to-device messages to process,
|
# Notify listeners that there are new to-device messages to process,
|
||||||
|
|||||||
@@ -650,6 +650,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
self,
|
self,
|
||||||
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
|
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
|
||||||
remote_messages_by_destination: Dict[str, JsonDict],
|
remote_messages_by_destination: Dict[str, JsonDict],
|
||||||
|
size_limit: Optional[int] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Used to send messages from this server.
|
"""Used to send messages from this server.
|
||||||
|
|
||||||
@@ -666,11 +667,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
assert self._can_write_to_device
|
assert self._can_write_to_device
|
||||||
|
|
||||||
def add_messages_txn(
|
def add_messages_txn(
|
||||||
txn: LoggingTransaction, now_ms: int, stream_id: int
|
txn: LoggingTransaction,
|
||||||
|
now_ms: int,
|
||||||
|
stream_id: int,
|
||||||
|
size_limit: Optional[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
# Add the local messages directly to the local inbox.
|
# Add the local messages directly to the local inbox.
|
||||||
self._add_messages_to_local_device_inbox_txn(
|
self._add_messages_to_local_device_inbox_txn(
|
||||||
txn, stream_id, local_messages_by_user_then_device
|
txn, stream_id, local_messages_by_user_then_device, size_limit
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the remote messages to the federation outbox.
|
# Add the remote messages to the federation outbox.
|
||||||
@@ -731,7 +735,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
async with self._device_inbox_id_gen.get_next() as stream_id:
|
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||||
now_ms = self._clock.time_msec()
|
now_ms = self._clock.time_msec()
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
"add_messages_to_device_inbox",
|
||||||
|
add_messages_txn,
|
||||||
|
now_ms,
|
||||||
|
stream_id,
|
||||||
|
size_limit,
|
||||||
)
|
)
|
||||||
for user_id in local_messages_by_user_then_device.keys():
|
for user_id in local_messages_by_user_then_device.keys():
|
||||||
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
|
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
|
||||||
@@ -802,11 +810,23 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
stream_id: int,
|
stream_id: int,
|
||||||
messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
|
messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
|
||||||
|
size_limit: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self._can_write_to_device
|
assert self._can_write_to_device
|
||||||
|
|
||||||
local_by_user_then_device = {}
|
local_by_user_then_device = {}
|
||||||
for user_id, messages_by_device in messages_by_user_then_device.items():
|
for user_id, messages_by_device in messages_by_user_then_device.items():
|
||||||
|
inbox_sizes = {}
|
||||||
|
if size_limit:
|
||||||
|
sql = """
|
||||||
|
SELECT device_id, COUNT(*) FROM device_inbox
|
||||||
|
WHERE user_id = ?
|
||||||
|
GROUP BY device_id
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (user_id,))
|
||||||
|
for r in txn:
|
||||||
|
inbox_sizes[r[0]] = r[1]
|
||||||
|
|
||||||
messages_json_for_user = {}
|
messages_json_for_user = {}
|
||||||
devices = list(messages_by_device.keys())
|
devices = list(messages_by_device.keys())
|
||||||
if len(devices) == 1 and devices[0] == "*":
|
if len(devices) == 1 and devices[0] == "*":
|
||||||
@@ -822,9 +842,10 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
message_json = json_encoder.encode(messages_by_device["*"])
|
message_json = json_encoder.encode(messages_by_device["*"])
|
||||||
for device_id in devices:
|
for device_id in devices:
|
||||||
# Add the message for all devices for this user on this
|
if size_limit is None or inbox_sizes.get(device_id, 0) < size_limit:
|
||||||
# server.
|
# Add the message for all devices for this user on this
|
||||||
messages_json_for_user[device_id] = message_json
|
# server.
|
||||||
|
messages_json_for_user[device_id] = message_json
|
||||||
else:
|
else:
|
||||||
if not devices:
|
if not devices:
|
||||||
continue
|
continue
|
||||||
@@ -857,7 +878,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
message_json = json_encoder.encode(msg)
|
message_json = json_encoder.encode(msg)
|
||||||
|
|
||||||
messages_json_for_user[device_id] = message_json
|
if size_limit is None or inbox_sizes.get(device_id, 0) < size_limit:
|
||||||
|
messages_json_for_user[device_id] = message_json
|
||||||
|
|
||||||
if messages_json_for_user:
|
if messages_json_for_user:
|
||||||
local_by_user_then_device[user_id] = messages_json_for_user
|
local_by_user_then_device[user_id] = messages_json_for_user
|
||||||
|
|||||||
@@ -23,20 +23,28 @@ from synapse.api.constants import RoomEncryptionAlgorithms
|
|||||||
from synapse.api.errors import NotFoundError, SynapseError
|
from synapse.api.errors import NotFoundError, SynapseError
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
|
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
|
||||||
|
from synapse.handlers.devicemessage import INBOX_SIZE_LIMIT_FOR_KEY_REQUEST
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.appservice import _make_exclusive_regex
|
from synapse.storage.databases.main.appservice import _make_exclusive_regex
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
|
||||||
user1 = "@boris:aaa"
|
user1 = "@boris:aaa"
|
||||||
user2 = "@theresa:bbb"
|
user2 = "@theresa:bbb"
|
||||||
|
|
||||||
|
|
||||||
class DeviceTestCase(unittest.HomeserverTestCase):
|
class DeviceTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
synapse.rest.client.login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.appservice_api = mock.Mock()
|
self.appservice_api = mock.Mock()
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
@@ -47,6 +55,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
|||||||
handler = hs.get_device_handler()
|
handler = hs.get_device_handler()
|
||||||
assert isinstance(handler, DeviceHandler)
|
assert isinstance(handler, DeviceHandler)
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
self.msg_handler = hs.get_device_message_handler()
|
||||||
|
self.event_sources = hs.get_event_sources()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
@@ -398,6 +408,79 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_room_key_request_limit(self) -> None:
|
||||||
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
myuser = self.register_user("myuser", "pass")
|
||||||
|
self.login("myuser", "pass", "device")
|
||||||
|
self.login("myuser", "pass", "device2")
|
||||||
|
|
||||||
|
requester = requester = create_requester(myuser)
|
||||||
|
|
||||||
|
from_token = self.event_sources.get_current_token()
|
||||||
|
|
||||||
|
# for i in range(0, INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2):
|
||||||
|
# self.get_success(
|
||||||
|
# self.msg_handler.send_device_message(
|
||||||
|
# requester,
|
||||||
|
# "m.room_key",
|
||||||
|
# {
|
||||||
|
# myuser2: {
|
||||||
|
# "device": {
|
||||||
|
# "algorithm": "m.megolm.v1.aes-sha2",
|
||||||
|
# "room_id": "!Cuyf34gef24t:localhost",
|
||||||
|
# "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
|
||||||
|
# "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY..."
|
||||||
|
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
# to_token = self.event_sources.get_current_token()
|
||||||
|
|
||||||
|
# res = self.get_success(self.store.get_messages_for_device(
|
||||||
|
# myuser2,
|
||||||
|
# "device",
|
||||||
|
# from_token.to_device_key,
|
||||||
|
# to_token.to_device_key,
|
||||||
|
# INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 5,
|
||||||
|
# ))
|
||||||
|
# self.assertEqual(len(res[0]), INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2)
|
||||||
|
|
||||||
|
# from_token = to_token
|
||||||
|
|
||||||
|
for i in range(0, INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2):
|
||||||
|
self.get_success(
|
||||||
|
self.msg_handler.send_device_message(
|
||||||
|
requester,
|
||||||
|
"m.room_key_request",
|
||||||
|
{
|
||||||
|
myuser: {
|
||||||
|
"device2": {
|
||||||
|
"action": "request",
|
||||||
|
"request_id": f"request_id_{i}",
|
||||||
|
"requesting_device_id": "device",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
to_token = self.event_sources.get_current_token()
|
||||||
|
|
||||||
|
res = self.get_success(
|
||||||
|
self.store.get_messages_for_device(
|
||||||
|
myuser,
|
||||||
|
"device2",
|
||||||
|
from_token.to_device_key,
|
||||||
|
to_token.to_device_key,
|
||||||
|
INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(len(res[0]), INBOX_SIZE_LIMIT_FOR_KEY_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
class DehydrationTestCase(unittest.HomeserverTestCase):
|
class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
|||||||
Reference in New Issue
Block a user