diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 609c2b18c9..10e9a2fd23 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -284,8 +284,8 @@ class _ServiceQueuer: compute one-time key counts and fallback key usages for the users. """ otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users) - # OSTD implement me! - return otk_counts, {} + unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users) + return otk_counts, unused_fbks class _TransactionController: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index cd7236eb70..5097a7737a 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -27,7 +27,10 @@ import attr from canonicaljson import encode_canonical_json from synapse.api.constants import DeviceKeyAlgorithms -from synapse.appservice import TransactionOneTimeKeyCounts +from synapse.appservice import ( + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -480,6 +483,50 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn ) + async def get_e2e_bulk_unused_fallback_key_types( + self, user_ids: Collection[str] + ) -> TransactionUnusedFallbackKeys: + """ + Finds, in bulk, the types of unused fallback keys for all the users specified. + Intended to be used by application services for populating unused fallback + keys in transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithms + """ + + def _get_bulk_e2e_unused_fallback_keys_txn( + txn: LoggingTransaction, + ) -> TransactionUnusedFallbackKeys: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + sql = f""" + SELECT user_id, device_id, algorithm + FROM devices + LEFT JOIN e2e_fallback_keys_json USING (user_id, device_id) + WHERE + {user_in_where_clause} + AND NOT used + """ + txn.execute(sql, user_parameters) + + result = {} + + for user_id, device_id, algorithm in txn: + device_unused_keys = result.setdefault(user_id, {}).setdefault( + device_id, [] + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_unused_keys.append(algorithm) + + return result + + return await self.db_pool.runInteraction( + "_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn + ) + async def set_e2e_fallback_keys( self, user_id: str, device_id: str, fallback_keys: JsonDict ) -> None: