stuff we want to bring over

This commit is contained in:
Andrew Morgan
2025-09-29 17:50:48 +01:00
parent a2ef624153
commit e049a65e9e
6 changed files with 360 additions and 80 deletions

View File

@@ -20,12 +20,22 @@
#
#
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)
import attr
from canonicaljson import encode_canonical_json
from signedjson.key import VerifyKey, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from typing_extensions import TypeAlias
from unpaddedbase64 import decode_base64
from twisted.internet import defer
@@ -60,6 +70,47 @@ logger = logging.getLogger(__name__)
ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
@attr.s(frozen=True, slots=True, auto_attribs=True)
class DeviceKeys:
algorithms: List[str]
"""The encryption algorithms supported by this device."""
device_id: str
"""The ID of the device these keys belong to. Must match the device ID used when logging in."""
keys: Mapping[str, str]
"""
Public identity keys. The names of the properties should be in the
format `<algorithm>:<device_id>`. The keys themselves should be encoded as
specified by the key algorithm.
"""
signatures: Mapping[UserID, Mapping[str, str]]
"""Signatures for the device key object. A map from user ID, to a map from "<algorithm>:<device_id>" to the signature."""
user_id: UserID
"""The ID of the user the device belongs to. Must match the user ID used when logging in."""
@attr.s(frozen=True, slots=True, auto_attribs=True)
class KeyObject:
key: str
"""The key, encoded using unpadded base64."""
signatures: Mapping[UserID, Mapping[str, str]]
"""Signature for the device. Mapped from user ID to another map of key signing identifier to the signature itself.
See the following for more detail: https://spec.matrix.org/v1.16/appendices/#signing-details
"""
fallback: bool = False
"""Whether this is a fallback key."""
FallbackKeys: TypeAlias = Mapping[str, Union[str, KeyObject]]
OneTimeKeys: TypeAlias = Mapping[str, Union[str, KeyObject]]
class E2eKeysHandler:
def __init__(self, hs: "HomeServer"):
self.config = hs.config
@@ -833,7 +884,12 @@ class E2eKeysHandler:
@tag_args
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
self,
user_id: str,
device_id: str,
device_keys: Optional[DeviceKeys],
fallback_keys: Optional[FallbackKeys],
one_time_keys: Optional[OneTimeKeys],
) -> JsonDict:
"""
Args:
@@ -847,18 +903,16 @@ class E2eKeysHandler:
"""
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys and isinstance(device_keys, dict):
if device_keys:
# Validate that user_id and device_id match the requesting user
if (
device_keys["user_id"] == user_id
and device_keys["device_id"] == device_id
device_keys.user_id.to_string() == user_id
and device_keys.device_id == device_id
):
await self.upload_device_keys_for_user(
user_id=user_id,
device_id=device_id,
keys={"device_keys": device_keys},
user_id,
device_id,
device_keys,
)
else:
log_kv(
@@ -870,8 +924,7 @@ class E2eKeysHandler:
else:
log_kv({"message": "Did not update device_keys", "reason": "not a dict"})
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys and isinstance(one_time_keys, dict):
if one_time_keys:
log_kv(
{
"message": "Updating one_time_keys for device.",
@@ -888,10 +941,8 @@ class E2eKeysHandler:
log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"}
)
fallback_keys = keys.get("fallback_keys") or keys.get(
"org.matrix.msc2732.fallback_keys"
)
if fallback_keys and isinstance(fallback_keys, dict):
if fallback_keys:
log_kv(
{
"message": "Updating fallback_keys for device.",
@@ -900,8 +951,6 @@ class E2eKeysHandler:
}
)
await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
elif fallback_keys:
log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
else:
log_kv(
{"message": "Did not update fallback_keys", "reason": "no keys given"}
@@ -914,7 +963,10 @@ class E2eKeysHandler:
@tag_args
async def upload_device_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
self,
user_id: str,
device_id: str,
device_keys: DeviceKeys,
) -> None:
"""
Args:
@@ -925,7 +977,6 @@ class E2eKeysHandler:
"""
time_now = self.clock.time_msec()
device_keys = keys["device_keys"]
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id,
@@ -955,7 +1006,11 @@ class E2eKeysHandler:
await self.device_handler.check_device_registered(user_id, device_id)
async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
self,
user_id: str,
device_id: str,
time_now: int,
one_time_keys: OneTimeKeys,
) -> None:
# We take out a lock so that we don't have to worry about a client
# sending duplicate requests.
@@ -1742,20 +1797,20 @@ def _exception_to_failure(e: Exception) -> JsonDict:
return {"status": 503, "message": str(e)}
def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
def _one_time_keys_match(old_key_json: str, new_key: Union[str, KeyObject]) -> bool:
old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly
if not isinstance(old_key, dict) or not isinstance(new_key, dict):
if isinstance(old_key, str) or isinstance(new_key, str):
return old_key == new_key
# otherwise, we strip off the 'signatures' if any, because it's legitimate
# for different upload attempts to have different signatures.
old_key.pop("signatures", None)
new_key_copy = dict(new_key)
new_key_copy.pop("signatures", None)
# new_key must be a `KeyObject`
return old_key == new_key_copy
# Otherwise, check whether the embedded keys match.
#
# We ignore signatures, because it's legitimate for different upload
# attempts to have different signatures.
return old_key["key"] == new_key.key
@attr.s(slots=True, auto_attribs=True)

View File

@@ -23,7 +23,25 @@
import logging
import re
from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
)
from pydantic import (
AfterValidator,
AliasChoices,
Field,
StrictBool,
StrictStr,
)
from synapse.api.auth.mas import MasDelegatedAuth
from synapse.api.errors import (
@@ -31,9 +49,16 @@ from synapse.api.errors import (
InvalidAPICallError,
SynapseError,
)
from synapse.handlers.e2e_keys import (
DeviceKeys,
FallbackKeys,
KeyObject,
OneTimeKeys,
)
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_and_validate_json_object_from_request,
parse_integer,
parse_json_object_from_request,
parse_string,
@@ -41,7 +66,8 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken
from synapse.types import JsonDict, StreamToken, UserIDType
from synapse.types.rest import RequestBodyModel
from synapse.util.cancellation import cancellable
if TYPE_CHECKING:
@@ -111,12 +137,94 @@ class KeyUploadServlet(RestServlet):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
class KeyUploadRequestBody(RequestBodyModel):
"""
The body of a `POST /_matrix/client/v3/keys/upload` request.
Based on https://spec.matrix.org/v1.16/client-server-api/#post_matrixclientv3keysupload.
"""
class DeviceKeys(RequestBodyModel):
algorithms: List[StrictStr]
"""The encryption algorithms supported by this device."""
device_id: StrictStr
"""The ID of the device these keys belong to. Must match the device ID used when logging in."""
keys: Mapping[StrictStr, StrictStr]
"""
Public identity keys. The names of the properties should be in the
format `<algorithm>:<device_id>`. The keys themselves should be encoded as
specified by the key algorithm.
"""
signatures: Mapping[UserIDType, Mapping[StrictStr, StrictStr]]
"""Signatures for the device key object. A map from user ID, to a map from "<algorithm>:<device_id>" to the signature."""
user_id: UserIDType
"""The ID of the user the device belongs to. Must match the user ID used when logging in."""
class KeyObject(RequestBodyModel):
key: StrictStr
"""The key, encoded using unpadded base64."""
# TODO: Is this only allowed on fallback keys?
fallback: StrictBool = False
"""Whether this is a fallback key."""
signatures: Mapping[UserIDType, Mapping[StrictStr, StrictStr]]
"""Signature for the device. Mapped from user ID to another map of key signing identifier to the signature itself.
See the following for more detail: https://spec.matrix.org/v1.16/appendices/#signing-details
"""
device_keys: Optional[DeviceKeys] = None
"""Identity keys for the device. May be absent if no new identity keys are required."""
fallback_keys: Optional[Mapping[StrictStr, Union[StrictStr, KeyObject]]] = (
Field(
default_factory=lambda: None,
validation_alias=AliasChoices(
"fallback_keys",
# Accept this field alias, which is the unstable equivalent to
# the `fallback_keys` field from MSC2732.
"org.matrix.msc2732.fallback_keys",
),
serialization_alias="fallback_keys",
)
)
"""
The public key which should be used if the device's one-time keys are
exhausted. The fallback key is not deleted once used, but should be
replaced when additional one-time keys are being uploaded. The server
will notify the client of the fallback key being used through `/sync`.
There can only be at most one key per algorithm uploaded, and the server
will only persist one key per algorithm.
When uploading a signed key, an additional fallback: true key should be
included to denote that the key is a fallback key.
May be absent if a new fallback key is not required.
"""
one_time_keys: Optional[Mapping[StrictStr, Union[StrictStr, KeyObject]]] = None
"""
One-time public keys for “pre-key” messages. The names of the properties
should be in the format `<algorithm>:<key_id>`.
The format of the key is determined by the key algorithm, see:
https://spec.matrix.org/v1.16/client-server-api/#key-algorithms.
"""
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
body = parse_and_validate_json_object_from_request(
request, self.KeyUploadRequestBody
)
if device_id is not None:
# Providing the device_id should only be done for setting keys
@@ -149,12 +257,75 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating"
)
# Map the pydantic model to domain objects.
device_keys, fallback_keys, one_time_keys = (
self._map_pydantic_model_to_domain_objects(body)
)
result = await self.e2e_keys_handler.upload_keys_for_user(
user_id=user_id, device_id=device_id, keys=body
user_id,
device_id,
device_keys,
fallback_keys,
one_time_keys,
)
return 200, result
def _map_pydantic_model_to_domain_objects(
self, body: KeyUploadRequestBody
) -> Tuple[
Optional[DeviceKeys],
Optional[FallbackKeys],
Optional[OneTimeKeys],
]:
"""Map a validated pydantic model to internal data classes."""
device_keys: Optional[DeviceKeys] = None
if body.device_keys is not None:
device_keys = DeviceKeys(
algorithms=body.device_keys.algorithms,
device_id=body.device_keys.device_id,
keys=body.device_keys.keys,
signatures=body.device_keys.signatures,
user_id=body.device_keys.user_id,
)
fallback_keys: Optional[FallbackKeys] = None
if body.fallback_keys is not None:
fallback_keys = {}
for (
algorithm_and_key_id,
public_key_or_object,
) in body.fallback_keys.items():
if isinstance(public_key_or_object, str):
fallback_keys[algorithm_and_key_id] = public_key_or_object
else:
fallback_key_object: KeyUploadServlet.KeyUploadRequestBody.KeyObject = public_key_or_object
fallback_keys[algorithm_and_key_id] = KeyObject(
key=fallback_key_object.key,
signatures=fallback_key_object.signatures,
fallback=fallback_key_object.fallback,
)
one_time_keys: Optional[OneTimeKeys] = None
if body.one_time_keys is not None:
one_time_keys = {}
for (
algorithm_and_key_id,
public_key_or_object,
) in body.one_time_keys.items():
if isinstance(public_key_or_object, str):
one_time_keys[algorithm_and_key_id] = public_key_or_object
else:
one_time_key_object: KeyUploadServlet.KeyUploadRequestBody.KeyObject = public_key_or_object
one_time_keys[algorithm_and_key_id] = KeyObject(
key=one_time_key_object.key,
signatures=one_time_key_object.signatures,
fallback=one_time_key_object.fallback,
)
return device_keys, fallback_keys, one_time_keys
class KeyQueryServlet(RestServlet):
"""

View File

@@ -67,7 +67,7 @@ from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
from synapse.handlers.e2e_keys import DeviceKeys, FallbackKeys, SignatureListItem
from synapse.server import HomeServer
@@ -802,7 +802,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
async def set_e2e_fallback_keys(
self, user_id: str, device_id: str, fallback_keys: JsonDict
self,
user_id: str,
device_id: str,
fallback_keys: "FallbackKeys",
) -> None:
"""Set the user's e2e fallback keys.
@@ -829,7 +832,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
txn: LoggingTransaction,
user_id: str,
device_id: str,
fallback_keys: JsonDict,
fallback_keys: "FallbackKeys",
) -> None:
"""Set the user's e2e fallback keys.
@@ -1650,16 +1653,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
self,
user_id: str,
device_id: str,
time_now: int,
device_keys: "DeviceKeys",
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
Args:
user_id: user_id of the user to store keys for
device_id: device_id of the device to store keys for
time_now: time at the request to store the keys
device_keys: the keys to store
Args:
user_id: user_id of the user to store keys for
device_id: device_id of the device to store keys for
time_now: time at the request to store the keys
device_keys: the keys to store
"""
return await self.db_pool.runInteraction(
@@ -1677,7 +1684,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
user_id: str,
device_id: str,
time_now: int,
device_keys: JsonDict,
device_keys: "DeviceKeys",
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.

View File

@@ -27,6 +27,7 @@ from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,
Annotated,
Any,
ClassVar,
Dict,
@@ -48,6 +49,7 @@ from typing import (
import attr
from immutabledict import immutabledict
from pydantic import BeforeValidator, PlainSerializer, WithJsonSchema
from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey
from typing_extensions import Self
@@ -361,6 +363,37 @@ class RoomIdWithDomain(DomainSpecificString):
SIGIL = "!"
def _parse_user_id(user_id_str: Any) -> Any:
if isinstance(user_id_str, str):
try:
return UserID.from_string(user_id_str)
except Exception as e:
raise ValueError(
f"Unable to parse string '{user_id_str}' as valid Matrix User ID"
) from e
raise ValueError(f"Expected a string, found {type(user_id_str)}")
UserIDType = Annotated[
UserID,
BeforeValidator(_parse_user_id),
PlainSerializer(lambda uid: uid.to_string(), return_type=str),
WithJsonSchema(
{
"type": "string",
"description": "Matrix User ID",
"pattern": r"^@[^:]+:[^:]+$",
"examples": ["@alice:example.org"],
}
),
]
"""
A User ID type that can be used in Pydantic models.
Validates that the input value is a `str` and can be parsed as a Matrix User ID.
"""
# the set of urlsafe base64 characters, no padding.
ROOM_ID_PATTERN_DOMAINLESS = re.compile(r"^[A-Za-z0-9\-_]{43}$")

View File

@@ -52,6 +52,7 @@ from synapse.types import (
StreamToken,
ThreadSubscriptionsToken,
UserID,
UserIDType,
)
from synapse.types.rest.client import SlidingSyncBody
@@ -67,7 +68,7 @@ class SlidingSyncConfig(SlidingSyncBody):
extra fields that we need in the handler
"""
user: UserID
user: UserIDType
requester: Requester
# Pydantic config

View File

@@ -21,7 +21,9 @@
from twisted.internet.testing import MemoryReactor
from synapse.handlers.e2e_keys import DeviceKeys
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -30,47 +32,55 @@ from tests.unittest import HomeserverTestCase
class EndToEndKeyStoreTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.now_ms = 1470174257070
self.test_user_id = "@alice:test"
self.test_device_id = "TEST_DEVICE"
self.test_device_keys = self._create_test_device_keys(self.test_user_id, self.test_device_id)
def _create_test_device_keys(self, user_id: str, device_id: str, public_key: str = "test_public_key") -> DeviceKeys:
"""Create and return a test `DeviceKeys` object."""
return DeviceKeys(
algorithms=["ed25519"],
device_id=device_id,
keys={
f"ed25519:{device_id}": public_key,
},
signatures={},
user_id=UserID.from_string(user_id),
)
def test_key_without_device_name(self) -> None:
now = 1470174257070
json = {"key": "value"}
self.get_success(self.store.store_device(self.test_user_id, self.test_device_id, None))
self.get_success(self.store.store_device("user", "device", None))
self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
self.get_success(self.store.set_e2e_device_keys(self.test_user_id, self.test_device_id, self.now_ms, self.test_device_keys))
res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
self.store.get_e2e_device_keys_for_cs_api(((self.test_user_id, self.test_device_id),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
self.assertLessEqual(json.items(), dev.items())
self.assertIn(self.test_user_id, res)
self.assertIn(self.test_device_id, res[self.test_user_id])
device_keys = res[self.test_user_id][self.test_device_id]
print(device_keys)
def test_reupload_key(self) -> None:
now = 1470174257070
json = {"key": "value"}
self.get_success(self.store.store_device("user", "device", None))
changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
self.store.set_e2e_device_keys("user", "device", self.now_ms, self.test_device_keys)
)
self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing
# changed
changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
self.store.set_e2e_device_keys("user", "device", self.now_ms, self.test_device_keys)
)
self.assertFalse(changed)
def test_get_key_with_device_name(self) -> None:
now = 1470174257070
json = {"key": "value"}
self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
self.get_success(self.store.store_device("user", "device", "display_name"))
self.get_success(self.store.set_e2e_device_keys(self.test_user_id, self.test_device_id, self.now_ms, self.test_device_keys))
self.get_success(self.store.store_device(self.test_user_id, self.test_device_id, "display_name"))
res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
@@ -87,34 +97,37 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
)
def test_multiple_devices(self) -> None:
now = 1470174257070
user_one = "@user1:test"
user_two = "@user2:test"
device_id_one = "DEVICE_ID_1"
device_id_two = "DEVICE_ID_2"
self.get_success(self.store.store_device("user1", "device1", None))
self.get_success(self.store.store_device("user1", "device2", None))
self.get_success(self.store.store_device("user2", "device1", None))
self.get_success(self.store.store_device("user2", "device2", None))
self.get_success(self.store.store_device(user_one, device_id_one, None))
self.get_success(self.store.store_device(user_one, device_id_two, None))
self.get_success(self.store.store_device(user_two, device_id_one, None))
self.get_success(self.store.store_device(user_two, device_id_two, None))
self.get_success(
self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
self.store.set_e2e_device_keys(user_one, device_id_one, self.now_ms, self._create_test_device_keys(user_one, device_id_one, "json11"))
)
self.get_success(
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
self.store.set_e2e_device_keys(user_one, device_id_two, self.now_ms, self._create_test_device_keys(user_one, device_id_two, "json12"))
)
self.get_success(
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
self.store.set_e2e_device_keys(user_two, device_id_one, self.now_ms, self._create_test_device_keys(user_two, device_id_one, "json21"))
)
self.get_success(
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
self.store.set_e2e_device_keys(user_two, device_id_two, self.now_ms, self._create_test_device_keys(user_two, device_id_two, "json22"))
)
res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api(
(("user1", "device1"), ("user2", "device2"))
((user_one, device_id_one), (user_two, device_id_two))
)
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
self.assertNotIn("device2", res["user1"])
self.assertIn("user2", res)
self.assertNotIn("device1", res["user2"])
self.assertIn("device2", res["user2"])
self.assertIn(user_one, res)
self.assertIn(device_id_one, res[user_one])
self.assertNotIn(device_id_two, res[user_one])
self.assertIn(user_two, res)
self.assertNotIn(device_id_one, res[user_two])
self.assertIn(device_id_two, res[user_two])