Cache _get_e2e_cross_signing_signatures_for_devices (#18899)
This commit is contained in:
1
changelog.d/18899.feature
Normal file
1
changelog.d/18899.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add an in-memory cache to `_get_e2e_cross_signing_signatures_for_devices` to reduce DB load.
|
||||
@@ -2653,8 +2653,7 @@ def make_in_list_sql_clause(
|
||||
|
||||
|
||||
# These overloads ensure that `columns` and `iterable` values have the same length.
|
||||
# Suppress "Single overload definition, multiple required" complaint.
|
||||
@overload # type: ignore[misc]
|
||||
@overload
|
||||
def make_tuple_in_list_sql_clause(
|
||||
database_engine: BaseDatabaseEngine,
|
||||
columns: Tuple[str, str],
|
||||
@@ -2662,6 +2661,14 @@ def make_tuple_in_list_sql_clause(
|
||||
) -> Tuple[str, list]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def make_tuple_in_list_sql_clause(
|
||||
database_engine: BaseDatabaseEngine,
|
||||
columns: Tuple[str, str, str],
|
||||
iterable: Collection[Tuple[Any, Any, Any]],
|
||||
) -> Tuple[str, list]: ...
|
||||
|
||||
|
||||
def make_tuple_in_list_sql_clause(
|
||||
database_engine: BaseDatabaseEngine,
|
||||
columns: Tuple[str, ...],
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple
|
||||
|
||||
@@ -62,6 +63,12 @@ PURGE_HISTORY_CACHE_NAME = "ph_cache_fake"
|
||||
# As above, but for invalidating room caches on room deletion
|
||||
DELETE_ROOM_CACHE_NAME = "dr_cache_fake"
|
||||
|
||||
# This cache takes a list of tuples as its first argument, which requires
|
||||
# special handling.
|
||||
GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME = (
|
||||
"_get_e2e_cross_signing_signatures_for_device"
|
||||
)
|
||||
|
||||
# How long between cache invalidation table cleanups, once we have caught up
|
||||
# with the backlog.
|
||||
REGULAR_CLEANUP_INTERVAL_MS = Config.parse_duration("1h")
|
||||
@@ -270,6 +277,33 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
# room membership.
|
||||
#
|
||||
# self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined]
|
||||
elif (
|
||||
row.cache_func
|
||||
== GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME
|
||||
):
|
||||
# "keys" is a list of strings, where each string is a
|
||||
# JSON-encoded representation of the tuple keys, i.e.
|
||||
# keys: ['["@userid:domain", "DEVICEID"]','["@userid2:domain", "DEVICEID2"]']
|
||||
#
|
||||
# This is a side-effect of not being able to send nested
|
||||
# information over replication.
|
||||
for json_str in row.keys:
|
||||
try:
|
||||
user_id, device_id = json.loads(json_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.error(
|
||||
"Failed to deserialise cache key as valid JSON: %s",
|
||||
json_str,
|
||||
)
|
||||
continue
|
||||
|
||||
# Invalidate each key.
|
||||
#
|
||||
# Note: .invalidate takes a tuple of arguments, hence the need
|
||||
# to nest our tuple in another tuple.
|
||||
self._get_e2e_cross_signing_signatures_for_device.invalidate( # type: ignore[attr-defined]
|
||||
((user_id, device_id),)
|
||||
)
|
||||
else:
|
||||
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#
|
||||
#
|
||||
import abc
|
||||
import json
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -354,15 +355,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
)
|
||||
|
||||
for batch in batch_iter(signature_query, 50):
|
||||
cross_sigs_result = await self.db_pool.runInteraction(
|
||||
"get_e2e_cross_signing_signatures_for_devices",
|
||||
self._get_e2e_cross_signing_signatures_for_devices_txn,
|
||||
batch,
|
||||
cross_sigs_result = (
|
||||
await self._get_e2e_cross_signing_signatures_for_devices(batch)
|
||||
)
|
||||
|
||||
# add each cross-signing signature to the correct device in the result dict.
|
||||
for user_id, key_id, device_id, signature in cross_sigs_result:
|
||||
for (
|
||||
user_id,
|
||||
device_id,
|
||||
), signature_list in cross_sigs_result.items():
|
||||
target_device_result = result[user_id][device_id]
|
||||
|
||||
# We've only looked up cross-signatures for non-deleted devices with key
|
||||
# data.
|
||||
assert target_device_result is not None
|
||||
@@ -373,7 +376,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
signing_user_signatures = target_device_signatures.setdefault(
|
||||
user_id, {}
|
||||
)
|
||||
signing_user_signatures[key_id] = signature
|
||||
|
||||
for key_id, signature in signature_list:
|
||||
signing_user_signatures[key_id] = signature
|
||||
|
||||
log_kv(result)
|
||||
return result
|
||||
@@ -479,41 +484,83 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
|
||||
return result
|
||||
|
||||
def _get_e2e_cross_signing_signatures_for_devices_txn(
|
||||
self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""Get cross-signing signatures for a given list of devices
|
||||
|
||||
Returns signatures made by the owners of the devices.
|
||||
|
||||
Returns: a list of results; each entry in the list is a tuple of
|
||||
(user_id, key_id, target_device_id, signature).
|
||||
@cached()
|
||||
def _get_e2e_cross_signing_signatures_for_device(
|
||||
self,
|
||||
user_id_and_device_id: Tuple[str, str],
|
||||
) -> Sequence[Tuple[str, str]]:
|
||||
"""
|
||||
signature_query_clauses = []
|
||||
signature_query_params = []
|
||||
The single-item version of `_get_e2e_cross_signing_signatures_for_devices`.
|
||||
See @cachedList for why a separate method is needed.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
for user_id, device_id in device_query:
|
||||
signature_query_clauses.append(
|
||||
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
|
||||
@cachedList(
|
||||
cached_method_name="_get_e2e_cross_signing_signatures_for_device",
|
||||
list_name="device_query",
|
||||
)
|
||||
async def _get_e2e_cross_signing_signatures_for_devices(
|
||||
self, device_query: Iterable[Tuple[str, str]]
|
||||
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
|
||||
"""Get cross-signing signatures for a given list of user IDs and devices.
|
||||
|
||||
Args:
|
||||
An iterable containing tuples of (user ID, device ID).
|
||||
|
||||
Returns:
|
||||
A mapping of results. The keys are the original (user_id, device_id)
|
||||
tuple, while the value is the matching list of tuples of
|
||||
(key_id, signature). The value will be an empty list if no
|
||||
signatures exist for the device.
|
||||
|
||||
Given this method is annotated with `@cachedList`, the return dict's
|
||||
keys match the tuples within `device_query`, so that cache entries can
|
||||
be computed from the corresponding values.
|
||||
|
||||
As results are cached, the return type is immutable.
|
||||
"""
|
||||
|
||||
def _get_e2e_cross_signing_signatures_for_devices_txn(
|
||||
txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
|
||||
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
|
||||
where_clause_sql, where_clause_params = make_tuple_in_list_sql_clause(
|
||||
self.database_engine,
|
||||
columns=("target_user_id", "target_device_id", "user_id"),
|
||||
iterable=[
|
||||
(user_id, device_id, user_id) for user_id, device_id in device_query
|
||||
],
|
||||
)
|
||||
signature_query_params.extend([user_id, device_id, user_id])
|
||||
|
||||
signature_sql = """
|
||||
SELECT user_id, key_id, target_device_id, signature
|
||||
FROM e2e_cross_signing_signatures WHERE %s
|
||||
""" % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
|
||||
signature_sql = f"""
|
||||
SELECT user_id, key_id, target_device_id, signature
|
||||
FROM e2e_cross_signing_signatures WHERE {where_clause_sql}
|
||||
"""
|
||||
|
||||
txn.execute(signature_sql, signature_query_params)
|
||||
return cast(
|
||||
List[
|
||||
Tuple[
|
||||
str,
|
||||
str,
|
||||
str,
|
||||
str,
|
||||
]
|
||||
],
|
||||
txn.fetchall(),
|
||||
txn.execute(signature_sql, where_clause_params)
|
||||
|
||||
devices_and_signatures: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}
|
||||
|
||||
# `@cachedList` requires we return one key for every item in `device_query`.
|
||||
# Pre-populate `devices_and_signatures` with each key so that none are missing.
|
||||
#
|
||||
# If any are missing, they will be cached as `None`, which is not
|
||||
# what callers expected.
|
||||
for user_id, device_id in device_query:
|
||||
devices_and_signatures.setdefault((user_id, device_id), [])
|
||||
|
||||
# Populate the return dictionary with each found key_id and signature.
|
||||
for user_id, key_id, target_device_id, signature in txn.fetchall():
|
||||
signature_tuple = (key_id, signature)
|
||||
devices_and_signatures[(user_id, target_device_id)].append(
|
||||
signature_tuple
|
||||
)
|
||||
|
||||
return devices_and_signatures
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"_get_e2e_cross_signing_signatures_for_devices_txn",
|
||||
_get_e2e_cross_signing_signatures_for_devices_txn,
|
||||
device_query,
|
||||
)
|
||||
|
||||
async def get_e2e_one_time_keys(
|
||||
@@ -1772,26 +1819,71 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
user_id: the user who made the signatures
|
||||
signatures: signatures to add
|
||||
"""
|
||||
await self.db_pool.simple_insert_many(
|
||||
"e2e_cross_signing_signatures",
|
||||
keys=(
|
||||
"user_id",
|
||||
"key_id",
|
||||
"target_user_id",
|
||||
"target_device_id",
|
||||
"signature",
|
||||
),
|
||||
values=[
|
||||
(
|
||||
user_id,
|
||||
item.signing_key_id,
|
||||
item.target_user_id,
|
||||
item.target_device_id,
|
||||
item.signature,
|
||||
)
|
||||
|
||||
def _store_e2e_cross_signing_signatures(
|
||||
txn: LoggingTransaction,
|
||||
signatures: "Iterable[SignatureListItem]",
|
||||
) -> None:
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
"e2e_cross_signing_signatures",
|
||||
keys=(
|
||||
"user_id",
|
||||
"key_id",
|
||||
"target_user_id",
|
||||
"target_device_id",
|
||||
"signature",
|
||||
),
|
||||
values=[
|
||||
(
|
||||
user_id,
|
||||
item.signing_key_id,
|
||||
item.target_user_id,
|
||||
item.target_device_id,
|
||||
item.signature,
|
||||
)
|
||||
for item in signatures
|
||||
],
|
||||
)
|
||||
|
||||
to_invalidate = [
|
||||
# Each entry is a tuple of arguments to
|
||||
# `_get_e2e_cross_signing_signatures_for_device`, which
|
||||
# itself takes a tuple. Hence the double-tuple.
|
||||
((user_id, item.target_device_id),)
|
||||
for item in signatures
|
||||
],
|
||||
desc="add_e2e_signing_key",
|
||||
]
|
||||
|
||||
if to_invalidate:
|
||||
# Invalidate the local cache of this worker.
|
||||
for cache_key in to_invalidate:
|
||||
txn.call_after(
|
||||
self._get_e2e_cross_signing_signatures_for_device.invalidate,
|
||||
cache_key,
|
||||
)
|
||||
|
||||
# Stream cache invalidate keys over replication.
|
||||
#
|
||||
# We can only send a primitive per function argument across
|
||||
# replication.
|
||||
#
|
||||
# Encode the array of strings as a JSON string, and we'll unpack
|
||||
# it on the other side.
|
||||
to_send = [
|
||||
(json.dumps([user_id, item.target_device_id]),)
|
||||
for item in signatures
|
||||
]
|
||||
|
||||
self._send_invalidation_to_replication_bulk(
|
||||
txn,
|
||||
cache_name=self._get_e2e_cross_signing_signatures_for_device.__name__,
|
||||
key_tuples=to_send,
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"add_e2e_signing_key",
|
||||
_store_e2e_cross_signing_signatures,
|
||||
signatures,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -579,9 +579,12 @@ def cachedList(
|
||||
Used to do batch lookups for an already created cache. One of the arguments
|
||||
is specified as a list that is iterated through to lookup keys in the
|
||||
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
|
||||
the cache gets passed to the original function, which is expected to results
|
||||
the cache gets passed to the original function, which is expected to result
|
||||
in a map of key to value for each passed value. The new results are stored in the
|
||||
original cache. Note that any missing values are cached as None.
|
||||
original cache.
|
||||
|
||||
Note that any values in the input that end up being missing from both the
|
||||
cache and the returned dictionary will be cached as `None`.
|
||||
|
||||
Args:
|
||||
cached_method_name: The name of the single-item lookup method.
|
||||
|
||||
Reference in New Issue
Block a user