1
0

Merge pull request #5727 from matrix-org/uhoreg/e2e_cross-signing2-part3

* commit '53d7680e3':
  Update synapse/storage/data_stores/main/devices.py
  rename get_devices_by_remote to get_device_updates_by_remote
  black
  apply changes as a result of PR review
  don't error if federation query doesn't have cross-signing keys
  move get_e2e_cross_signing_key to EndToEndKeyWorkerStore so it works with workers
  black
  vendor-prefix the EDU name until MSC1756 is merged into the spec
  fix unit test
  add news file
  update to work with newer code, and fix formatting
  add missing param
  make black happy
  don't crash if the user doesn't have cross-signing keys
  implement federation parts of cross-signing
This commit is contained in:
Andrew Morgan
2020-03-16 19:23:32 +00:00
7 changed files with 241 additions and 47 deletions

1
changelog.d/5727.feature Normal file
View File

@@ -0,0 +1 @@
Add federation support for cross-signing.

View File

@@ -360,20 +360,20 @@ class PerDestinationQueue(object):
last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination
now_stream_id, results = yield self._store.get_devices_by_remote(
now_stream_id, results = yield self._store.get_device_updates_by_remote(
self._destination, last_device_list, limit=limit
)
edus = [
Edu(
origin=self._server_name,
destination=self._destination,
edu_type="m.device_list_update",
edu_type=edu_type,
content=content,
)
for content in results
for (edu_type, content) in results
]
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
return (edus, now_stream_id)

View File

@@ -459,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = yield self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
)
return {
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
"master_key": master_key,
"self_signing_key": self_signing_key,
}
@defer.inlineCallbacks
def user_left_room(self, user, room_id):

View File

@@ -36,6 +36,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key,
)
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -49,10 +51,19 @@ class E2eKeysHandler(object):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
self._edu_updater = SigningKeyEduUpdater(hs, self)
federation_registry = hs.get_federation_registry()
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
federation_registry.register_edu_handler(
"org.matrix.signing_key_update",
self._edu_updater.incoming_signing_key_update,
)
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
hs.get_federation_registry().register_query_handler(
federation_registry.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@@ -208,13 +219,15 @@ class E2eKeysHandler(object):
if user_id in destination_query:
results[user_id] = keys
for user_id, key in remote_result["master_keys"].items():
if user_id in destination_query:
cross_signing_keys["master_keys"][user_id] = key
if "master_keys" in remote_result:
for user_id, key in remote_result["master_keys"].items():
if user_id in destination_query:
cross_signing_keys["master_keys"][user_id] = key
for user_id, key in remote_result["self_signing_keys"].items():
if user_id in destination_query:
cross_signing_keys["self_signing_keys"][user_id] = key
if "self_signing_keys" in remote_result:
for user_id, key in remote_result["self_signing_keys"].items():
if user_id in destination_query:
cross_signing_keys["self_signing_keys"][user_id] = key
except Exception as e:
failure = _exception_to_failure(e)
@@ -252,7 +265,7 @@ class E2eKeysHandler(object):
Returns:
defer.Deferred[dict[str, dict[str, dict]]]: map from
(master|self_signing|user_signing) -> user_id -> key
(master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
@@ -344,7 +357,16 @@ class E2eKeysHandler(object):
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
return {"device_keys": res}
ret = {"device_keys": res}
# add in the cross-signing keys
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
device_keys_query, None
)
ret.update(cross_signing_keys)
return ret
@trace
@defer.inlineCallbacks
@@ -1058,3 +1080,100 @@ class SignatureListItem:
target_user_id = attr.ib()
target_device_id = attr.ib()
signature = attr.ib()
class SigningKeyEduUpdater(object):
"""Handles incoming signing key updates from federation and updates the DB"""
def __init__(self, hs, e2e_keys_handler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.e2e_keys_handler = e2e_keys_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
self._pending_updates = {}
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
cache_name="signing_key_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)
@defer.inlineCallbacks
def incoming_signing_key_update(self, origin, edu_content):
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
Args:
origin (string): the server that sent the EDU
edu_content (dict): the contents of the EDU
"""
user_id = edu_content.pop("user_id")
master_key = edu_content.pop("master_key", None)
self_signing_key = edu_content.pop("self_signing_key", None)
if get_domain_from_id(user_id) != origin:
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
return
room_ids = yield self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
self._pending_updates.setdefault(user_id, []).append(
(master_key, self_signing_key)
)
yield self._handle_signing_key_updates(user_id)
@defer.inlineCallbacks
def _handle_signing_key_updates(self, user_id):
"""Actually handle pending updates.
Args:
user_id (string): the user whose updates we are processing
"""
device_handler = self.e2e_keys_handler.device_handler
with (yield self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
return
device_ids = []
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
if master_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "master", master_key
)
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
device_ids.append(verify_key.version)
if self_signing_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
_, verify_key = get_verify_key_from_cross_signing_key(
self_signing_key
)
device_ids.append(verify_key.version)
yield device_handler.notify_device_update(user_id, device_ids)

View File

@@ -37,6 +37,7 @@ from synapse.storage._base import (
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
@trace
@defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit):
"""Get stream of updates to send to remote servers
def get_device_updates_by_remote(self, destination, from_stream_id, limit):
"""Get a stream of device updates to send to the given remote server.
Args:
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
limit (int): Maximum number of device updates to return
Returns:
Deferred[tuple[int, list[dict]]]:
Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
response), and the list of updates
response), and the list of updates, where each update is a pair of EDU
type and EDU contents
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
"get_devices_by_remote",
self._get_devices_by_remote_txn,
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
@@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates:
return now_stream_id, []
# get the cross-signing keys of the users in the list, so that we can
# determine which of the device changes were cross-signing keys
users = set(r[0] for r in updates)
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
master_key_by_user[user] = {
"key_info": cross_signing_key,
"device_id": verify_key.version,
}
cross_signing_key = yield self.get_e2e_cross_signing_key(
user, "self_signing"
)
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
)
self_signing_key_by_user[user] = {
"key_info": cross_signing_key,
"device_id": verify_key.version,
}
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
@@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
# context which created the Edu.
query_map = {}
for update in updates:
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
key = (update[0], update[1])
if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = master_key_by_user[user_id]["key_info"]
elif (
user_id in self_signing_key_by_user
and device_id == self_signing_key_by_user[user_id]["device_id"]
):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["self_signing_key"] = self_signing_key_by_user[user_id][
"key_info"
]
else:
key = (user_id, device_id)
update_context = update[3]
update_stream_id = update[2]
previous_update_stream_id, _ = query_map.get(key, (0, None))
previous_update_stream_id, _ = query_map.get(key, (0, None))
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
@@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map:
if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
# add the updated cross-signing keys to the results list
for user_id, result in iteritems(cross_signing_keys_by_user):
result["user_id"] = user_id
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
results.append(("org.matrix.signing_key_update", result))
return now_stream_id, results
def _get_devices_by_remote_txn(
def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
@@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
List: List of device updates
"""
# get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
List[Dict]: List of objects representing an device update EDU
"""
devices = yield self.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
devices = (
yield self.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
)
if query_map
else {}
)
results = []
@@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
result["deleted"] = True
results.append(result)
results.append(("m.device_list_update", result))
return results

View File

@@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
"get_devices_by_remote",
"get_device_updates_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
@@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
self.datastore.get_devices_by_remote.return_value = (0, [])
self.datastore.get_device_updates_by_remote.return_value = (0, [])
def get_received_txn_response(*args):
return defer.succeed(None)

View File

@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
def test_get_devices_by_remote(self):
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
@@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# Get all device updates ever meant for this remote
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"somehost", -1, limit=100
)
@@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self._check_devices_in_updates(device_ids, device_updates)
@defer.inlineCallbacks
def test_get_devices_by_remote_limited(self):
def test_get_device_updates_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments
# first add one device
@@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
#
# first we should get a single update
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", -1, limit=100
)
self._check_devices_in_updates(device_ids1, device_updates)
# Then we should get an empty list back as the 101 devices broke the limit
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self.assertEqual(len(device_updates), 0)
# The 101 devices should've been cleared, so we should now just get one device
# update
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self._check_devices_in_updates(device_ids3, device_updates)
@@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
received_device_ids = {update["device_id"] for update in device_updates}
received_device_ids = {
update["device_id"] for edu_type, update in device_updates
}
self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks