1
0

Merge branch 'erikj/refactor_keyring' into erikj/test_send

This commit is contained in:
Erik Johnston
2021-05-04 17:57:57 +01:00
3 changed files with 60 additions and 66 deletions
+54 -53
View File
@@ -16,7 +16,7 @@
import abc
import logging
import urllib
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
import attr
from signedjson.key import (
@@ -41,11 +41,9 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config.key import TrustedKeyServer
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.events import EventBase
from synapse.events.utils import prune_event_dict
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
@@ -72,8 +70,6 @@ class VerifyJsonRequest:
minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
request_name: The name of the request.
key_ids: The set of key_ids to that could be used to verify the JSON object
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
@@ -86,14 +82,32 @@ class VerifyJsonRequest:
"""
server_name = attr.ib(type=str)
json_object = attr.ib(type=JsonDict)
json_object_callback = attr.ib(type=Callable[[], JsonDict])
minimum_valid_until_ts = attr.ib(type=int)
request_name = attr.ib(type=str)
key_ids = attr.ib(init=False, type=List[str])
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
key_ids = attr.ib(type=List[str])
def __attrs_post_init__(self):
self.key_ids = signature_ids(self.json_object, self.server_name)
@staticmethod
def from_json_object(
server_name: str, minimum_valid_until_ms: int, json_object: JsonDict
):
key_ids = signature_ids(json_object, server_name)
return VerifyJsonRequest(
server_name, lambda: json_object, minimum_valid_until_ms, key_ids
)
@staticmethod
def from_event(
server_name: str,
minimum_valid_until_ms: int,
event: EventBase,
):
key_ids = list(event.signatures.get(server_name, []))
return VerifyJsonRequest(
server_name,
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
minimum_valid_until_ms,
key_ids,
)
class KeyLookupError(ValueError):
@@ -179,8 +193,10 @@ class Keyring:
validity_time: int,
request_name: str,
) -> defer.Deferred:
request = VerifyJsonRequest(
server_name, json_object, validity_time, request_name
request = VerifyJsonRequest.from_json_object(
server_name,
validity_time,
json_object,
)
return defer.ensureDeferred(self._verify_object(request))
@@ -190,14 +206,32 @@ class Keyring:
return [
defer.ensureDeferred(
self._verify_object(
VerifyJsonRequest(
server_name, json_object, validity_time, request_name
VerifyJsonRequest.from_json_object(
server_name,
validity_time,
json_object,
)
)
)
for server_name, json_object, validity_time, request_name in server_and_json
]
def verify_events_for_server(
self, server_and_json: Iterable[Tuple[str, EventBase, int]]
) -> List[defer.Deferred]:
return [
defer.ensureDeferred(
self._verify_object(
VerifyJsonRequest.from_event(
server_name,
validity_time,
event,
)
)
)
for server_name, event, validity_time in server_and_json
]
async def _verify_object(self, verify_request: VerifyJsonRequest):
# TODO: Use a batching thing.
with (await self._server_queue.queue(verify_request.server_name)):
@@ -240,8 +274,9 @@ class Keyring:
for key_id in verify_request.key_ids:
verify_key = found_keys[key_id].verify_key
try:
json_object = verify_request.json_object_callback()
verify_signed_json(
verify_request.json_object,
json_object,
verify_request.server_name,
verify_key,
)
@@ -696,37 +731,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
keys.update(response_keys)
return keys
async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
"""Waits for the key to become available, and then performs a verification
Args:
verify_request:
Raises:
SynapseError if there was a problem performing the verification
"""
server_name = verify_request.server_name
with PreserveLoggingContext():
_, key_id, verify_key = await verify_request.key_ready
json_object = verify_request.json_object
try:
verify_signed_json(json_object, server_name, verify_key)
except SignatureVerifyException as e:
logger.debug(
"Error verifying signature for %s:%s:%s with key %s: %s",
server_name,
verify_key.alg,
verify_key.version,
encode_verify_key_base64(verify_key),
str(e),
)
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s: %s"
% (server_name, verify_key.alg, verify_key.version, str(e)),
Codes.UNAUTHORIZED,
)
+5 -12
View File
@@ -137,11 +137,7 @@ class FederationBase:
return deferreds
class PduToCheckSig(
namedtuple(
"PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
)
):
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass
@@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
pdus_to_check = [
PduToCheckSig(
pdu=p,
redacted_pdu_json=prune_event(p).get_pdu_json(),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
@@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
more_deferreds = keyring.verify_json_objects_for_server(
more_deferreds = keyring.verify_events_for_server(
[
(
p.sender_domain,
p.redacted_pdu_json,
p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
]
@@ -230,13 +224,12 @@ def _check_sigs_on_pdus(
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
more_deferreds = keyring.verify_json_objects_for_server(
more_deferreds = keyring.verify_events_for_server(
[
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id
]
+1 -1
View File
@@ -215,7 +215,7 @@ class RemoteKey(DirectServeJsonResource):
# ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
await yieldable_gather_results(
self.fetcher.get_keys,
lambda t: self.fetcher.get_keys(*t),
(
(server_name, list(keys), 0)
for server_name, keys in cache_misses.items()