Merge pull request #5727 from matrix-org/uhoreg/e2e_cross-signing2-part3
* commit '53d7680e3': Update synapse/storage/data_stores/main/devices.py rename get_devices_by_remote to get_device_updates_by_remote black apply changes as a result of PR review don't error if federation query doesn't have cross-signing keys move get_e2e_cross_signing_key to EndToEndKeyWorkerStore so it works with workers black vendor-prefix the EDU name until MSC1756 is merged into the spec fix unit test add news file update to work with newer code, and fix formatting add missing param make black happy don't crash if the user doesn't have cross-signing keys implement federation parts of cross-signing
This commit is contained in:
1
changelog.d/5727.feature
Normal file
1
changelog.d/5727.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add federation support for cross-signing.
|
||||
@@ -360,20 +360,20 @@ class PerDestinationQueue(object):
|
||||
last_device_list = self._last_device_list_stream_id
|
||||
|
||||
# Retrieve list of new device updates to send to the destination
|
||||
now_stream_id, results = yield self._store.get_devices_by_remote(
|
||||
now_stream_id, results = yield self._store.get_device_updates_by_remote(
|
||||
self._destination, last_device_list, limit=limit
|
||||
)
|
||||
edus = [
|
||||
Edu(
|
||||
origin=self._server_name,
|
||||
destination=self._destination,
|
||||
edu_type="m.device_list_update",
|
||||
edu_type=edu_type,
|
||||
content=content,
|
||||
)
|
||||
for content in results
|
||||
for (edu_type, content) in results
|
||||
]
|
||||
|
||||
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
|
||||
assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
|
||||
|
||||
return (edus, now_stream_id)
|
||||
|
||||
|
||||
@@ -459,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
@defer.inlineCallbacks
|
||||
def on_federation_query_user_devices(self, user_id):
|
||||
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
|
||||
return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
|
||||
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
|
||||
self_signing_key = yield self.store.get_e2e_cross_signing_key(
|
||||
user_id, "self_signing"
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"stream_id": stream_id,
|
||||
"devices": devices,
|
||||
"master_key": master_key,
|
||||
"self_signing_key": self_signing_key,
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_left_room(self, user, room_id):
|
||||
|
||||
@@ -36,6 +36,8 @@ from synapse.types import (
|
||||
get_verify_key_from_cross_signing_key,
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -49,10 +51,19 @@ class E2eKeysHandler(object):
|
||||
self.is_mine = hs.is_mine
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self._edu_updater = SigningKeyEduUpdater(hs, self)
|
||||
|
||||
federation_registry = hs.get_federation_registry()
|
||||
|
||||
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
|
||||
federation_registry.register_edu_handler(
|
||||
"org.matrix.signing_key_update",
|
||||
self._edu_updater.incoming_signing_key_update,
|
||||
)
|
||||
# doesn't really work as part of the generic query API, because the
|
||||
# query request requires an object POST, but we abuse the
|
||||
# "query handler" interface.
|
||||
hs.get_federation_registry().register_query_handler(
|
||||
federation_registry.register_query_handler(
|
||||
"client_keys", self.on_federation_query_client_keys
|
||||
)
|
||||
|
||||
@@ -208,13 +219,15 @@ class E2eKeysHandler(object):
|
||||
if user_id in destination_query:
|
||||
results[user_id] = keys
|
||||
|
||||
for user_id, key in remote_result["master_keys"].items():
|
||||
if user_id in destination_query:
|
||||
cross_signing_keys["master_keys"][user_id] = key
|
||||
if "master_keys" in remote_result:
|
||||
for user_id, key in remote_result["master_keys"].items():
|
||||
if user_id in destination_query:
|
||||
cross_signing_keys["master_keys"][user_id] = key
|
||||
|
||||
for user_id, key in remote_result["self_signing_keys"].items():
|
||||
if user_id in destination_query:
|
||||
cross_signing_keys["self_signing_keys"][user_id] = key
|
||||
if "self_signing_keys" in remote_result:
|
||||
for user_id, key in remote_result["self_signing_keys"].items():
|
||||
if user_id in destination_query:
|
||||
cross_signing_keys["self_signing_keys"][user_id] = key
|
||||
|
||||
except Exception as e:
|
||||
failure = _exception_to_failure(e)
|
||||
@@ -252,7 +265,7 @@ class E2eKeysHandler(object):
|
||||
|
||||
Returns:
|
||||
defer.Deferred[dict[str, dict[str, dict]]]: map from
|
||||
(master|self_signing|user_signing) -> user_id -> key
|
||||
(master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
|
||||
"""
|
||||
master_keys = {}
|
||||
self_signing_keys = {}
|
||||
@@ -344,7 +357,16 @@ class E2eKeysHandler(object):
|
||||
"""
|
||||
device_keys_query = query_body.get("device_keys", {})
|
||||
res = yield self.query_local_devices(device_keys_query)
|
||||
return {"device_keys": res}
|
||||
ret = {"device_keys": res}
|
||||
|
||||
# add in the cross-signing keys
|
||||
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
|
||||
device_keys_query, None
|
||||
)
|
||||
|
||||
ret.update(cross_signing_keys)
|
||||
|
||||
return ret
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
@@ -1058,3 +1080,100 @@ class SignatureListItem:
|
||||
target_user_id = attr.ib()
|
||||
target_device_id = attr.ib()
|
||||
signature = attr.ib()
|
||||
|
||||
|
||||
class SigningKeyEduUpdater(object):
|
||||
"""Handles incoming signing key updates from federation and updates the DB"""
|
||||
|
||||
def __init__(self, hs, e2e_keys_handler):
|
||||
self.store = hs.get_datastore()
|
||||
self.federation = hs.get_federation_client()
|
||||
self.clock = hs.get_clock()
|
||||
self.e2e_keys_handler = e2e_keys_handler
|
||||
|
||||
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
||||
|
||||
# user_id -> list of updates waiting to be handled.
|
||||
self._pending_updates = {}
|
||||
|
||||
# Recently seen stream ids. We don't bother keeping these in the DB,
|
||||
# but they're useful to have them about to reduce the number of spurious
|
||||
# resyncs.
|
||||
self._seen_updates = ExpiringCache(
|
||||
cache_name="signing_key_update_edu",
|
||||
clock=self.clock,
|
||||
max_len=10000,
|
||||
expiry_ms=30 * 60 * 1000,
|
||||
iterable=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def incoming_signing_key_update(self, origin, edu_content):
|
||||
"""Called on incoming signing key update from federation. Responsible for
|
||||
parsing the EDU and adding to pending updates list.
|
||||
|
||||
Args:
|
||||
origin (string): the server that sent the EDU
|
||||
edu_content (dict): the contents of the EDU
|
||||
"""
|
||||
|
||||
user_id = edu_content.pop("user_id")
|
||||
master_key = edu_content.pop("master_key", None)
|
||||
self_signing_key = edu_content.pop("self_signing_key", None)
|
||||
|
||||
if get_domain_from_id(user_id) != origin:
|
||||
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
|
||||
return
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
if not room_ids:
|
||||
# We don't share any rooms with this user. Ignore update, as we
|
||||
# probably won't get any further updates.
|
||||
return
|
||||
|
||||
self._pending_updates.setdefault(user_id, []).append(
|
||||
(master_key, self_signing_key)
|
||||
)
|
||||
|
||||
yield self._handle_signing_key_updates(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_signing_key_updates(self, user_id):
|
||||
"""Actually handle pending updates.
|
||||
|
||||
Args:
|
||||
user_id (string): the user whose updates we are processing
|
||||
"""
|
||||
|
||||
device_handler = self.e2e_keys_handler.device_handler
|
||||
|
||||
with (yield self._remote_edu_linearizer.queue(user_id)):
|
||||
pending_updates = self._pending_updates.pop(user_id, [])
|
||||
if not pending_updates:
|
||||
# This can happen since we batch updates
|
||||
return
|
||||
|
||||
device_ids = []
|
||||
|
||||
logger.info("pending updates: %r", pending_updates)
|
||||
|
||||
for master_key, self_signing_key in pending_updates:
|
||||
if master_key:
|
||||
yield self.store.set_e2e_cross_signing_key(
|
||||
user_id, "master", master_key
|
||||
)
|
||||
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
|
||||
# verify_key is a VerifyKey from signedjson, which uses
|
||||
# .version to denote the portion of the key ID after the
|
||||
# algorithm and colon, which is the device ID
|
||||
device_ids.append(verify_key.version)
|
||||
if self_signing_key:
|
||||
yield self.store.set_e2e_cross_signing_key(
|
||||
user_id, "self_signing", self_signing_key
|
||||
)
|
||||
_, verify_key = get_verify_key_from_cross_signing_key(
|
||||
self_signing_key
|
||||
)
|
||||
device_ids.append(verify_key.version)
|
||||
|
||||
yield device_handler.notify_device_update(user_id, device_ids)
|
||||
|
||||
@@ -37,6 +37,7 @@ from synapse.storage._base import (
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||
from synapse.types import get_verify_key_from_cross_signing_key
|
||||
from synapse.util import batch_iter
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||
|
||||
@@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_remote(self, destination, from_stream_id, limit):
|
||||
"""Get stream of updates to send to remote servers
|
||||
def get_device_updates_by_remote(self, destination, from_stream_id, limit):
|
||||
"""Get a stream of device updates to send to the given remote server.
|
||||
|
||||
Args:
|
||||
destination (str): The host the device updates are intended for
|
||||
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||
limit (int): Maximum number of device updates to return
|
||||
Returns:
|
||||
Deferred[tuple[int, list[dict]]]:
|
||||
Deferred[tuple[int, list[tuple[string,dict]]]]:
|
||||
current stream id (ie, the stream id of the last update included in the
|
||||
response), and the list of updates
|
||||
response), and the list of updates, where each update is a pair of EDU
|
||||
type and EDU contents
|
||||
"""
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
@@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
# stream_id; the rationale being that such a large device list update
|
||||
# is likely an error.
|
||||
updates = yield self.runInteraction(
|
||||
"get_devices_by_remote",
|
||||
self._get_devices_by_remote_txn,
|
||||
"get_device_updates_by_remote",
|
||||
self._get_device_updates_by_remote_txn,
|
||||
destination,
|
||||
from_stream_id,
|
||||
now_stream_id,
|
||||
@@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
if not updates:
|
||||
return now_stream_id, []
|
||||
|
||||
# get the cross-signing keys of the users in the list, so that we can
|
||||
# determine which of the device changes were cross-signing keys
|
||||
users = set(r[0] for r in updates)
|
||||
master_key_by_user = {}
|
||||
self_signing_key_by_user = {}
|
||||
for user in users:
|
||||
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
|
||||
if cross_signing_key:
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||
cross_signing_key
|
||||
)
|
||||
# verify_key is a VerifyKey from signedjson, which uses
|
||||
# .version to denote the portion of the key ID after the
|
||||
# algorithm and colon, which is the device ID
|
||||
master_key_by_user[user] = {
|
||||
"key_info": cross_signing_key,
|
||||
"device_id": verify_key.version,
|
||||
}
|
||||
|
||||
cross_signing_key = yield self.get_e2e_cross_signing_key(
|
||||
user, "self_signing"
|
||||
)
|
||||
if cross_signing_key:
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||
cross_signing_key
|
||||
)
|
||||
self_signing_key_by_user[user] = {
|
||||
"key_info": cross_signing_key,
|
||||
"device_id": verify_key.version,
|
||||
}
|
||||
|
||||
# if we have exceeded the limit, we need to exclude any results with the
|
||||
# same stream_id as the last row.
|
||||
if len(updates) > limit:
|
||||
@@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
# context which created the Edu.
|
||||
|
||||
query_map = {}
|
||||
for update in updates:
|
||||
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
|
||||
cross_signing_keys_by_user = {}
|
||||
for user_id, device_id, update_stream_id, update_context in updates:
|
||||
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
|
||||
# Stop processing updates
|
||||
break
|
||||
|
||||
key = (update[0], update[1])
|
||||
if (
|
||||
user_id in master_key_by_user
|
||||
and device_id == master_key_by_user[user_id]["device_id"]
|
||||
):
|
||||
result = cross_signing_keys_by_user.setdefault(user_id, {})
|
||||
result["master_key"] = master_key_by_user[user_id]["key_info"]
|
||||
elif (
|
||||
user_id in self_signing_key_by_user
|
||||
and device_id == self_signing_key_by_user[user_id]["device_id"]
|
||||
):
|
||||
result = cross_signing_keys_by_user.setdefault(user_id, {})
|
||||
result["self_signing_key"] = self_signing_key_by_user[user_id][
|
||||
"key_info"
|
||||
]
|
||||
else:
|
||||
key = (user_id, device_id)
|
||||
|
||||
update_context = update[3]
|
||||
update_stream_id = update[2]
|
||||
previous_update_stream_id, _ = query_map.get(key, (0, None))
|
||||
|
||||
previous_update_stream_id, _ = query_map.get(key, (0, None))
|
||||
|
||||
if update_stream_id > previous_update_stream_id:
|
||||
query_map[key] = (update_stream_id, update_context)
|
||||
if update_stream_id > previous_update_stream_id:
|
||||
query_map[key] = (update_stream_id, update_context)
|
||||
|
||||
# If we didn't find any updates with a stream_id lower than the cutoff, it
|
||||
# means that there are more than limit updates all of which have the same
|
||||
@@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
# devices, in which case E2E isn't going to work well anyway. We'll just
|
||||
# skip that stream_id and return an empty list, and continue with the next
|
||||
# stream_id next time.
|
||||
if not query_map:
|
||||
if not query_map and not cross_signing_keys_by_user:
|
||||
return stream_id_cutoff, []
|
||||
|
||||
results = yield self._get_device_update_edus_by_remote(
|
||||
destination, from_stream_id, query_map
|
||||
)
|
||||
|
||||
# add the updated cross-signing keys to the results list
|
||||
for user_id, result in iteritems(cross_signing_keys_by_user):
|
||||
result["user_id"] = user_id
|
||||
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
|
||||
results.append(("org.matrix.signing_key_update", result))
|
||||
|
||||
return now_stream_id, results
|
||||
|
||||
def _get_devices_by_remote_txn(
|
||||
def _get_device_updates_by_remote_txn(
|
||||
self, txn, destination, from_stream_id, now_stream_id, limit
|
||||
):
|
||||
"""Return device update information for a given remote destination
|
||||
@@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
Returns:
|
||||
List: List of device updates
|
||||
"""
|
||||
# get the list of device updates that need to be sent
|
||||
sql = """
|
||||
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
|
||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||
@@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
List[Dict]: List of objects representing an device update EDU
|
||||
|
||||
"""
|
||||
devices = yield self.runInteraction(
|
||||
"_get_e2e_device_keys_txn",
|
||||
self._get_e2e_device_keys_txn,
|
||||
query_map.keys(),
|
||||
include_all_devices=True,
|
||||
include_deleted_devices=True,
|
||||
devices = (
|
||||
yield self.runInteraction(
|
||||
"_get_e2e_device_keys_txn",
|
||||
self._get_e2e_device_keys_txn,
|
||||
query_map.keys(),
|
||||
include_all_devices=True,
|
||||
include_deleted_devices=True,
|
||||
)
|
||||
if query_map
|
||||
else {}
|
||||
)
|
||||
|
||||
results = []
|
||||
@@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
else:
|
||||
result["deleted"] = True
|
||||
|
||||
results.append(result)
|
||||
results.append(("m.device_list_update", result))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
"get_received_txn_response",
|
||||
"set_received_txn_response",
|
||||
"get_destination_retry_timings",
|
||||
"get_devices_by_remote",
|
||||
"get_device_updates_by_remote",
|
||||
# Bits that user_directory needs
|
||||
"get_user_directory_stream_pos",
|
||||
"get_current_state_deltas",
|
||||
@@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
retry_timings_res
|
||||
)
|
||||
|
||||
self.datastore.get_devices_by_remote.return_value = (0, [])
|
||||
self.datastore.get_device_updates_by_remote.return_value = (0, [])
|
||||
|
||||
def get_received_txn_response(*args):
|
||||
return defer.succeed(None)
|
||||
|
||||
@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_devices_by_remote(self):
|
||||
def test_get_device_updates_by_remote(self):
|
||||
device_ids = ["device_id1", "device_id2"]
|
||||
|
||||
# Add two device updates with a single stream_id
|
||||
@@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
)
|
||||
|
||||
# Get all device updates ever meant for this remote
|
||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||
"somehost", -1, limit=100
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
self._check_devices_in_updates(device_ids, device_updates)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_devices_by_remote_limited(self):
|
||||
def test_get_device_updates_by_remote_limited(self):
|
||||
# Test breaking the update limit in 1, 101, and 1 device_id segments
|
||||
|
||||
# first add one device
|
||||
@@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
#
|
||||
|
||||
# first we should get a single update
|
||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||
"someotherhost", -1, limit=100
|
||||
)
|
||||
self._check_devices_in_updates(device_ids1, device_updates)
|
||||
|
||||
# Then we should get an empty list back as the 101 devices broke the limit
|
||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||
"someotherhost", now_stream_id, limit=100
|
||||
)
|
||||
self.assertEqual(len(device_updates), 0)
|
||||
|
||||
# The 101 devices should've been cleared, so we should now just get one device
|
||||
# update
|
||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||
"someotherhost", now_stream_id, limit=100
|
||||
)
|
||||
self._check_devices_in_updates(device_ids3, device_updates)
|
||||
@@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
"""Check that an specific device ids exist in a list of device update EDUs"""
|
||||
self.assertEqual(len(device_updates), len(expected_device_ids))
|
||||
|
||||
received_device_ids = {update["device_id"] for update in device_updates}
|
||||
received_device_ids = {
|
||||
update["device_id"] for edu_type, update in device_updates
|
||||
}
|
||||
self.assertEqual(received_device_ids, set(expected_device_ids))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
Reference in New Issue
Block a user