stuff we want to bring over
This commit is contained in:
@@ -20,12 +20,22 @@
|
||||
#
|
||||
#
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.key import VerifyKey, decode_verify_key_bytes
|
||||
from signedjson.sign import SignatureVerifyException, verify_signed_json
|
||||
from typing_extensions import TypeAlias
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from twisted.internet import defer
|
||||
@@ -60,6 +70,47 @@ logger = logging.getLogger(__name__)
|
||||
ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class DeviceKeys:
|
||||
algorithms: List[str]
|
||||
"""The encryption algorithms supported by this device."""
|
||||
|
||||
device_id: str
|
||||
"""The ID of the device these keys belong to. Must match the device ID used when logging in."""
|
||||
|
||||
keys: Mapping[str, str]
|
||||
"""
|
||||
Public identity keys. The names of the properties should be in the
|
||||
format `<algorithm>:<device_id>`. The keys themselves should be encoded as
|
||||
specified by the key algorithm.
|
||||
"""
|
||||
|
||||
signatures: Mapping[UserID, Mapping[str, str]]
|
||||
"""Signatures for the device key object. A map from user ID, to a map from "<algorithm>:<device_id>" to the signature."""
|
||||
|
||||
user_id: UserID
|
||||
"""The ID of the user the device belongs to. Must match the user ID used when logging in."""
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class KeyObject:
|
||||
key: str
|
||||
"""The key, encoded using unpadded base64."""
|
||||
|
||||
signatures: Mapping[UserID, Mapping[str, str]]
|
||||
"""Signature for the device. Mapped from user ID to another map of key signing identifier to the signature itself.
|
||||
|
||||
See the following for more detail: https://spec.matrix.org/v1.16/appendices/#signing-details
|
||||
"""
|
||||
|
||||
fallback: bool = False
|
||||
"""Whether this is a fallback key."""
|
||||
|
||||
|
||||
FallbackKeys: TypeAlias = Mapping[str, Union[str, KeyObject]]
|
||||
OneTimeKeys: TypeAlias = Mapping[str, Union[str, KeyObject]]
|
||||
|
||||
|
||||
class E2eKeysHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.config = hs.config
|
||||
@@ -833,7 +884,12 @@ class E2eKeysHandler:
|
||||
|
||||
@tag_args
|
||||
async def upload_keys_for_user(
|
||||
self, user_id: str, device_id: str, keys: JsonDict
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
device_keys: Optional[DeviceKeys],
|
||||
fallback_keys: Optional[FallbackKeys],
|
||||
one_time_keys: Optional[OneTimeKeys],
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
@@ -847,18 +903,16 @@ class E2eKeysHandler:
|
||||
"""
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
device_keys = keys.get("device_keys", None)
|
||||
if device_keys and isinstance(device_keys, dict):
|
||||
if device_keys:
|
||||
# Validate that user_id and device_id match the requesting user
|
||||
if (
|
||||
device_keys["user_id"] == user_id
|
||||
and device_keys["device_id"] == device_id
|
||||
device_keys.user_id.to_string() == user_id
|
||||
and device_keys.device_id == device_id
|
||||
):
|
||||
await self.upload_device_keys_for_user(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
keys={"device_keys": device_keys},
|
||||
user_id,
|
||||
device_id,
|
||||
device_keys,
|
||||
)
|
||||
else:
|
||||
log_kv(
|
||||
@@ -870,8 +924,7 @@ class E2eKeysHandler:
|
||||
else:
|
||||
log_kv({"message": "Did not update device_keys", "reason": "not a dict"})
|
||||
|
||||
one_time_keys = keys.get("one_time_keys", None)
|
||||
if one_time_keys and isinstance(one_time_keys, dict):
|
||||
if one_time_keys:
|
||||
log_kv(
|
||||
{
|
||||
"message": "Updating one_time_keys for device.",
|
||||
@@ -888,10 +941,8 @@ class E2eKeysHandler:
|
||||
log_kv(
|
||||
{"message": "Did not update one_time_keys", "reason": "no keys given"}
|
||||
)
|
||||
fallback_keys = keys.get("fallback_keys") or keys.get(
|
||||
"org.matrix.msc2732.fallback_keys"
|
||||
)
|
||||
if fallback_keys and isinstance(fallback_keys, dict):
|
||||
|
||||
if fallback_keys:
|
||||
log_kv(
|
||||
{
|
||||
"message": "Updating fallback_keys for device.",
|
||||
@@ -900,8 +951,6 @@ class E2eKeysHandler:
|
||||
}
|
||||
)
|
||||
await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
|
||||
elif fallback_keys:
|
||||
log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
|
||||
else:
|
||||
log_kv(
|
||||
{"message": "Did not update fallback_keys", "reason": "no keys given"}
|
||||
@@ -914,7 +963,10 @@ class E2eKeysHandler:
|
||||
|
||||
@tag_args
|
||||
async def upload_device_keys_for_user(
|
||||
self, user_id: str, device_id: str, keys: JsonDict
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
device_keys: DeviceKeys,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@@ -925,7 +977,6 @@ class E2eKeysHandler:
|
||||
"""
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
device_keys = keys["device_keys"]
|
||||
logger.info(
|
||||
"Updating device_keys for device %r for user %s at %d",
|
||||
device_id,
|
||||
@@ -955,7 +1006,11 @@ class E2eKeysHandler:
|
||||
await self.device_handler.check_device_registered(user_id, device_id)
|
||||
|
||||
async def _upload_one_time_keys_for_user(
|
||||
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
time_now: int,
|
||||
one_time_keys: OneTimeKeys,
|
||||
) -> None:
|
||||
# We take out a lock so that we don't have to worry about a client
|
||||
# sending duplicate requests.
|
||||
@@ -1742,20 +1797,20 @@ def _exception_to_failure(e: Exception) -> JsonDict:
|
||||
return {"status": 503, "message": str(e)}
|
||||
|
||||
|
||||
def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
|
||||
def _one_time_keys_match(old_key_json: str, new_key: Union[str, KeyObject]) -> bool:
|
||||
old_key = json_decoder.decode(old_key_json)
|
||||
|
||||
# if either is a string rather than an object, they must match exactly
|
||||
if not isinstance(old_key, dict) or not isinstance(new_key, dict):
|
||||
if isinstance(old_key, str) or isinstance(new_key, str):
|
||||
return old_key == new_key
|
||||
|
||||
# otherwise, we strip off the 'signatures' if any, because it's legitimate
|
||||
# for different upload attempts to have different signatures.
|
||||
old_key.pop("signatures", None)
|
||||
new_key_copy = dict(new_key)
|
||||
new_key_copy.pop("signatures", None)
|
||||
# new_key must be a `KeyObject`
|
||||
|
||||
return old_key == new_key_copy
|
||||
# Otherwise, check whether the embedded keys match.
|
||||
#
|
||||
# We ignore signatures, because it's legitimate for different upload
|
||||
# attempts to have different signatures.
|
||||
return old_key["key"] == new_key.key
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
|
||||
@@ -23,7 +23,25 @@
|
||||
import logging
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
AfterValidator,
|
||||
AliasChoices,
|
||||
Field,
|
||||
StrictBool,
|
||||
StrictStr,
|
||||
)
|
||||
|
||||
from synapse.api.auth.mas import MasDelegatedAuth
|
||||
from synapse.api.errors import (
|
||||
@@ -31,9 +49,16 @@ from synapse.api.errors import (
|
||||
InvalidAPICallError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.handlers.e2e_keys import (
|
||||
DeviceKeys,
|
||||
FallbackKeys,
|
||||
KeyObject,
|
||||
OneTimeKeys,
|
||||
)
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_and_validate_json_object_from_request,
|
||||
parse_integer,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
@@ -41,7 +66,8 @@ from synapse.http.servlet import (
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import log_kv, set_tag
|
||||
from synapse.rest.client._base import client_patterns, interactive_auth_handler
|
||||
from synapse.types import JsonDict, StreamToken
|
||||
from synapse.types import JsonDict, StreamToken, UserIDType
|
||||
from synapse.types.rest import RequestBodyModel
|
||||
from synapse.util.cancellation import cancellable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -111,12 +137,94 @@ class KeyUploadServlet(RestServlet):
|
||||
self._clock = hs.get_clock()
|
||||
self._store = hs.get_datastores().main
|
||||
|
||||
class KeyUploadRequestBody(RequestBodyModel):
|
||||
"""
|
||||
The body of a `POST /_matrix/client/v3/keys/upload` request.
|
||||
|
||||
Based on https://spec.matrix.org/v1.16/client-server-api/#post_matrixclientv3keysupload.
|
||||
"""
|
||||
|
||||
class DeviceKeys(RequestBodyModel):
|
||||
algorithms: List[StrictStr]
|
||||
"""The encryption algorithms supported by this device."""
|
||||
|
||||
device_id: StrictStr
|
||||
"""The ID of the device these keys belong to. Must match the device ID used when logging in."""
|
||||
|
||||
keys: Mapping[StrictStr, StrictStr]
|
||||
"""
|
||||
Public identity keys. The names of the properties should be in the
|
||||
format `<algorithm>:<device_id>`. The keys themselves should be encoded as
|
||||
specified by the key algorithm.
|
||||
"""
|
||||
|
||||
signatures: Mapping[UserIDType, Mapping[StrictStr, StrictStr]]
|
||||
"""Signatures for the device key object. A map from user ID, to a map from "<algorithm>:<device_id>" to the signature."""
|
||||
|
||||
user_id: UserIDType
|
||||
"""The ID of the user the device belongs to. Must match the user ID used when logging in."""
|
||||
|
||||
class KeyObject(RequestBodyModel):
|
||||
key: StrictStr
|
||||
"""The key, encoded using unpadded base64."""
|
||||
|
||||
# TODO: Is this only allowed on fallback keys?
|
||||
fallback: StrictBool = False
|
||||
"""Whether this is a fallback key."""
|
||||
|
||||
signatures: Mapping[UserIDType, Mapping[StrictStr, StrictStr]]
|
||||
"""Signature for the device. Mapped from user ID to another map of key signing identifier to the signature itself.
|
||||
|
||||
See the following for more detail: https://spec.matrix.org/v1.16/appendices/#signing-details
|
||||
"""
|
||||
|
||||
device_keys: Optional[DeviceKeys] = None
|
||||
"""Identity keys for the device. May be absent if no new identity keys are required."""
|
||||
|
||||
fallback_keys: Optional[Mapping[StrictStr, Union[StrictStr, KeyObject]]] = (
|
||||
Field(
|
||||
default_factory=lambda: None,
|
||||
validation_alias=AliasChoices(
|
||||
"fallback_keys",
|
||||
# Accept this field alias, which is the unstable equivalent to
|
||||
# the `fallback_keys` field from MSC2732.
|
||||
"org.matrix.msc2732.fallback_keys",
|
||||
),
|
||||
serialization_alias="fallback_keys",
|
||||
)
|
||||
)
|
||||
"""
|
||||
The public key which should be used if the device's one-time keys are
|
||||
exhausted. The fallback key is not deleted once used, but should be
|
||||
replaced when additional one-time keys are being uploaded. The server
|
||||
will notify the client of the fallback key being used through `/sync`.
|
||||
|
||||
There can only be at most one key per algorithm uploaded, and the server
|
||||
will only persist one key per algorithm.
|
||||
|
||||
When uploading a signed key, an additional fallback: true key should be
|
||||
included to denote that the key is a fallback key.
|
||||
|
||||
May be absent if a new fallback key is not required.
|
||||
"""
|
||||
|
||||
one_time_keys: Optional[Mapping[StrictStr, Union[StrictStr, KeyObject]]] = None
|
||||
"""
|
||||
One-time public keys for “pre-key” messages. The names of the properties
|
||||
should be in the format `<algorithm>:<key_id>`.
|
||||
|
||||
The format of the key is determined by the key algorithm, see:
|
||||
https://spec.matrix.org/v1.16/client-server-api/#key-algorithms.
|
||||
"""
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, device_id: Optional[str]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
body = parse_and_validate_json_object_from_request(
|
||||
request, self.KeyUploadRequestBody
|
||||
)
|
||||
|
||||
if device_id is not None:
|
||||
# Providing the device_id should only be done for setting keys
|
||||
@@ -149,12 +257,75 @@ class KeyUploadServlet(RestServlet):
|
||||
400, "To upload keys, you must pass device_id when authenticating"
|
||||
)
|
||||
|
||||
# Map the pydantic model to domain objects.
|
||||
device_keys, fallback_keys, one_time_keys = (
|
||||
self._map_pydantic_model_to_domain_objects(body)
|
||||
)
|
||||
|
||||
result = await self.e2e_keys_handler.upload_keys_for_user(
|
||||
user_id=user_id, device_id=device_id, keys=body
|
||||
user_id,
|
||||
device_id,
|
||||
device_keys,
|
||||
fallback_keys,
|
||||
one_time_keys,
|
||||
)
|
||||
|
||||
return 200, result
|
||||
|
||||
def _map_pydantic_model_to_domain_objects(
|
||||
self, body: KeyUploadRequestBody
|
||||
) -> Tuple[
|
||||
Optional[DeviceKeys],
|
||||
Optional[FallbackKeys],
|
||||
Optional[OneTimeKeys],
|
||||
]:
|
||||
"""Map a validated pydantic model to internal data classes."""
|
||||
device_keys: Optional[DeviceKeys] = None
|
||||
if body.device_keys is not None:
|
||||
device_keys = DeviceKeys(
|
||||
algorithms=body.device_keys.algorithms,
|
||||
device_id=body.device_keys.device_id,
|
||||
keys=body.device_keys.keys,
|
||||
signatures=body.device_keys.signatures,
|
||||
user_id=body.device_keys.user_id,
|
||||
)
|
||||
|
||||
fallback_keys: Optional[FallbackKeys] = None
|
||||
if body.fallback_keys is not None:
|
||||
fallback_keys = {}
|
||||
for (
|
||||
algorithm_and_key_id,
|
||||
public_key_or_object,
|
||||
) in body.fallback_keys.items():
|
||||
if isinstance(public_key_or_object, str):
|
||||
fallback_keys[algorithm_and_key_id] = public_key_or_object
|
||||
else:
|
||||
fallback_key_object: KeyUploadServlet.KeyUploadRequestBody.KeyObject = public_key_or_object
|
||||
fallback_keys[algorithm_and_key_id] = KeyObject(
|
||||
key=fallback_key_object.key,
|
||||
signatures=fallback_key_object.signatures,
|
||||
fallback=fallback_key_object.fallback,
|
||||
)
|
||||
|
||||
one_time_keys: Optional[OneTimeKeys] = None
|
||||
if body.one_time_keys is not None:
|
||||
one_time_keys = {}
|
||||
for (
|
||||
algorithm_and_key_id,
|
||||
public_key_or_object,
|
||||
) in body.one_time_keys.items():
|
||||
if isinstance(public_key_or_object, str):
|
||||
one_time_keys[algorithm_and_key_id] = public_key_or_object
|
||||
else:
|
||||
one_time_key_object: KeyUploadServlet.KeyUploadRequestBody.KeyObject = public_key_or_object
|
||||
one_time_keys[algorithm_and_key_id] = KeyObject(
|
||||
key=one_time_key_object.key,
|
||||
signatures=one_time_key_object.signatures,
|
||||
fallback=one_time_key_object.fallback,
|
||||
)
|
||||
|
||||
return device_keys, fallback_keys, one_time_keys
|
||||
|
||||
|
||||
class KeyQueryServlet(RestServlet):
|
||||
"""
|
||||
|
||||
@@ -67,7 +67,7 @@ from synapse.util.cancellation import cancellable
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.handlers.e2e_keys import SignatureListItem
|
||||
from synapse.handlers.e2e_keys import DeviceKeys, FallbackKeys, SignatureListItem
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
@@ -802,7 +802,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
)
|
||||
|
||||
async def set_e2e_fallback_keys(
|
||||
self, user_id: str, device_id: str, fallback_keys: JsonDict
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
fallback_keys: "FallbackKeys",
|
||||
) -> None:
|
||||
"""Set the user's e2e fallback keys.
|
||||
|
||||
@@ -829,7 +832,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
fallback_keys: JsonDict,
|
||||
fallback_keys: "FallbackKeys",
|
||||
) -> None:
|
||||
"""Set the user's e2e fallback keys.
|
||||
|
||||
@@ -1650,16 +1653,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
)
|
||||
|
||||
async def set_e2e_device_keys(
|
||||
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
time_now: int,
|
||||
device_keys: "DeviceKeys",
|
||||
) -> bool:
|
||||
"""Stores device keys for a device. Returns whether there was a change
|
||||
or the keys were already in the database.
|
||||
|
||||
Args:
|
||||
user_id: user_id of the user to store keys for
|
||||
device_id: device_id of the device to store keys for
|
||||
time_now: time at the request to store the keys
|
||||
device_keys: the keys to store
|
||||
Args:
|
||||
user_id: user_id of the user to store keys for
|
||||
device_id: device_id of the device to store keys for
|
||||
time_now: time at the request to store the keys
|
||||
device_keys: the keys to store
|
||||
"""
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
@@ -1677,7 +1684,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
time_now: int,
|
||||
device_keys: JsonDict,
|
||||
device_keys: "DeviceKeys",
|
||||
) -> bool:
|
||||
"""Stores device keys for a device. Returns whether there was a change
|
||||
or the keys were already in the database.
|
||||
|
||||
@@ -27,6 +27,7 @@ from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Annotated,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
@@ -48,6 +49,7 @@ from typing import (
|
||||
|
||||
import attr
|
||||
from immutabledict import immutabledict
|
||||
from pydantic import BeforeValidator, PlainSerializer, WithJsonSchema
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.types import VerifyKey
|
||||
from typing_extensions import Self
|
||||
@@ -361,6 +363,37 @@ class RoomIdWithDomain(DomainSpecificString):
|
||||
SIGIL = "!"
|
||||
|
||||
|
||||
def _parse_user_id(user_id_str: Any) -> Any:
|
||||
if isinstance(user_id_str, str):
|
||||
try:
|
||||
return UserID.from_string(user_id_str)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Unable to parse string '{user_id_str}' as valid Matrix User ID"
|
||||
) from e
|
||||
raise ValueError(f"Expected a string, found {type(user_id_str)}")
|
||||
|
||||
|
||||
UserIDType = Annotated[
|
||||
UserID,
|
||||
BeforeValidator(_parse_user_id),
|
||||
PlainSerializer(lambda uid: uid.to_string(), return_type=str),
|
||||
WithJsonSchema(
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Matrix User ID",
|
||||
"pattern": r"^@[^:]+:[^:]+$",
|
||||
"examples": ["@alice:example.org"],
|
||||
}
|
||||
),
|
||||
]
|
||||
"""
|
||||
A User ID type that can be used in Pydantic models.
|
||||
|
||||
Validates that the input value is a `str` and can be parsed as a Matrix User ID.
|
||||
"""
|
||||
|
||||
|
||||
# the set of urlsafe base64 characters, no padding.
|
||||
ROOM_ID_PATTERN_DOMAINLESS = re.compile(r"^[A-Za-z0-9\-_]{43}$")
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ from synapse.types import (
|
||||
StreamToken,
|
||||
ThreadSubscriptionsToken,
|
||||
UserID,
|
||||
UserIDType,
|
||||
)
|
||||
from synapse.types.rest.client import SlidingSyncBody
|
||||
|
||||
@@ -67,7 +68,7 @@ class SlidingSyncConfig(SlidingSyncBody):
|
||||
extra fields that we need in the handler
|
||||
"""
|
||||
|
||||
user: UserID
|
||||
user: UserIDType
|
||||
requester: Requester
|
||||
|
||||
# Pydantic config
|
||||
|
||||
@@ -21,7 +21,9 @@
|
||||
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.handlers.e2e_keys import DeviceKeys
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
@@ -30,47 +32,55 @@ from tests.unittest import HomeserverTestCase
|
||||
class EndToEndKeyStoreTestCase(HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.now_ms = 1470174257070
|
||||
self.test_user_id = "@alice:test"
|
||||
self.test_device_id = "TEST_DEVICE"
|
||||
self.test_device_keys = self._create_test_device_keys(self.test_user_id, self.test_device_id)
|
||||
|
||||
def _create_test_device_keys(self, user_id: str, device_id: str, public_key: str = "test_public_key") -> DeviceKeys:
|
||||
"""Create and return a test `DeviceKeys` object."""
|
||||
return DeviceKeys(
|
||||
algorithms=["ed25519"],
|
||||
device_id=device_id,
|
||||
keys={
|
||||
f"ed25519:{device_id}": public_key,
|
||||
},
|
||||
signatures={},
|
||||
user_id=UserID.from_string(user_id),
|
||||
)
|
||||
|
||||
def test_key_without_device_name(self) -> None:
|
||||
now = 1470174257070
|
||||
json = {"key": "value"}
|
||||
self.get_success(self.store.store_device(self.test_user_id, self.test_device_id, None))
|
||||
|
||||
self.get_success(self.store.store_device("user", "device", None))
|
||||
|
||||
self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
|
||||
self.get_success(self.store.set_e2e_device_keys(self.test_user_id, self.test_device_id, self.now_ms, self.test_device_keys))
|
||||
|
||||
res = self.get_success(
|
||||
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
|
||||
self.store.get_e2e_device_keys_for_cs_api(((self.test_user_id, self.test_device_id),))
|
||||
)
|
||||
self.assertIn("user", res)
|
||||
self.assertIn("device", res["user"])
|
||||
dev = res["user"]["device"]
|
||||
self.assertLessEqual(json.items(), dev.items())
|
||||
self.assertIn(self.test_user_id, res)
|
||||
self.assertIn(self.test_device_id, res[self.test_user_id])
|
||||
device_keys = res[self.test_user_id][self.test_device_id]
|
||||
|
||||
print(device_keys)
|
||||
|
||||
def test_reupload_key(self) -> None:
|
||||
now = 1470174257070
|
||||
json = {"key": "value"}
|
||||
|
||||
self.get_success(self.store.store_device("user", "device", None))
|
||||
|
||||
changed = self.get_success(
|
||||
self.store.set_e2e_device_keys("user", "device", now, json)
|
||||
self.store.set_e2e_device_keys("user", "device", self.now_ms, self.test_device_keys)
|
||||
)
|
||||
self.assertTrue(changed)
|
||||
|
||||
# If we try to upload the same key then we should be told nothing
|
||||
# changed
|
||||
changed = self.get_success(
|
||||
self.store.set_e2e_device_keys("user", "device", now, json)
|
||||
self.store.set_e2e_device_keys("user", "device", self.now_ms, self.test_device_keys)
|
||||
)
|
||||
self.assertFalse(changed)
|
||||
|
||||
def test_get_key_with_device_name(self) -> None:
|
||||
now = 1470174257070
|
||||
json = {"key": "value"}
|
||||
|
||||
self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
|
||||
self.get_success(self.store.store_device("user", "device", "display_name"))
|
||||
self.get_success(self.store.set_e2e_device_keys(self.test_user_id, self.test_device_id, self.now_ms, self.test_device_keys))
|
||||
self.get_success(self.store.store_device(self.test_user_id, self.test_device_id, "display_name"))
|
||||
|
||||
res = self.get_success(
|
||||
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
|
||||
@@ -87,34 +97,37 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
def test_multiple_devices(self) -> None:
|
||||
now = 1470174257070
|
||||
user_one = "@user1:test"
|
||||
user_two = "@user2:test"
|
||||
device_id_one = "DEVICE_ID_1"
|
||||
device_id_two = "DEVICE_ID_2"
|
||||
|
||||
self.get_success(self.store.store_device("user1", "device1", None))
|
||||
self.get_success(self.store.store_device("user1", "device2", None))
|
||||
self.get_success(self.store.store_device("user2", "device1", None))
|
||||
self.get_success(self.store.store_device("user2", "device2", None))
|
||||
self.get_success(self.store.store_device(user_one, device_id_one, None))
|
||||
self.get_success(self.store.store_device(user_one, device_id_two, None))
|
||||
self.get_success(self.store.store_device(user_two, device_id_one, None))
|
||||
self.get_success(self.store.store_device(user_two, device_id_two, None))
|
||||
|
||||
self.get_success(
|
||||
self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
|
||||
self.store.set_e2e_device_keys(user_one, device_id_one, self.now_ms, self._create_test_device_keys(user_one, device_id_one, "json11"))
|
||||
)
|
||||
self.get_success(
|
||||
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
|
||||
self.store.set_e2e_device_keys(user_one, device_id_two, self.now_ms, self._create_test_device_keys(user_one, device_id_two, "json12"))
|
||||
)
|
||||
self.get_success(
|
||||
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
|
||||
self.store.set_e2e_device_keys(user_two, device_id_one, self.now_ms, self._create_test_device_keys(user_two, device_id_one, "json21"))
|
||||
)
|
||||
self.get_success(
|
||||
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
|
||||
self.store.set_e2e_device_keys(user_two, device_id_two, self.now_ms, self._create_test_device_keys(user_two, device_id_two, "json22"))
|
||||
)
|
||||
|
||||
res = self.get_success(
|
||||
self.store.get_e2e_device_keys_for_cs_api(
|
||||
(("user1", "device1"), ("user2", "device2"))
|
||||
((user_one, device_id_one), (user_two, device_id_two))
|
||||
)
|
||||
)
|
||||
self.assertIn("user1", res)
|
||||
self.assertIn("device1", res["user1"])
|
||||
self.assertNotIn("device2", res["user1"])
|
||||
self.assertIn("user2", res)
|
||||
self.assertNotIn("device1", res["user2"])
|
||||
self.assertIn("device2", res["user2"])
|
||||
self.assertIn(user_one, res)
|
||||
self.assertIn(device_id_one, res[user_one])
|
||||
self.assertNotIn(device_id_two, res[user_one])
|
||||
self.assertIn(user_two, res)
|
||||
self.assertNotIn(device_id_one, res[user_two])
|
||||
self.assertIn(device_id_two, res[user_two])
|
||||
|
||||
Reference in New Issue
Block a user