Count the OTKs in bulk
This commit is contained in:
@@ -280,8 +280,9 @@ class _ServiceQueuer:
|
||||
Given a list of application service users that are interesting,
|
||||
compute one-time key counts and fallback key usages for the users.
|
||||
"""
|
||||
otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users)
|
||||
# OSTD implement me!
|
||||
return {}, {}
|
||||
return otk_counts, {}
|
||||
|
||||
|
||||
class _TransactionController:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
@@ -22,9 +22,14 @@ from canonicaljson import encode_canonical_json
|
||||
from twisted.enterprise.adbapi import Connection
|
||||
|
||||
from synapse.api.constants import DeviceKeyAlgorithms
|
||||
from synapse.appservice import TransactionOneTimeKeyCounts
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.types import JsonDict
|
||||
@@ -397,6 +402,49 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
async def count_bulk_e2e_one_time_keys_for_as(
|
||||
self, user_ids: Collection[str]
|
||||
) -> TransactionOneTimeKeyCounts:
|
||||
"""
|
||||
Counts, in bulk, the one-time keys for all the users specified.
|
||||
Intended to be used by application services for populating OTK counts in
|
||||
transactions.
|
||||
|
||||
Return structure is of the shape:
|
||||
user_id -> device_id -> algorithm -> count
|
||||
"""
|
||||
|
||||
def _count_bulk_e2e_one_time_keys_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> TransactionOneTimeKeyCounts:
|
||||
user_in_where_clause, user_parameters = make_in_list_sql_clause(
|
||||
self.database_engine, "user_id", user_ids
|
||||
)
|
||||
sql = f"""
|
||||
SELECT user_id, device_id, algorithm, COUNT(key_id)
|
||||
FROM devices
|
||||
LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id)
|
||||
WHERE {user_in_where_clause}
|
||||
GROUP BY user_id, device_id, algorithm
|
||||
"""
|
||||
txn.execute(sql, user_parameters)
|
||||
|
||||
result = {}
|
||||
|
||||
for user_id, device_id, algorithm, count in txn:
|
||||
device_count_by_algo = result.setdefault(user_id, {}).setdefault(
|
||||
device_id, {}
|
||||
)
|
||||
if algorithm is not None:
|
||||
# algorithm will be None if this device has no keys.
|
||||
device_count_by_algo[algorithm] = count
|
||||
|
||||
return result
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn
|
||||
)
|
||||
|
||||
async def set_e2e_fallback_keys(
|
||||
self, user_id: str, device_id: str, fallback_keys: JsonDict
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user