diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 28c12753c1..833a7a4ff6 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -280,33 +280,25 @@ class E2eKeysHandler(object): defer.Deferred[dict[str, dict[str, dict]]]: map from (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key """ - master_keys = {} - self_signing_keys = {} user_signing_keys = {} - for user_id in query: - # XXX: consider changing the store functions to allow querying - # multiple users simultaneously. - key = yield self.store.get_e2e_cross_signing_key( - user_id, "master", from_user_id - ) - if key: - master_keys[user_id] = key + users = list(query) - key = yield self.store.get_e2e_cross_signing_key( - user_id, "self_signing", from_user_id - ) - if key: - self_signing_keys[user_id] = key + master_keys = yield self.store.get_e2e_cross_signing_keys_bulk( + users, "master", from_user_id + ) + self_signing_keys = yield self.store.get_e2e_cross_signing_keys_bulk( + users, "self_signing", from_user_id + ) + if from_user_id in users: # users can see other users' master and self-signing keys, but can # only see their own user-signing keys - if from_user_id == user_id: - key = yield self.store.get_e2e_cross_signing_key( - user_id, "user_signing", from_user_id - ) - if key: - user_signing_keys[user_id] = key + key = yield self.store.get_e2e_cross_signing_key( + from_user_id, "user_signing", from_user_id + ) + if key: + user_signing_keys[from_user_id] = key return { "master_keys": master_keys, diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 08bcdc4725..d18d1b4c65 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -269,7 +269,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Args: txn (twisted.enterprise.adbapi.Connection): db connection user_id (str): the user whose key is being requested - key_type (str): the type of key that is being set: either 'master' + key_type (str): the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on @@ -314,8 +314,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): """Returns a user's cross-signing key. Args: - user_id (str): the user whose self-signing key is being requested - key_type (str): the type of cross-signing key to get + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on the self-signing key will be included in the result @@ -330,6 +332,104 @@ class EndToEndKeyWorkerStore(SQLBaseStore): from_user_id, ) + def _get_e2e_cross_signing_keys_bulk_txn( + self, txn, users, key_type, from_user_id=None + ): + """Returns the cross-signing keys for a set of users. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + users (iterable[str]): the users whose keys are being requested + key_type (str): the type of keys that are being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + from_user_id (str): if specified, signatures made by this user on + the keys will be included in the result + + Returns: + dict mapping from user ID to key data. If a user's cross-signing key + was not found, their user ID will not be in the dict. + """ + sql = ( + "SELECT k.user_id, k.keydata, k.stream_id " + " FROM e2e_cross_signing_keys k" + " JOIN (SELECT user_id, MAX(stream_id) AS stream_id" + " FROM e2e_cross_signing_keys" + " WHERE keytype = ?" + " GROUP BY user_id)" + " USING (user_id, stream_id)" + " WHERE k.user_id IN (%s) AND k.keytype = ?" + ) % ( + ",".join("?" for u in users) + ) + query_params = [key_type] + query_params.extend(users) + query_params.append(key_type) + + txn.execute(sql, query_params) + rows = self.cursor_to_dict(txn) + + result = {} + devices = {} + for row in rows: + user_id = row["user_id"] + key = json.loads(row["keydata"]) + result[user_id] = key + for k in key["keys"].values(): + devices[user_id] = k + + if devices and from_user_id: + # if we're asked to get signatures, and we have any devices to get + # signatures for, fetch the signatures + sql = ( + "SELECT target_user_id, key_id, signature " + " FROM e2e_cross_signing_signatures " + " WHERE user_id = ?" + " AND (target_user_id, target_device_id) IN (VALUES %s)" + ) % ( + ",".join("(?,?)" for d in devices) + ) + query_params = [from_user_id] + for item in devices.items(): + query_params.extend(item) + + txn.execute(sql, query_params) + rows = self.cursor_to_dict(txn) + + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_user_key = result[target_user_id] + signatures = target_user_key.setdefault("signatures", {}) + user_sigs = signatures.setdefault(from_user_id, {}) + user_sigs[key_id] = row["signature"] + + return result + + def get_e2e_cross_signing_keys_bulk(self, users, key_type, from_user_id=None): + """Returns the cross-signing keys for a set of users. + + Args: + users (iterable[str]): the users whose keys are being requested + key_type (str): the type of keys that are being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + from_user_id (str): if specified, signatures made by this user on + the self-signing keys will be included in the result + + Returns: + Deferred[dict] map of user ID to key data. If a user's cross-signing key + was not found, their user ID will not be in the dict. + """ + return self.runInteraction( + "get_e2e_cross_signing_key", + self._get_e2e_cross_signing_keys_bulk_txn, + users, + key_type, + from_user_id, + ) + def get_all_user_signature_changes_for_remotes(self, from_key, to_key): """Return a list of changes from the user signature stream to notify remotes. Note that the user signature stream represents when a user signs their