Cache _get_e2e_cross_signing_signatures_for_devices (#18899)

This commit is contained in:
Andrew Morgan
2025-09-18 12:06:08 +01:00
committed by GitHub
parent 6f9fab1089
commit b596faa4ec
5 changed files with 196 additions and 59 deletions

View File

@@ -0,0 +1 @@
Add an in-memory cache to `_get_e2e_cross_signing_signatures_for_devices` to reduce DB load.

View File

@@ -2653,8 +2653,7 @@ def make_in_list_sql_clause(
# These overloads ensure that `columns` and `iterable` values have the same length.
# Suppress "Single overload definition, multiple required" complaint.
@overload # type: ignore[misc]
@overload
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, str],
@@ -2662,6 +2661,14 @@ def make_tuple_in_list_sql_clause(
) -> Tuple[str, list]: ...
@overload
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, str, str],
iterable: Collection[Tuple[Any, Any, Any]],
) -> Tuple[str, list]: ...
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, ...],

View File

@@ -21,6 +21,7 @@
import itertools
import json
import logging
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple
@@ -62,6 +63,12 @@ PURGE_HISTORY_CACHE_NAME = "ph_cache_fake"
# As above, but for invalidating room caches on room deletion
DELETE_ROOM_CACHE_NAME = "dr_cache_fake"
# This cache takes a list of tuples as its first argument, which requires
# special handling.
GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME = (
"_get_e2e_cross_signing_signatures_for_device"
)
# How long between cache invalidation table cleanups, once we have caught up
# with the backlog.
REGULAR_CLEANUP_INTERVAL_MS = Config.parse_duration("1h")
@@ -270,6 +277,33 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
# room membership.
#
# self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined]
elif (
row.cache_func
== GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME
):
# "keys" is a list of strings, where each string is a
# JSON-encoded representation of the tuple keys, i.e.
# keys: ['["@userid:domain", "DEVICEID"]','["@userid2:domain", "DEVICEID2"]']
#
# This is a side-effect of not being able to send nested
# information over replication.
for json_str in row.keys:
try:
user_id, device_id = json.loads(json_str)
except (json.JSONDecodeError, TypeError):
logger.error(
"Failed to deserialise cache key as valid JSON: %s",
json_str,
)
continue
# Invalidate each key.
#
# Note: .invalidate takes a tuple of arguments, hence the need
# to nest our tuple in another tuple.
self._get_e2e_cross_signing_signatures_for_device.invalidate( # type: ignore[attr-defined]
((user_id, device_id),)
)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)

View File

