Compare commits
5 Commits
erikj/fix_
...
rav/out_of
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9347b961b2 | ||
|
|
6dbad83998 | ||
|
|
4c586567f6 | ||
|
|
27546ac171 | ||
|
|
dd45ba4d67 |
1
changelog.d/16579.feature
Normal file
1
changelog.d/16579.feature
Normal file
@@ -0,0 +1 @@
|
||||
Experimental support for [MSC4072](https://github.com/matrix-org/matrix-spec-proposals/pull/4072): Return a result for all devices requested in a `/keys/claim` request.
|
||||
@@ -15,7 +15,6 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import attr
|
||||
import attr.validators
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
@@ -419,3 +418,9 @@ class ExperimentalConfig(Config):
|
||||
self.msc4028_push_encrypted_events = experimental.get(
|
||||
"msc4028_push_encrypted_events", False
|
||||
)
|
||||
|
||||
# MSC4072: Return an empty dict from /keys/claim for unknown devices or those
|
||||
# with exhausted OTKs
|
||||
self.msc4072_empty_dict_for_exhausted_devices = experimental.get(
|
||||
"msc4072_empty_dict_for_exhausted_devices", False
|
||||
)
|
||||
|
||||
@@ -84,7 +84,7 @@ from synapse.replication.http.federation import (
|
||||
from synapse.storage.databases.main.lock import Lock
|
||||
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
|
||||
from synapse.storage.roommember import MemberSummary
|
||||
from synapse.types import JsonDict, StateMap, get_domain_from_id
|
||||
from synapse.types import JsonDict, JsonSerializable, StateMap, get_domain_from_id
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
@@ -1000,19 +1000,13 @@ class FederationServer(FederationBase):
|
||||
self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
|
||||
) -> Dict[str, Any]:
|
||||
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
||||
results = await self._e2e_keys_handler.claim_local_one_time_keys(
|
||||
query, always_include_fallback_keys=always_include_fallback_keys
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
|
||||
await self._e2e_keys_handler.claim_local_one_time_keys(
|
||||
query,
|
||||
always_include_fallback_keys=always_include_fallback_keys,
|
||||
result_dict=json_result,
|
||||
)
|
||||
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for result in results:
|
||||
for user_id, device_keys in result.items():
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, key in keys.items():
|
||||
json_result.setdefault(user_id, {}).setdefault(device_id, {})[
|
||||
key_id
|
||||
] = key
|
||||
|
||||
logger.info(
|
||||
"Claimed one-time-keys: %s",
|
||||
",".join(
|
||||
|
||||
@@ -861,7 +861,7 @@ class ApplicationServicesHandler:
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||
A map of user ID -> a map device ID -> a map of key ID -> key.
|
||||
|
||||
A copy of the input which has not been fulfilled (either because
|
||||
they are not appservice users or the appservice does not support
|
||||
|
||||
@@ -32,6 +32,7 @@ from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
JsonMapping,
|
||||
JsonSerializable,
|
||||
UserID,
|
||||
get_domain_from_id,
|
||||
get_verify_key_from_cross_signing_key,
|
||||
@@ -560,7 +561,8 @@ class E2eKeysHandler:
|
||||
self,
|
||||
local_query: List[Tuple[str, str, str, int]],
|
||||
always_include_fallback_keys: bool,
|
||||
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
|
||||
result_dict: Dict[str, Dict[str, Dict[str, JsonSerializable]]],
|
||||
) -> None:
|
||||
"""Claim one time keys for local users.
|
||||
|
||||
1. Attempt to claim OTKs from the database.
|
||||
@@ -570,18 +572,34 @@ class E2eKeysHandler:
|
||||
Args:
|
||||
local_query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
always_include_fallback_keys: True to always include fallback keys.
|
||||
|
||||
Returns:
|
||||
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||
result_dict: A dict to update with the results.
|
||||
{user_id -> { device_id -> { key_id -> key string/object }}}
|
||||
"""
|
||||
|
||||
def update_result_dict(
|
||||
results: Mapping[str, Mapping[str, Mapping[str, JsonSerializable]]]
|
||||
) -> None:
|
||||
"""Stash results from a store query in `result_dict`"""
|
||||
for user_id, device_keys in results.items():
|
||||
user_result_dict = result_dict.setdefault(user_id, {})
|
||||
for device_id, keys in device_keys.items():
|
||||
device_result_dict = user_result_dict.setdefault(device_id, {})
|
||||
device_result_dict.update(keys)
|
||||
|
||||
# Cap the number of OTKs that can be claimed at once to avoid abuse.
|
||||
local_query = [
|
||||
(user_id, device_id, algorithm, min(count, 5))
|
||||
for user_id, device_id, algorithm, count in local_query
|
||||
]
|
||||
|
||||
# prepopulate the response to make sure that all queried users/devices are
|
||||
# included, even if the user/device is unknown or has run out of OTKs
|
||||
if self.config.experimental.msc4072_empty_dict_for_exhausted_devices:
|
||||
for user_id, device_id, _, _ in local_query:
|
||||
result_dict.setdefault(user_id, {}).setdefault(device_id, {})
|
||||
|
||||
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
|
||||
update_result_dict(otk_results)
|
||||
|
||||
# If the application services have not provided any keys via the C-S
|
||||
# API, query it directly for one-time keys.
|
||||
@@ -592,6 +610,7 @@ class E2eKeysHandler:
|
||||
appservice_results,
|
||||
not_found,
|
||||
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
|
||||
update_result_dict(appservice_results)
|
||||
else:
|
||||
appservice_results = {}
|
||||
|
||||
@@ -646,10 +665,7 @@ class E2eKeysHandler:
|
||||
# For each user that does not have a one-time keys available, see if
|
||||
# there is a fallback key.
|
||||
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
|
||||
|
||||
# Return the results in order, each item from the input query should
|
||||
# only appear once in the combined list.
|
||||
return (otk_results, appservice_results, fallback_results)
|
||||
update_result_dict(fallback_results)
|
||||
|
||||
@trace
|
||||
async def claim_one_time_keys(
|
||||
@@ -659,6 +675,25 @@ class E2eKeysHandler:
|
||||
timeout: Optional[int],
|
||||
always_include_fallback_keys: bool,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Handle a /keys/claim request.
|
||||
|
||||
Handles requests for local users with a db lookup, and makes federation
|
||||
requests for remote users.
|
||||
|
||||
Args:
|
||||
query: map from user ID, to map from device ID, to map from algorithm name
|
||||
to number of keys needed
|
||||
(``{user_id: {device_id: {algorithm: number_of keys}}}``)
|
||||
|
||||
user: The user id of the requesting user
|
||||
|
||||
timeout: number of milliseconds to wait for the response from remote servers.
|
||||
``config.federation.client_timeout_ms`` by default.
|
||||
|
||||
always_include_fallback_keys: True to always include fallback keys, even
|
||||
for devices which still have one-time keys.
|
||||
"""
|
||||
local_query: List[Tuple[str, str, str, int]] = []
|
||||
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
|
||||
|
||||
@@ -672,22 +707,19 @@ class E2eKeysHandler:
|
||||
domain = get_domain_from_id(user_id)
|
||||
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
|
||||
|
||||
set_tag("local_key_query", str(local_query))
|
||||
set_tag("remote_key_query", str(remote_queries))
|
||||
|
||||
results = await self.claim_local_one_time_keys(
|
||||
local_query, always_include_fallback_keys
|
||||
log_kv(
|
||||
{
|
||||
"message": "claiming one time keys",
|
||||
"local query": local_query,
|
||||
"remote queries, by server": remote_queries,
|
||||
}
|
||||
)
|
||||
|
||||
# A map of user ID -> device ID -> key ID -> key.
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for result in results:
|
||||
for user_id, device_keys in result.items():
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, key in keys.items():
|
||||
json_result.setdefault(user_id, {}).setdefault(
|
||||
device_id, {}
|
||||
).update({key_id: key})
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
|
||||
await self.claim_local_one_time_keys(
|
||||
local_query, always_include_fallback_keys, json_result
|
||||
)
|
||||
|
||||
# Remote failures.
|
||||
failures: Dict[str, JsonDict] = {}
|
||||
@@ -700,9 +732,18 @@ class E2eKeysHandler:
|
||||
remote_result = await self.federation.claim_client_keys(
|
||||
user, destination, device_keys, timeout=timeout
|
||||
)
|
||||
for user_id, keys in remote_result["one_time_keys"].items():
|
||||
if user_id in device_keys:
|
||||
json_result[user_id] = keys
|
||||
try:
|
||||
destination_result = filter_remote_claimed_keys(
|
||||
device_keys,
|
||||
remote_result,
|
||||
self.config.experimental.msc4072_empty_dict_for_exhausted_devices,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error parsing /keys/claim response from server {destination}",
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
failure = _exception_to_failure(e)
|
||||
@@ -710,6 +751,11 @@ class E2eKeysHandler:
|
||||
set_tag("error", True)
|
||||
set_tag("reason", str(failure))
|
||||
|
||||
else:
|
||||
# only populate json_result once we know there will not be an entry in
|
||||
# failures for this destination.
|
||||
json_result.update(destination_result)
|
||||
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
@@ -1625,3 +1671,51 @@ class SigningKeyEduUpdater:
|
||||
device_ids = device_ids + new_device_ids
|
||||
|
||||
await self._device_handler.notify_device_update(user_id, device_ids)
|
||||
|
||||
|
||||
def filter_remote_claimed_keys(
|
||||
destination_query: Dict[str, Dict[str, Dict[str, int]]],
|
||||
remote_response: JsonDict,
|
||||
msc4072_empty_dict_for_exhausted_devices: bool,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Process the response from a federation /keys/claim request
|
||||
|
||||
Checks that there are no redundant entries, and that all the entries that
|
||||
should be there are present.
|
||||
|
||||
Args:
|
||||
destination_query: user->device->key map that was sent in the request to
|
||||
this server
|
||||
remote_response: response from the remote server
|
||||
msc4072_empty_dict_for_exhausted_devices: true to include an entry in the
|
||||
result for every queried device
|
||||
|
||||
Returns:
|
||||
user->device->key map to be merged into the results
|
||||
"""
|
||||
remote_otks = remote_response["one_time_keys"]
|
||||
|
||||
destination_result: JsonDict = {}
|
||||
|
||||
if msc4072_empty_dict_for_exhausted_devices:
|
||||
# We need to make sure there is an entry in destination_result for
|
||||
# every queried (user, device) even if the remote server did not
|
||||
# populate it; so we iterate the query and populate
|
||||
# destination_result based on the federation result.
|
||||
for user_id, user_query in destination_query.items():
|
||||
remote_user_result = remote_otks.get(user_id, {})
|
||||
destination_user_result = destination_result[user_id] = {}
|
||||
for device_id in user_query.keys():
|
||||
destination_user_result[device_id] = remote_user_result.get(
|
||||
device_id, {}
|
||||
)
|
||||
else:
|
||||
# We need to make sure that remote servers do not poison the
|
||||
# result with data for users which do not belong to it, so we only
|
||||
# copy data for users that were queried.
|
||||
for user_id, keys in remote_otks.items():
|
||||
if user_id in destination_query:
|
||||
destination_result[user_id] = keys
|
||||
|
||||
return destination_result
|
||||
|
||||
@@ -52,7 +52,7 @@ from synapse.storage.database import (
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
from synapse.types import JsonDict, JsonMapping, JsonSerializable
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.cancellation import cancellable
|
||||
@@ -1112,7 +1112,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
async def claim_e2e_one_time_keys(
|
||||
self, query_list: Iterable[Tuple[str, str, str, int]]
|
||||
) -> Tuple[
|
||||
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
|
||||
Dict[str, Dict[str, Dict[str, JsonSerializable]]],
|
||||
List[Tuple[str, str, str, int]],
|
||||
]:
|
||||
"""Take a list of one time keys out of the database.
|
||||
|
||||
@@ -1121,7 +1122,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
|
||||
Returns:
|
||||
A tuple pf:
|
||||
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||
A map of user ID -> a map device ID -> a map of key ID -> key
|
||||
|
||||
A copy of the input which has not been fulfilled.
|
||||
"""
|
||||
@@ -1214,7 +1215,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
|
||||
]
|
||||
|
||||
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
results: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
|
||||
missing: List[Tuple[str, str, str, int]] = []
|
||||
for user_id, device_id, algorithm, count in query_list:
|
||||
if self.database_engine.supports_returning:
|
||||
@@ -1240,7 +1241,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
device_id, {}
|
||||
)
|
||||
for claim_row in claim_rows:
|
||||
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
|
||||
# The shape of the key depends on the algorithm: it is a dict for
|
||||
# signed_curve25519, or a string for curve25519. In general, it
|
||||
# is whatever the client chose to upload, since we dont validate it.
|
||||
decoded_key: JsonSerializable = json_decoder.decode(claim_row[1])
|
||||
device_results[claim_row[0]] = decoded_key
|
||||
# Did we get enough OTKs?
|
||||
count -= len(claim_rows)
|
||||
if count:
|
||||
@@ -1250,7 +1255,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
|
||||
async def claim_e2e_fallback_keys(
|
||||
self, query_list: Iterable[Tuple[str, str, str, bool]]
|
||||
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
|
||||
) -> Dict[str, Dict[str, Dict[str, JsonSerializable]]]:
|
||||
"""Take a list of fallback keys out of the database.
|
||||
|
||||
Args:
|
||||
@@ -1260,7 +1265,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
Returns:
|
||||
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||
"""
|
||||
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
results: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
|
||||
for user_id, device_id, algorithm, mark_as_used in query_list:
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="e2e_fallback_keys_json",
|
||||
@@ -1298,7 +1303,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
)
|
||||
|
||||
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
|
||||
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
|
||||
# The shape of the key depends on the algorithm: it is a dict for
|
||||
# signed_curve25519, or a string for curve25519. In general, it
|
||||
# is whatever the client chose to upload, since we dont validate it.
|
||||
decoded_key: JsonSerializable = json_decoder.decode(key_json)
|
||||
device_results[f"{algorithm}:{key_id}"] = decoded_key
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -144,35 +144,81 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
SynapseError,
|
||||
)
|
||||
|
||||
def test_claim_one_time_key(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
keys = {"alg1:k1": "key1"}
|
||||
@parameterized.expand([(True,), (False,)])
|
||||
def test_claim_one_time_key(self, msc4072: bool) -> None:
|
||||
self.hs.config.experimental.msc4072_empty_dict_for_exhausted_devices = msc4072
|
||||
|
||||
local_known_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
local_unknown_user = "@charlie:" + self.hs.hostname
|
||||
|
||||
remote_known_user = "@dave:xyz"
|
||||
remote_unknown_user = "@errol:xyz"
|
||||
|
||||
# upload a key for the local user
|
||||
res = self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
local_known_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(
|
||||
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
|
||||
)
|
||||
|
||||
# mock out the response for remote users. We pretend that the remote server
|
||||
# hasn't heard of MSC4072 and returns an incomplete result. (Even once
|
||||
# MSC4072 is stable, we still need to handle incomplete results.)
|
||||
#
|
||||
# we also include a spurious result to check it gets filtered out.
|
||||
self.hs.get_federation_client().claim_client_keys = mock.AsyncMock( # type: ignore[method-assign]
|
||||
return_value={
|
||||
"one_time_keys": {
|
||||
remote_known_user: {"ghi": {"alg1": "keykey"}},
|
||||
"@other:xyz": {"zzz": {"alg1": "dodgykey"}},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
res2 = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
{
|
||||
local_known_user: {device_id: {"alg1": 1}, "abc": {"alg2": 1}},
|
||||
local_unknown_user: {"def": {"alg1": 1}},
|
||||
remote_known_user: {"ghi": {"alg1": 1}, "jkl": {"alg1": 1}},
|
||||
remote_unknown_user: {"mno": {"alg1": 1}},
|
||||
},
|
||||
self.requester,
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
res2,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
|
||||
},
|
||||
)
|
||||
|
||||
if msc4072:
|
||||
# empty result for each unknown device
|
||||
self.assertEqual(
|
||||
res2,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {
|
||||
local_known_user: {device_id: {"alg1:k1": "key1"}, "abc": {}},
|
||||
local_unknown_user: {"def": {}},
|
||||
remote_known_user: {"ghi": {"alg1": "keykey"}, "jkl": {}},
|
||||
remote_unknown_user: {"mno": {}},
|
||||
},
|
||||
},
|
||||
)
|
||||
else:
|
||||
# only known devices
|
||||
self.assertEqual(
|
||||
res2,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {
|
||||
local_known_user: {device_id: {"alg1:k1": "key1"}},
|
||||
remote_known_user: {"ghi": {"alg1": "keykey"}},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_fallback_key(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
|
||||
Reference in New Issue
Block a user