Merge branch 'erikj/refactor_keyring' into erikj/test_send
This commit is contained in:
+176
-357
@@ -16,8 +16,7 @@
|
||||
import abc
|
||||
import logging
|
||||
import urllib
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from signedjson.key import (
|
||||
@@ -45,14 +44,13 @@ from synapse.config.key import TrustedKeyServer
|
||||
from synapse.logging.context import (
|
||||
PreserveLoggingContext,
|
||||
make_deferred_yieldable,
|
||||
preserve_fn,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.keys import FetchKeyResult
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.async_helpers import Linearizer, yieldable_gather_results
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -102,6 +100,62 @@ class KeyLookupError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _QueueValue:
|
||||
server_name = attr.ib(type=str)
|
||||
minimum_valid_until_ts = attr.ib(type=int)
|
||||
key_ids = attr.ib(type=List[str])
|
||||
|
||||
|
||||
class _Queue:
|
||||
def __init__(self, name, clock, process_items):
|
||||
self._name = name
|
||||
self._clock = clock
|
||||
self._is_processing = False
|
||||
self._next_values = []
|
||||
|
||||
self.process_items = process_items
|
||||
|
||||
async def add_to_queue(self, value: _QueueValue) -> Dict[str, FetchKeyResult]:
|
||||
d = defer.Deferred()
|
||||
self._next_values.append((value, d))
|
||||
|
||||
if self._is_processing:
|
||||
return await d
|
||||
|
||||
run_as_background_process(self._name, self._unsafe_process)
|
||||
|
||||
return await d
|
||||
|
||||
async def _unsafe_process(self):
|
||||
# We purposefully defer to the next loop.
|
||||
await self._clock.sleep(0)
|
||||
|
||||
try:
|
||||
if self._is_processing:
|
||||
return
|
||||
|
||||
self._is_processing = True
|
||||
|
||||
while self._next_values:
|
||||
next_values = self._next_values
|
||||
self._next_values = []
|
||||
|
||||
try:
|
||||
values = [value for value, _ in next_values]
|
||||
results = await self.process_items(values)
|
||||
|
||||
for value, deferred in next_values:
|
||||
deferred.callback(results.get(value.server_name, {}))
|
||||
|
||||
except Exception as e:
|
||||
for _, deferred in next_values:
|
||||
deferred.errback(e)
|
||||
|
||||
finally:
|
||||
self._is_processing = False
|
||||
|
||||
|
||||
class Keyring:
|
||||
def __init__(
|
||||
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
|
||||
@@ -116,12 +170,7 @@ class Keyring:
|
||||
)
|
||||
self._key_fetchers = key_fetchers
|
||||
|
||||
# map from server name to Deferred. Has an entry for each server with
|
||||
# an ongoing key download; the Deferred completes once the download
|
||||
# completes.
|
||||
#
|
||||
# These are regular, logcontext-agnostic Deferreds.
|
||||
self.key_downloads = {} # type: Dict[str, defer.Deferred]
|
||||
self._server_queue = Linearizer("keyring_server")
|
||||
|
||||
def verify_json_for_server(
|
||||
self,
|
||||
@@ -130,365 +179,129 @@ class Keyring:
|
||||
validity_time: int,
|
||||
request_name: str,
|
||||
) -> defer.Deferred:
|
||||
"""Verify that a JSON object has been signed by a given server
|
||||
|
||||
Args:
|
||||
server_name: name of the server which must have signed this object
|
||||
|
||||
json_object: object to be checked
|
||||
|
||||
validity_time: timestamp at which we require the signing key to
|
||||
be valid. (0 implies we don't care)
|
||||
|
||||
request_name: an identifier for this json object (eg, an event id)
|
||||
for logging.
|
||||
|
||||
Returns:
|
||||
Deferred[None]: completes if the the object was correctly signed, otherwise
|
||||
errbacks with an error
|
||||
"""
|
||||
req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
|
||||
requests = (req,)
|
||||
return make_deferred_yieldable(self._verify_objects(requests)[0])
|
||||
request = VerifyJsonRequest(
|
||||
server_name, json_object, validity_time, request_name
|
||||
)
|
||||
return defer.ensureDeferred(self._verify_object(request))
|
||||
|
||||
def verify_json_objects_for_server(
|
||||
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
|
||||
) -> List[defer.Deferred]:
|
||||
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
||||
necessary.
|
||||
|
||||
Args:
|
||||
server_and_json:
|
||||
Iterable of (server_name, json_object, validity_time, request_name)
|
||||
tuples.
|
||||
|
||||
validity_time is a timestamp at which the signing key must be
|
||||
valid.
|
||||
|
||||
request_name is an identifier for this json object (eg, an event id)
|
||||
for logging.
|
||||
|
||||
Returns:
|
||||
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
||||
or failure to verify each json object's signature for the given
|
||||
server_name. The deferreds run their callbacks in the sentinel
|
||||
logcontext.
|
||||
"""
|
||||
return self._verify_objects(
|
||||
VerifyJsonRequest(server_name, json_object, validity_time, request_name)
|
||||
for server_name, json_object, validity_time, request_name in server_and_json
|
||||
)
|
||||
|
||||
def _verify_objects(
|
||||
self, verify_requests: Iterable[VerifyJsonRequest]
|
||||
) -> List[defer.Deferred]:
|
||||
"""Does the work of verify_json_[objects_]for_server
|
||||
|
||||
|
||||
Args:
|
||||
verify_requests: Iterable of verification requests.
|
||||
|
||||
Returns:
|
||||
List<Deferred[None]>: for each input item, a deferred indicating success
|
||||
or failure to verify each json object's signature for the given
|
||||
server_name. The deferreds run their callbacks in the sentinel
|
||||
logcontext.
|
||||
"""
|
||||
# a list of VerifyJsonRequests which are awaiting a key lookup
|
||||
key_lookups = []
|
||||
handle = preserve_fn(_handle_key_deferred)
|
||||
|
||||
def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
|
||||
"""Process an entry in the request list
|
||||
|
||||
Adds a key request to key_lookups, and returns a deferred which
|
||||
will complete or fail (in the sentinel context) when verification completes.
|
||||
"""
|
||||
if not verify_request.key_ids:
|
||||
return defer.fail(
|
||||
SynapseError(
|
||||
400,
|
||||
"Not signed by %s" % (verify_request.server_name,),
|
||||
Codes.UNAUTHORIZED,
|
||||
return [
|
||||
defer.ensureDeferred(
|
||||
self._verify_object(
|
||||
VerifyJsonRequest(
|
||||
server_name, json_object, validity_time, request_name
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Verifying %s for %s with key_ids %s, min_validity %i",
|
||||
verify_request.request_name,
|
||||
verify_request.server_name,
|
||||
verify_request.key_ids,
|
||||
verify_request.minimum_valid_until_ts,
|
||||
)
|
||||
for server_name, json_object, validity_time, request_name in server_and_json
|
||||
]
|
||||
|
||||
# add the key request to the queue, but don't start it off yet.
|
||||
key_lookups.append(verify_request)
|
||||
async def _verify_object(self, verify_request: VerifyJsonRequest):
|
||||
# TODO: Use a batching thing.
|
||||
with (await self._server_queue.queue(verify_request.server_name)):
|
||||
found_keys: Dict[str, FetchKeyResult] = {}
|
||||
missing_key_ids = set(verify_request.key_ids)
|
||||
for fetcher in self._key_fetchers:
|
||||
if not missing_key_ids:
|
||||
break
|
||||
|
||||
# now run _handle_key_deferred, which will wait for the key request
|
||||
# to complete and then do the verification.
|
||||
#
|
||||
# We want _handle_key_request to log to the right context, so we
|
||||
# wrap it with preserve_fn (aka run_in_background)
|
||||
return handle(verify_request)
|
||||
|
||||
results = [process(r) for r in verify_requests]
|
||||
|
||||
if key_lookups:
|
||||
run_in_background(self._start_key_lookups, key_lookups)
|
||||
|
||||
return results
|
||||
|
||||
async def _start_key_lookups(
|
||||
self, verify_requests: List[VerifyJsonRequest]
|
||||
) -> None:
|
||||
"""Sets off the key fetches for each verify request
|
||||
|
||||
Once each fetch completes, verify_request.key_ready will be resolved.
|
||||
|
||||
Args:
|
||||
verify_requests:
|
||||
"""
|
||||
|
||||
try:
|
||||
# map from server name to a set of outstanding request ids
|
||||
server_to_request_ids = {} # type: Dict[str, Set[int]]
|
||||
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||
|
||||
# Wait for any previous lookups to complete before proceeding.
|
||||
await self.wait_for_previous_lookups(server_to_request_ids.keys())
|
||||
|
||||
# take out a lock on each of the servers by sticking a Deferred in
|
||||
# key_downloads
|
||||
for server_name in server_to_request_ids.keys():
|
||||
self.key_downloads[server_name] = defer.Deferred()
|
||||
logger.debug("Got key lookup lock on %s", server_name)
|
||||
|
||||
# When we've finished fetching all the keys for a given server_name,
|
||||
# drop the lock by resolving the deferred in key_downloads.
|
||||
def drop_server_lock(server_name):
|
||||
d = self.key_downloads.pop(server_name)
|
||||
d.callback(None)
|
||||
|
||||
def lookup_done(res, verify_request):
|
||||
server_name = verify_request.server_name
|
||||
server_requests = server_to_request_ids[server_name]
|
||||
server_requests.remove(id(verify_request))
|
||||
|
||||
# if there are no more requests for this server, we can drop the lock.
|
||||
if not server_requests:
|
||||
logger.debug("Releasing key lookup lock on %s", server_name)
|
||||
drop_server_lock(server_name)
|
||||
|
||||
return res
|
||||
|
||||
for verify_request in verify_requests:
|
||||
verify_request.key_ready.addBoth(lookup_done, verify_request)
|
||||
|
||||
# Actually start fetching keys.
|
||||
self._get_server_verify_keys(verify_requests)
|
||||
except Exception:
|
||||
logger.exception("Error starting key lookups")
|
||||
|
||||
async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
|
||||
"""Waits for any previous key lookups for the given servers to finish.
|
||||
|
||||
Args:
|
||||
server_names: list of servers which we want to look up
|
||||
|
||||
Returns:
|
||||
Resolves once all key lookups for the given servers have
|
||||
completed. Follows the synapse rules of logcontext preservation.
|
||||
"""
|
||||
loop_count = 1
|
||||
while True:
|
||||
wait_on = [
|
||||
(server_name, self.key_downloads[server_name])
|
||||
for server_name in server_names
|
||||
if server_name in self.key_downloads
|
||||
]
|
||||
if not wait_on:
|
||||
break
|
||||
logger.info(
|
||||
"Waiting for existing lookups for %s to complete [loop %i]",
|
||||
[w[0] for w in wait_on],
|
||||
loop_count,
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
await defer.DeferredList((w[1] for w in wait_on))
|
||||
|
||||
loop_count += 1
|
||||
|
||||
def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
|
||||
"""Tries to find at least one key for each verify request
|
||||
|
||||
For each verify_request, verify_request.key_ready is called back with
|
||||
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
|
||||
with a SynapseError if none of the keys are found.
|
||||
|
||||
Args:
|
||||
verify_requests: list of verify requests
|
||||
"""
|
||||
|
||||
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
||||
|
||||
async def do_iterations():
|
||||
try:
|
||||
with Measure(self.clock, "get_server_verify_keys"):
|
||||
for f in self._key_fetchers:
|
||||
if not remaining_requests:
|
||||
return
|
||||
await self._attempt_key_fetches_with_fetcher(
|
||||
f, remaining_requests
|
||||
)
|
||||
|
||||
# look for any requests which weren't satisfied
|
||||
while remaining_requests:
|
||||
verify_request = remaining_requests.pop()
|
||||
rq_str = (
|
||||
"VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
|
||||
% (
|
||||
verify_request.server_name,
|
||||
verify_request.key_ids,
|
||||
verify_request.minimum_valid_until_ts,
|
||||
)
|
||||
)
|
||||
|
||||
# If we run the errback immediately, it may cancel our
|
||||
# loggingcontext while we are still in it, so instead we
|
||||
# schedule it for the next time round the reactor.
|
||||
#
|
||||
# (this also ensures that we don't get a stack overflow if we
|
||||
# has a massive queue of lookups waiting for this server).
|
||||
self.clock.call_later(
|
||||
0,
|
||||
verify_request.key_ready.errback,
|
||||
SynapseError(
|
||||
401,
|
||||
"Failed to find any key to satisfy %s" % (rq_str,),
|
||||
Codes.UNAUTHORIZED,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
# we don't really expect to get here, because any errors should already
|
||||
# have been caught and logged. But if we do, let's log the error and make
|
||||
# sure that all of the deferreds are resolved.
|
||||
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
|
||||
with PreserveLoggingContext():
|
||||
for verify_request in remaining_requests:
|
||||
if not verify_request.key_ready.called:
|
||||
verify_request.key_ready.errback(err)
|
||||
|
||||
run_in_background(do_iterations)
|
||||
|
||||
async def _attempt_key_fetches_with_fetcher(
|
||||
self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
|
||||
):
|
||||
"""Use a key fetcher to attempt to satisfy some key requests
|
||||
|
||||
Args:
|
||||
fetcher: fetcher to use to fetch the keys
|
||||
remaining_requests: outstanding key requests.
|
||||
Any successfully-completed requests will be removed from the list.
|
||||
"""
|
||||
# The keys to fetch.
|
||||
# server_name -> key_id -> min_valid_ts
|
||||
missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
|
||||
|
||||
for verify_request in remaining_requests:
|
||||
# any completed requests should already have been removed
|
||||
assert not verify_request.key_ready.called
|
||||
keys_for_server = missing_keys[verify_request.server_name]
|
||||
|
||||
for key_id in verify_request.key_ids:
|
||||
# If we have several requests for the same key, then we only need to
|
||||
# request that key once, but we should do so with the greatest
|
||||
# min_valid_until_ts of the requests, so that we can satisfy all of
|
||||
# the requests.
|
||||
keys_for_server[key_id] = max(
|
||||
keys_for_server.get(key_id, -1),
|
||||
keys = await fetcher.get_keys(
|
||||
verify_request.server_name,
|
||||
list(missing_key_ids),
|
||||
verify_request.minimum_valid_until_ts,
|
||||
)
|
||||
|
||||
results = await fetcher.get_keys(missing_keys)
|
||||
for key_id, key in keys.items():
|
||||
if not key:
|
||||
continue
|
||||
|
||||
completed = []
|
||||
for verify_request in remaining_requests:
|
||||
server_name = verify_request.server_name
|
||||
if key.valid_until_ts < verify_request.minimum_valid_until_ts:
|
||||
continue
|
||||
|
||||
existing_key = found_keys.get(key_id)
|
||||
if existing_key:
|
||||
if key.valid_until_ts <= existing_key.valid_until_ts:
|
||||
continue
|
||||
|
||||
found_keys[key_id] = key
|
||||
|
||||
missing_key_ids.difference_update(found_keys)
|
||||
|
||||
if missing_key_ids:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing keys for %s: %s"
|
||||
% (verify_request.server_name, missing_key_ids),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
# see if any of the keys we got this time are sufficient to
|
||||
# complete this VerifyJsonRequest.
|
||||
result_keys = results.get(server_name, {})
|
||||
for key_id in verify_request.key_ids:
|
||||
fetch_key_result = result_keys.get(key_id)
|
||||
if not fetch_key_result:
|
||||
# we didn't get a result for this key
|
||||
continue
|
||||
|
||||
if (
|
||||
fetch_key_result.valid_until_ts
|
||||
< verify_request.minimum_valid_until_ts
|
||||
):
|
||||
# key was not valid at this point
|
||||
continue
|
||||
|
||||
# we have a valid key for this request. If we run the callback
|
||||
# immediately, it may cancel our loggingcontext while we are still in
|
||||
# it, so instead we schedule it for the next time round the reactor.
|
||||
#
|
||||
# (this also ensures that we don't get a stack overflow if we had
|
||||
# a massive queue of lookups waiting for this server).
|
||||
logger.debug(
|
||||
"Found key %s:%s for %s",
|
||||
server_name,
|
||||
key_id,
|
||||
verify_request.request_name,
|
||||
)
|
||||
self.clock.call_later(
|
||||
0,
|
||||
verify_request.key_ready.callback,
|
||||
(server_name, key_id, fetch_key_result.verify_key),
|
||||
)
|
||||
completed.append(verify_request)
|
||||
break
|
||||
|
||||
remaining_requests.difference_update(completed)
|
||||
verify_key = found_keys[key_id].verify_key
|
||||
try:
|
||||
verify_signed_json(
|
||||
verify_request.json_object,
|
||||
verify_request.server_name,
|
||||
verify_key,
|
||||
)
|
||||
except SignatureVerifyException as e:
|
||||
logger.debug(
|
||||
"Error verifying signature for %s:%s:%s with key %s: %s",
|
||||
verify_request.server_name,
|
||||
verify_key.alg,
|
||||
verify_key.version,
|
||||
encode_verify_key_base64(verify_key),
|
||||
str(e),
|
||||
)
|
||||
raise SynapseError(
|
||||
401,
|
||||
"Invalid signature for server %s with key %s:%s: %s"
|
||||
% (
|
||||
verify_request.server_name,
|
||||
verify_key.alg,
|
||||
verify_key.version,
|
||||
str(e),
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
class KeyFetcher(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch:
|
||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._queue = _Queue(self.__class__.__name__, hs.get_clock(), self._fetch_keys)
|
||||
|
||||
Returns:
|
||||
Map from server_name -> key_id -> FetchKeyResult
|
||||
"""
|
||||
raise NotImplementedError
|
||||
async def get_keys(
|
||||
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||
) -> Dict[str, FetchKeyResult]:
|
||||
return await self._queue.add_to_queue(
|
||||
_QueueValue(
|
||||
server_name=server_name,
|
||||
key_ids=key_ids,
|
||||
minimum_valid_until_ts=minimum_valid_until_ts,
|
||||
)
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _fetch_keys(
|
||||
self, keys_to_fetch: List[_QueueValue]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
pass
|
||||
|
||||
|
||||
class StoreKeyFetcher(KeyFetcher):
|
||||
"""KeyFetcher impl which fetches keys from our data store"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""see KeyFetcher.get_keys"""
|
||||
|
||||
async def _fetch_keys(self, keys_to_fetch: List[_QueueValue]):
|
||||
key_ids_to_fetch = (
|
||||
(server_name, key_id)
|
||||
for server_name, keys_for_server in keys_to_fetch.items()
|
||||
for key_id in keys_for_server.keys()
|
||||
(queue_value.server_name, key_id)
|
||||
for queue_value in keys_to_fetch
|
||||
for key_id in queue_value.key_ids
|
||||
)
|
||||
|
||||
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
|
||||
@@ -500,6 +313,8 @@ class StoreKeyFetcher(KeyFetcher):
|
||||
|
||||
class BaseV2KeyFetcher(KeyFetcher):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.config = hs.config
|
||||
|
||||
@@ -607,10 +422,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
self.client = hs.get_federation_http_client()
|
||||
self.key_servers = self.config.key_servers
|
||||
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
async def _fetch_keys(
|
||||
self, keys_to_fetch: List[_QueueValue]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""see KeyFetcher.get_keys"""
|
||||
"""see KeyFetcher._fetch_keys"""
|
||||
|
||||
async def get_key(key_server: TrustedKeyServer) -> Dict:
|
||||
try:
|
||||
@@ -646,12 +461,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
return union_of_keys
|
||||
|
||||
async def get_server_verify_key_v2_indirect(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
|
||||
self, keys_to_fetch: List[_QueueValue], key_server: TrustedKeyServer
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch:
|
||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||
the keys to be fetched.
|
||||
|
||||
key_server: notary server to query for the keys
|
||||
|
||||
@@ -665,7 +480,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
perspective_name = key_server.server_name
|
||||
logger.info(
|
||||
"Requesting keys %s from notary server %s",
|
||||
keys_to_fetch.items(),
|
||||
keys_to_fetch,
|
||||
perspective_name,
|
||||
)
|
||||
|
||||
@@ -675,11 +490,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
path="/_matrix/key/v2/query",
|
||||
data={
|
||||
"server_keys": {
|
||||
server_name: {
|
||||
key_id: {"minimum_valid_until_ts": min_valid_ts}
|
||||
for key_id, min_valid_ts in server_keys.items()
|
||||
queue_value.server_name: {
|
||||
key_id: {
|
||||
"minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
|
||||
}
|
||||
for key_id in queue_value.key_ids
|
||||
}
|
||||
for server_name, server_keys in keys_to_fetch.items()
|
||||
for queue_value in keys_to_fetch
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -779,8 +596,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
self.clock = hs.get_clock()
|
||||
self.client = hs.get_federation_http_client()
|
||||
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
async def _fetch_keys(
|
||||
self, keys_to_fetch: List[_QueueValue]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -793,8 +610,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
|
||||
results = {}
|
||||
|
||||
async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
|
||||
server_name, key_ids = key_to_fetch_item
|
||||
async def get_key(key_to_fetch_item: _QueueValue) -> None:
|
||||
server_name = key_to_fetch_item.server_name
|
||||
key_ids = key_to_fetch_item.key_ids
|
||||
|
||||
try:
|
||||
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
||||
results[server_name] = keys
|
||||
@@ -805,7 +624,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
except Exception:
|
||||
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
||||
|
||||
await yieldable_gather_results(get_key, keys_to_fetch.items())
|
||||
await yieldable_gather_results(get_key, keys_to_fetch)
|
||||
return results
|
||||
|
||||
async def get_server_verify_key_v2_direct(
|
||||
|
||||
@@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -213,7 +214,13 @@ class RemoteKey(DirectServeJsonResource):
|
||||
# If there is a cache miss, request the missing keys, then recurse (and
|
||||
# ensure the result is sent).
|
||||
if cache_misses and query_remote_on_cache_miss:
|
||||
await self.fetcher.get_keys(cache_misses)
|
||||
await yieldable_gather_results(
|
||||
self.fetcher.get_keys,
|
||||
(
|
||||
(server_name, list(keys), 0)
|
||||
for server_name, keys in cache_misses.items()
|
||||
),
|
||||
)
|
||||
await self.query_keys(request, query, query_remote_on_cache_miss=False)
|
||||
else:
|
||||
signed_keys = []
|
||||
|
||||
Reference in New Issue
Block a user