@@ -20,6 +20,7 @@
#
#
import abc
import json
from typing import (
TYPE_CHECKING,
Any,
@@ -354,15 +355,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures_for_devices",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
cross_sigs_result = (
await self._get_e2e_cross_signing_signatures_for_devices(batch)
)
# add each cross-signing signature to the correct device in the result dict.
for user_id, key_id, device_id, signature in cross_sigs_result:
for (
user_id,
device_id,
), signature_list in cross_sigs_result.items():
target_device_result = result[user_id][device_id]
# We've only looked up cross-signatures for non-deleted devices with key
# data.
assert target_device_result is not None
@@ -373,7 +376,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
signing_user_signatures = target_device_signatures.setdefault(
user_id, {}
)
signing_user_signatures[key_id] = signature
for key_id, signature in signature_list:
signing_user_signatures[key_id] = signature
log_kv(result)
return result
@@ -479,41 +484,83 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return result
def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
Returns signatures made by the owners of the devices.
Returns: a list of results; each entry in the list is a tuple of
(user_id, key_id, target_device_id, signature).
@cached()
def _get_e2e_cross_signing_signatures_for_device(
self,
user_id_and_device_id: Tuple[str, str],
) -> Sequence[Tuple[str, str]]:
"""
signature_query_clauses = []
signature_query_params = []
The single-item version of `_get_e2e_cross_signing_signatures_for_devices`.
See @cachedList for why a separate method is needed.
"""
raise NotImplementedError()
for user_id, device_id in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
@cachedList(
cached_method_name="_get_e2e_cross_signing_signatures_for_device",
list_name="device_query",
)
async def _get_e2e_cross_signing_signatures_for_devices(
self, device_query: Iterable[Tuple[str, str]]
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
"""Get cross-signing signatures for a given list of user IDs and devices.
Args:
An iterable containing tuples of (user ID, device ID).
Returns:
A mapping of results. The keys are the original (user_id, device_id)
tuple, while the value is the matching list of tuples of
(key_id, signature). The value will be an empty list if no
signatures exist for the device.
Given this method is annotated with `@cachedList`, the return dict's
keys match the tuples within `device_query`, so that cache entries can
be computed from the corresponding values.
As results are cached, the return type is immutable.
"""
def _get_e2e_cross_signing_signatures_for_devices_txn(
txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
where_clause_sql, where_clause_params = make_tuple_in_list_sql_clause(
self.database_engine,
columns=("target_user_id", "target_device_id", "user_id"),
iterable=[
(user_id, device_id, user_id) for user_id, device_id in device_query
],
)
signature_query_params.extend([user_id, device_id, user_id])
signature_sql = """
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s
""" % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
signature_sql = f"""
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE {where_clause_sql}
"""
txn.execute(signature_sql, signature_query_params)
return cast(
List[
Tuple[
str,
str,
str,
str,
]
],
txn.fetchall(),
txn.execute(signature_sql, where_clause_params)
devices_and_signatures: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}
# `@cachedList` requires we return one key for every item in `device_query`.
# Pre-populate `devices_and_signatures` with each key so that none are missing.
#
# If any are missing, they will be cached as `None`, which is not
# what callers expected.
for user_id, device_id in device_query:
devices_and_signatures.setdefault((user_id, device_id), [])
# Populate the return dictionary with each found key_id and signature.
for user_id, key_id, target_device_id, signature in txn.fetchall():
signature_tuple = (key_id, signature)
devices_and_signatures[(user_id, target_device_id)].append(
signature_tuple
)
return devices_and_signatures
return await self.db_pool.runInteraction(
"_get_e2e_cross_signing_signatures_for_devices_txn",
_get_e2e_cross_signing_signatures_for_devices_txn,
device_query,
)
async def get_e2e_one_time_keys(
@@ -1772,26 +1819,71 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
user_id: the user who made the signatures
signatures: signatures to add
"""
await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
keys=(
"user_id",
"key_id",
"target_user_id",
"target_device_id",
"signature",
),
values=[
(
user_id,
item.signing_key_id,
item.target_user_id,
item.target_device_id,
item.signature,
)
def _store_e2e_cross_signing_signatures(
txn: LoggingTransaction,
signatures: "Iterable[SignatureListItem]",
) -> None:
self.db_pool.simple_insert_many_txn(
txn,
"e2e_cross_signing_signatures",
keys=(
"user_id",
"key_id",
"target_user_id",
"target_device_id",
"signature",
),
values=[
(
user_id,
item.signing_key_id,
item.target_user_id,
item.target_device_id,
item.signature,
)
for item in signatures
],
)
to_invalidate = [
# Each entry is a tuple of arguments to
# `_get_e2e_cross_signing_signatures_for_device`, which
# itself takes a tuple. Hence the double-tuple.
((user_id, item.target_device_id),)
for item in signatures
],
desc="add_e2e_signing_key",
]
if to_invalidate:
# Invalidate the local cache of this worker.
for cache_key in to_invalidate:
txn.call_after(
self._get_e2e_cross_signing_signatures_for_device.invalidate,
cache_key,
)
# Stream cache invalidate keys over replication.
#
# We can only send a primitive per function argument across
# replication.
#
# Encode the array of strings as a JSON string, and we'll unpack
# it on the other side.
to_send = [
(json.dumps([user_id, item.target_device_id]),)
for item in signatures
]
self._send_invalidation_to_replication_bulk(
txn,
cache_name=self._get_e2e_cross_signing_signatures_for_device.__name__,
key_tuples=to_send,
)
await self.db_pool.runInteraction(
"add_e2e_signing_key",
_store_e2e_cross_signing_signatures,
signatures,
)

View File

@@ -579,9 +579,12 @@ def cachedList(
Used to do batch lookups for an already created cache. One of the arguments
is specified as a list that is iterated through to lookup keys in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
the cache gets passed to the original function, which is expected to results
the cache gets passed to the original function, which is expected to result
in a map of key to value for each passed value. The new results are stored in the
original cache. Note that any missing values are cached as None.
original cache.
Note that any values in the input that end up being missing from both the
cache and the returned dictionary will be cached as `None`.
Args:
cached_method_name: The name of the single-item lookup method.