diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index bc9d9d43f9..84205542e2 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -14,10 +14,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterable + from six import iteritems from canonicaljson import encode_canonical_json, json +from twisted.enterprise.adbapi import Connection from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace @@ -333,8 +336,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) def _get_e2e_cross_signing_keys_bulk_txn( - self, txn, users, key_type, from_user_id=None - ): + self, + txn: Connection, + users: Iterable[str], + key_type: str, + from_user_id: str = None, + ) -> dict: """Returns the cross-signing keys for a set of users. Args: @@ -347,68 +354,82 @@ class EndToEndKeyWorkerStore(SQLBaseStore): the keys will be included in the result Returns: - dict mapping from user ID to key data. If a user's cross-signing key - was not found, their user ID will not be in the dict. + dict[str, dict]: mapping from user ID to key data. If a user's + cross-signing key was not found, their user ID will not be in the + dict. """ - sql = ( - "SELECT k.user_id, k.keydata, k.stream_id " - " FROM e2e_cross_signing_keys k" - " JOIN (SELECT user_id, MAX(stream_id) AS stream_id" - " FROM e2e_cross_signing_keys" - " WHERE keytype = ?" - " GROUP BY user_id) s" - " USING (user_id, stream_id)" - " WHERE k.user_id IN (%s) AND k.keytype = ?" - ) % (",".join("?" for u in users)) - query_params = [key_type] - query_params.extend(users) - query_params.append(key_type) - txn.execute(sql, query_params) - rows = self.cursor_to_dict(txn) + # convert to a list if needed, so that we can slice it + if not isinstance(users, list): + users = list(users) result = {} - devices = {} - for row in rows: - user_id = row["user_id"] - key = json.loads(row["keydata"]) - result[user_id] = key - for k in key["keys"].values(): - devices[user_id] = k - if devices and from_user_id: - # if we're asked to get signatures, and we have any devices to get - # signatures for, fetch the signatures - sql = ( - "SELECT target_user_id, key_id, signature " - " FROM e2e_cross_signing_signatures " - " WHERE user_id = ?" - " AND (%s)" - ) % ( - " OR ".join( - "(target_user_id = ? AND target_device_id = ?)" for d in devices - ) + batch_size = 100 + chunks = [users[i : i + batch_size] for i in range(0, len(users), batch_size)] + for user_chunk in chunks: + sql = """ + SELECT k.user_id, k.keydata, k.stream_id + FROM e2e_cross_signing_keys k + INNER JOIN (SELECT user_id, MAX(stream_id) AS stream_id + FROM e2e_cross_signing_keys + WHERE keytype = ? + GROUP BY user_id) s + USING (user_id, stream_id) + WHERE k.user_id IN (%s) AND k.keytype = ? + """ % ( + ",".join("?" for u in user_chunk) ) - query_params = [from_user_id] - for item in devices.items(): - # item is a (user_id, device_id) tuple - query_params.extend(item) + query_params = [key_type] + query_params.extend(user_chunk) + query_params.append(key_type) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) - # and add the signatures to the appropriate keys + devices = {} for row in rows: - key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_user_key = result[target_user_id] - signatures = target_user_key.setdefault("signatures", {}) - user_sigs = signatures.setdefault(from_user_id, {}) - user_sigs[key_id] = row["signature"] + user_id = row["user_id"] + key = json.loads(row["keydata"]) + result[user_id] = key + for k in key["keys"].values(): + devices[user_id] = k + + if devices and from_user_id: + # if we're asked to get signatures, and we have any devices to get + # signatures for, fetch the signatures + sql = """ + SELECT target_user_id, key_id, signature + FROM e2e_cross_signing_signatures + WHERE user_id = ? + AND (%s) + """ % ( + " OR ".join( + "(target_user_id = ? AND target_device_id = ?)" for d in devices + ) + ) + query_params = [from_user_id] + for item in devices.items(): + # item is a (user_id, device_id) tuple + query_params.extend(item) + + txn.execute(sql, query_params) + rows = self.cursor_to_dict(txn) + + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_user_key = result[target_user_id] + signatures = target_user_key.setdefault("signatures", {}) + user_sigs = signatures.setdefault(from_user_id, {}) + user_sigs[key_id] = row["signature"] return result - def get_e2e_cross_signing_keys_bulk(self, users, key_type, from_user_id=None): + def get_e2e_cross_signing_keys_bulk( + self, users: Iterable[str], key_type: str, from_user_id: str = None + ) -> defer.Deferred: """Returns the cross-signing keys for a set of users. Args: @@ -420,8 +441,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): the self-signing keys will be included in the result Returns: - Deferred[dict] map of user ID to key data. If a user's cross-signing key - was not found, their user ID will not be in the dict. + Deferred[dict[str, dict]]: map of user ID to key data. If a user's + cross-signing key was not found, their user ID will not be in + the dict. """ return self.runInteraction( "get_e2e_cross_signing_key",