Compare commits
11 Commits
clokep/sta
...
uhoreg/deh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8811973107 | ||
|
|
afd9aa673a | ||
|
|
4f1619c4eb | ||
|
|
b7398f7489 | ||
|
|
c7c8f2822d | ||
|
|
460ebc558a | ||
|
|
96d9fc3410 | ||
|
|
b59bc664e6 | ||
|
|
0f9e402e16 | ||
|
|
e60a99deb6 | ||
|
|
52ddb79781 |
1
changelog.d/7955.feature
Normal file
1
changelog.d/7955.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add support for device dehydration. (MSC2697)
|
||||
@@ -14,8 +14,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.api.constants import EventTypes
|
||||
@@ -28,6 +29,7 @@ from synapse.api.errors import (
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
RoomStreamToken,
|
||||
get_domain_from_id,
|
||||
get_verify_key_from_cross_signing_key,
|
||||
@@ -489,6 +491,137 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
# receive device updates. Mark this in DB.
|
||||
await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
|
||||
|
||||
async def store_dehydrated_device(
|
||||
self,
|
||||
user_id: str,
|
||||
device_data: JsonDict,
|
||||
initial_device_display_name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Store a dehydrated device for a user. If the user had a previous
|
||||
dehydrated device, it is removed.
|
||||
|
||||
Args:
|
||||
user_id: the user that we are storing the device for
|
||||
device_data: the dehydrated device information
|
||||
initial_device_display_name: The display name to use for the device
|
||||
Returns:
|
||||
device id of the dehydrated device
|
||||
"""
|
||||
device_id = await self.check_device_registered(
|
||||
user_id, None, initial_device_display_name,
|
||||
)
|
||||
old_device_id = await self.store.store_dehydrated_device(
|
||||
user_id, device_id, device_data
|
||||
)
|
||||
if old_device_id is not None:
|
||||
await self.delete_device(user_id, old_device_id)
|
||||
return device_id
|
||||
|
||||
async def get_dehydrated_device(self, user_id: str) -> Tuple[str, JsonDict]:
|
||||
"""Retrieve the information for a dehydrated device.
|
||||
|
||||
Args:
|
||||
user_id: the user whose dehydrated device we are looking for
|
||||
Returns:
|
||||
a tuple whose first item is the device ID, and the second item is
|
||||
the dehydrated device information
|
||||
"""
|
||||
return await self.store.get_dehydrated_device(user_id)
|
||||
|
||||
async def create_dehydration_token(
|
||||
self, user_id: str, device_id: str, login_submission: JsonDict
|
||||
) -> str:
|
||||
"""Create a token for a client to fulfill a dehydration request.
|
||||
|
||||
Args:
|
||||
user_id: the user that we are creating the token for
|
||||
device_id: the device ID for the dehydrated device. This is to
|
||||
ensure that the device still exists when the user tells us
|
||||
they want to use the dehydrated device.
|
||||
login_submission: the contents of the login request.
|
||||
Returns:
|
||||
the dehydration token
|
||||
"""
|
||||
return await self.store.create_dehydration_token(
|
||||
user_id, device_id, login_submission
|
||||
)
|
||||
|
||||
async def rehydrate_device(self, token: str) -> dict:
|
||||
"""Process a rehydration request from the user.
|
||||
|
||||
Args:
|
||||
token: the dehydration token
|
||||
Returns:
|
||||
the login result, including the user's access token and device ID
|
||||
"""
|
||||
# FIXME: if can't find token, return 404
|
||||
token_info = await self.store.clear_dehydration_token(token, True)
|
||||
|
||||
# normally, the constructor would do self.registration_handler =
|
||||
# self.hs.get_registration_handler(), but doing that results in a
|
||||
# circular dependency in the handlers. So do this for now
|
||||
registration_handler = self.hs.get_registration_handler()
|
||||
|
||||
if token_info["dehydrated"]:
|
||||
# create access token for dehydrated device
|
||||
initial_display_name = (
|
||||
None # FIXME: get display name from login submission?
|
||||
)
|
||||
device_id, access_token = await registration_handler.register_device(
|
||||
token_info.get("user_id"),
|
||||
token_info.get("device_id"),
|
||||
initial_display_name,
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": token_info["user_id"],
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
else:
|
||||
# create device and access token from original login submission
|
||||
login_submission = token_info["login_submission"]
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = await registration_handler.register_device(
|
||||
token_info.get("user_id"), device_id, initial_display_name
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": token.info["user_id"],
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
async def cancel_rehydrate(self, token: str) -> dict:
|
||||
"""Cancel a rehydration request from the user and complete the user's login.
|
||||
|
||||
Args:
|
||||
token: the dehydration token
|
||||
Returns:
|
||||
the login result, including the user's access token and device ID
|
||||
"""
|
||||
# FIXME: if can't find token, return 404
|
||||
token_info = await self.store.clear_dehydration_token(token, False)
|
||||
# create device and access token from original login submission
|
||||
login_submission = token_info["login_submission"]
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
registration_handler = self.hs.get_registration_handler()
|
||||
device_id, access_token = await registration_handler.register_device(
|
||||
token_info.get("user_id"), device_id, initial_display_name
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": token_info.get("user_id"),
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
|
||||
def _update_device_from_client_ips(device, client_ips):
|
||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||
|
||||
@@ -496,6 +496,22 @@ class E2eKeysHandler(object):
|
||||
log_kv(
|
||||
{"message": "Did not update one_time_keys", "reason": "no keys given"}
|
||||
)
|
||||
fallback_keys = keys.get("fallback_keys", None)
|
||||
if fallback_keys and isinstance(fallback_keys, dict):
|
||||
log_kv(
|
||||
{
|
||||
"message": "Updating fallback_keys for device.",
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
}
|
||||
)
|
||||
await self.store.set_e2e_fallback_keys(
|
||||
user_id, device_id, fallback_keys
|
||||
)
|
||||
else:
|
||||
log_kv(
|
||||
{"message": "Did not update fallback_keys", "reason": "no keys given"}
|
||||
)
|
||||
|
||||
# the device should have been registered already, but it may have been
|
||||
# deleted due to a race with a DELETE request. Or we may be using an
|
||||
|
||||
@@ -203,6 +203,8 @@ class SyncResult:
|
||||
device_lists: List of user_ids whose devices have changed
|
||||
device_one_time_keys_count: Dict of algorithm to count for one time keys
|
||||
for this device
|
||||
device_unused_fallback_keys: List of key types that have an unused fallback
|
||||
key
|
||||
groups: Group updates, if any
|
||||
"""
|
||||
|
||||
@@ -215,6 +217,7 @@ class SyncResult:
|
||||
to_device = attr.ib(type=List[JsonDict])
|
||||
device_lists = attr.ib(type=DeviceLists)
|
||||
device_one_time_keys_count = attr.ib(type=JsonDict)
|
||||
device_unused_fallback_keys = attr.ib(type=List[str])
|
||||
groups = attr.ib(type=Optional[GroupsSyncResult])
|
||||
|
||||
def __nonzero__(self) -> bool:
|
||||
@@ -1024,10 +1027,14 @@ class SyncHandler(object):
|
||||
logger.debug("Fetching OTK data")
|
||||
device_id = sync_config.device_id
|
||||
one_time_key_counts = {} # type: JsonDict
|
||||
unused_fallback_keys = [] # type: list
|
||||
if device_id:
|
||||
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
||||
user_id, device_id
|
||||
)
|
||||
unused_fallback_keys = await self.store.get_e2e_unused_fallback_keys(
|
||||
user_id, device_id
|
||||
)
|
||||
|
||||
logger.debug("Fetching group data")
|
||||
await self._generate_sync_entry_for_groups(sync_result_builder)
|
||||
@@ -1051,6 +1058,7 @@ class SyncHandler(object):
|
||||
device_lists=device_lists,
|
||||
groups=sync_result_builder.groups,
|
||||
device_one_time_keys_count=one_time_key_counts,
|
||||
device_unused_fallback_keys=unused_fallback_keys,
|
||||
next_batch=sync_result_builder.now_token,
|
||||
)
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ class LoginRestServlet(RestServlet):
|
||||
self.oidc_enabled = hs.config.oidc_enabled
|
||||
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
self.handlers = hs.get_handlers()
|
||||
self._well_known_builder = WellKnownBuilder(hs)
|
||||
@@ -339,6 +340,29 @@ class LoginRestServlet(RestServlet):
|
||||
)
|
||||
user_id = canonical_uid
|
||||
|
||||
if login_submission.get("org.matrix.msc2697.restore_device"):
|
||||
# user requested to rehydrate a device, so check if there they have
|
||||
# a dehydrated device, and if so, allow them to try to rehydrate it
|
||||
(
|
||||
device_id,
|
||||
dehydrated_device,
|
||||
) = await self.device_handler.get_dehydrated_device(user_id)
|
||||
if dehydrated_device:
|
||||
token = await self.device_handler.create_dehydration_token(
|
||||
user_id, device_id, login_submission
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_data": dehydrated_device,
|
||||
"device_id": device_id,
|
||||
"dehydration_token": token,
|
||||
}
|
||||
|
||||
# FIXME: call callback?
|
||||
|
||||
return result
|
||||
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = await self.registration_handler.register_device(
|
||||
@@ -401,6 +425,96 @@ class LoginRestServlet(RestServlet):
|
||||
return result
|
||||
|
||||
|
||||
class RestoreDeviceServlet(RestServlet):
|
||||
"""Complete a rehydration request, either by letting the client use the
|
||||
dehydrated device, or by creating a new device for the user.
|
||||
|
||||
POST /org.matrix.msc2697/restore_device
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"rehydrate": true,
|
||||
"dehydration_token": "an_opaque_token"
|
||||
}
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
|
||||
{ // same format as the result from a /login request
|
||||
"user_id": "@alice:example.org",
|
||||
"device_id": "dehydrated_device",
|
||||
"access_token": "another_opaque_token"
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/org.matrix.msc2697/restore_device")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RestoreDeviceServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self._well_known_builder = WellKnownBuilder(hs)
|
||||
|
||||
async def on_POST(self, request: SynapseRequest):
|
||||
submission = parse_json_object_from_request(request)
|
||||
|
||||
if submission.get("rehydrate"):
|
||||
result = await self.device_handler.rehydrate_device(
|
||||
submission["dehydration_token"]
|
||||
)
|
||||
else:
|
||||
result = await self.device_handler.cancel_rehydrate(
|
||||
submission["dehydration_token"]
|
||||
)
|
||||
well_known_data = self._well_known_builder.get_well_known()
|
||||
if well_known_data:
|
||||
result["well_known"] = well_known_data
|
||||
return (200, result)
|
||||
|
||||
|
||||
class StoreDeviceServlet(RestServlet):
|
||||
"""Store a dehydrated device.
|
||||
|
||||
POST /org.matrix.msc2697/device/dehydrate
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"device_data": {
|
||||
"algorithm": "m.dehydration.v1.olm",
|
||||
"account": "dehydrated_device"
|
||||
}
|
||||
}
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"device_id": "dehydrated_device_id"
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/org.matrix.msc2697/device/dehydrate")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(StoreDeviceServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest):
|
||||
submission = parse_json_object_from_request(request)
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
device_id = await self.device_handler.store_dehydrated_device(
|
||||
requester.user.to_string(),
|
||||
submission["device_data"],
|
||||
submission.get("initial_device_display_name", None)
|
||||
)
|
||||
return 200, {"device_id": device_id}
|
||||
|
||||
|
||||
class BaseSSORedirectServlet(RestServlet):
|
||||
"""Common base class for /login/sso/redirect impls"""
|
||||
|
||||
@@ -499,6 +613,8 @@ class OIDCRedirectServlet(BaseSSORedirectServlet):
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
RestoreDeviceServlet(hs).register(http_server)
|
||||
StoreDeviceServlet(hs).register(http_server)
|
||||
if hs.config.cas_enabled:
|
||||
CasRedirectServlet(hs).register(http_server)
|
||||
CasTicketServlet(hs).register(http_server)
|
||||
|
||||
@@ -67,6 +67,7 @@ class KeyUploadServlet(RestServlet):
|
||||
super(KeyUploadServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@trace(opname="upload_keys")
|
||||
async def on_POST(self, request, device_id):
|
||||
@@ -78,20 +79,25 @@ class KeyUploadServlet(RestServlet):
|
||||
# passing the device_id here is deprecated; however, we allow it
|
||||
# for now for compatibility with older clients.
|
||||
if requester.device_id is not None and device_id != requester.device_id:
|
||||
set_tag("error", True)
|
||||
log_kv(
|
||||
{
|
||||
"message": "Client uploading keys for a different device",
|
||||
"logged_in_id": requester.device_id,
|
||||
"key_being_uploaded": device_id,
|
||||
}
|
||||
)
|
||||
logger.warning(
|
||||
"Client uploading keys for a different device "
|
||||
"(logged in as %s, uploading for %s)",
|
||||
requester.device_id,
|
||||
device_id,
|
||||
)
|
||||
(
|
||||
dehydrated_device_id,
|
||||
_,
|
||||
) = await self.device_handler.get_dehydrated_device(user_id)
|
||||
if device_id != dehydrated_device_id:
|
||||
set_tag("error", True)
|
||||
log_kv(
|
||||
{
|
||||
"message": "Client uploading keys for a different device",
|
||||
"logged_in_id": requester.device_id,
|
||||
"key_being_uploaded": device_id,
|
||||
}
|
||||
)
|
||||
logger.warning(
|
||||
"Client uploading keys for a different device "
|
||||
"(logged in as %s, uploading for %s)",
|
||||
requester.device_id,
|
||||
device_id,
|
||||
)
|
||||
else:
|
||||
device_id = requester.device_id
|
||||
|
||||
|
||||
@@ -237,6 +237,7 @@ class SyncRestServlet(RestServlet):
|
||||
"leave": sync_result.groups.leave,
|
||||
},
|
||||
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
||||
"device_unused_fallback_keys": sync_result.device_unused_fallback_keys,
|
||||
"next_batch": sync_result.next_batch.to_string(),
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from canonicaljson import json
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, StoreError
|
||||
@@ -33,9 +34,14 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import Cache, cached, cachedList
|
||||
from synapse.util.caches.descriptors import (
|
||||
Cache,
|
||||
cached,
|
||||
cachedInlineCallbacks,
|
||||
cachedList,
|
||||
)
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.stringutils import shortstr
|
||||
from synapse.util.stringutils import random_string, shortstr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -746,6 +752,168 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
_mark_remote_user_device_list_as_unsubscribed_txn,
|
||||
)
|
||||
|
||||
async def get_dehydrated_device(self, user_id: str) -> Tuple[str, JsonDict]:
|
||||
"""Retrieve the information for a dehydrated device.
|
||||
|
||||
Args:
|
||||
user_id: the user whose dehydrated device we are looking for
|
||||
Returns:
|
||||
a tuple whose first item is the device ID, and the second item is
|
||||
the dehydrated device information
|
||||
"""
|
||||
# FIXME: make sure device ID still exists in devices table
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="dehydrated_devices",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["device_id", "device_data"],
|
||||
allow_none=True,
|
||||
)
|
||||
return (row["device_id"], json.loads(row["device_data"])) if row else (None, None)
|
||||
|
||||
def _store_dehydrated_device_txn(
|
||||
self, txn, user_id: str, device_id: str, device_data: str
|
||||
) -> Optional[str]:
|
||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="dehydrated_devices",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="device_id",
|
||||
allow_none=True,
|
||||
)
|
||||
if old_device_id is None:
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="dehydrated_devices",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"device_data": device_data,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
table="dehydrated_devices",
|
||||
keyvalues={"user_id": user_id},
|
||||
updatevalues={"device_id": device_id, "device_data": device_data},
|
||||
)
|
||||
return old_device_id
|
||||
|
||||
async def store_dehydrated_device(
|
||||
self, user_id: str, device_id: str, device_data: JsonDict
|
||||
) -> Optional[str]:
|
||||
"""Store a dehydrated device for a user.
|
||||
|
||||
Args:
|
||||
user_id: the user that we are storing the device for
|
||||
device_data: the dehydrated device information
|
||||
initial_device_display_name: The display name to use for the device
|
||||
Returns:
|
||||
device id of the user's previous dehydrated device, if any
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"store_dehydrated_device_txn",
|
||||
self._store_dehydrated_device_txn,
|
||||
user_id,
|
||||
device_id,
|
||||
json_encoder.encode(device_data),
|
||||
)
|
||||
|
||||
async def create_dehydration_token(
|
||||
self, user_id: str, device_id: str, login_submission: JsonDict
|
||||
) -> str:
|
||||
"""Create a token for a client to fulfill a dehydration request.
|
||||
|
||||
Args:
|
||||
user_id: the user that we are creating the token for
|
||||
device_id: the device ID for the dehydrated device. This is to
|
||||
ensure that the device still exists when the user tells us
|
||||
they want to use the dehydrated device.
|
||||
login_submission: the contents of the login request.
|
||||
Returns:
|
||||
the dehydration token
|
||||
"""
|
||||
# FIXME: expire any old tokens
|
||||
|
||||
attempts = 0
|
||||
while attempts < 5:
|
||||
token = random_string(24)
|
||||
|
||||
try:
|
||||
await self.db_pool.simple_insert(
|
||||
table="dehydration_token",
|
||||
values={
|
||||
"token": token,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"login_submission": json_encoder.encode(login_submission),
|
||||
"creation_time": self.hs.get_clock().time_msec(),
|
||||
},
|
||||
desc="create_dehydration_token",
|
||||
)
|
||||
return token
|
||||
except self.db_pool.engine.module.IntegrityError:
|
||||
attempts += 1
|
||||
raise StoreError(500, "Couldn't generate a token.")
|
||||
|
||||
def _clear_dehydration_token_txn(self, txn, token: str, dehydrate: bool) -> dict:
|
||||
token_info = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"dehydration_token",
|
||||
{"token": token},
|
||||
["user_id", "device_id", "login_submission"],
|
||||
)
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, "dehydration_token", {"token": token},
|
||||
)
|
||||
token_info["login_submission"] = json.loads(token_info["login_submission"])
|
||||
|
||||
if dehydrate:
|
||||
device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
"dehydrated_devices",
|
||||
keyvalues={"user_id": token_info["user_id"]},
|
||||
retcol="device_id",
|
||||
allow_none=True,
|
||||
)
|
||||
token_info["dehydrated"] = False
|
||||
if device_id == token_info["device_id"]:
|
||||
count = self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
"dehydrated_devices",
|
||||
{
|
||||
"user_id": token_info["user_id"],
|
||||
"device_id": token_info["device_id"],
|
||||
},
|
||||
)
|
||||
if count != 0:
|
||||
token_info["dehydrated"] = True
|
||||
|
||||
return token_info
|
||||
|
||||
async def clear_dehydration_token(self, token: str, dehydrate: bool) -> dict:
|
||||
"""Use a dehydration token. If the client wishes to use the dehydrated
|
||||
device, it will also remove the dehydrated device.
|
||||
|
||||
Args:
|
||||
token: the dehydration token
|
||||
dehydrate: whether the client wishes to use the dehydrated device
|
||||
Returns:
|
||||
A dict giving the information related to the token. It will have
|
||||
the following properties:
|
||||
- user_id: the user associated from the token
|
||||
- device_id: the ID of the dehydrated device
|
||||
- login_submission: the original submission to /login
|
||||
- dehydrated: (only present if the "dehydrate" parameter is True).
|
||||
Whether the dehydrated device can be used by the client.
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_whose_devices_changed",
|
||||
self._clear_dehydration_token_txn,
|
||||
token,
|
||||
dehydrate,
|
||||
)
|
||||
|
||||
|
||||
class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
|
||||
@@ -271,6 +271,46 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
async def set_e2e_fallback_keys(
|
||||
self, user_id: str, device_id: str, fallback_keys: dict
|
||||
):
|
||||
# fallback_keys will usually only have one item in it, so using a for
|
||||
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
|
||||
# FIXME: make sure that only one key per algorithm is uploaded
|
||||
for key_id, fallback_key in fallback_keys.items():
|
||||
algorithm, key_id = key_id.split(":", 1)
|
||||
await self.db_pool.simple_upsert(
|
||||
"e2e_fallback_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm
|
||||
},
|
||||
values={
|
||||
"key_id": key_id,
|
||||
"key_json": json_encoder.encode(fallback_key),
|
||||
"used": 0
|
||||
},
|
||||
desc="set_e2e_fallback_key"
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
async def get_e2e_unused_fallback_keys(
|
||||
self, user_id: str, device_id: str
|
||||
):
|
||||
return await self.db_pool.simple_select_onecol(
|
||||
"e2e_fallback_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"used": 0
|
||||
},
|
||||
retcol="algorithm",
|
||||
desc="get_e2e_unused_fallback_keys"
|
||||
)
|
||||
|
||||
# FIXME: delete fallbacks when user logs out
|
||||
|
||||
async def get_e2e_cross_signing_key(
|
||||
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
||||
) -> Optional[dict]:
|
||||
@@ -590,15 +630,29 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
" LIMIT 1"
|
||||
)
|
||||
fallback_sql = (
|
||||
"SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
|
||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
" LIMIT 1"
|
||||
)
|
||||
result = {}
|
||||
delete = []
|
||||
used_fallbacks = []
|
||||
for user_id, device_id, algorithm in query_list:
|
||||
user_result = result.setdefault(user_id, {})
|
||||
device_result = user_result.setdefault(device_id, {})
|
||||
txn.execute(sql, (user_id, device_id, algorithm))
|
||||
found = False
|
||||
for key_id, key_json in txn:
|
||||
found = True
|
||||
device_result[algorithm + ":" + key_id] = key_json
|
||||
delete.append((user_id, device_id, algorithm, key_id))
|
||||
if not found:
|
||||
txn.execute(fallback_sql, (user_id, device_id, algorithm))
|
||||
for key_id, key_json, used in txn:
|
||||
device_result[algorithm + ":" + key_id] = key_json
|
||||
if used == 0:
|
||||
used_fallbacks.append((user_id, device_id, algorithm, key_id))
|
||||
sql = (
|
||||
"DELETE FROM e2e_one_time_keys_json"
|
||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
@@ -615,6 +669,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
for user_id, device_id, algorithm, key_id in used_fallbacks:
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
"e2e_fallback_keys_json",
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
"key_id": key_id
|
||||
},
|
||||
{
|
||||
"used": 1
|
||||
}
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_e2e_unused_fallback_keys, (user_id, device_id)
|
||||
)
|
||||
return result
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
@@ -643,6 +714,20 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="dehydrated_devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_e2e_unused_fallback_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS dehydrated_devices(
|
||||
user_id TEXT NOT NULL PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
device_data TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS dehydration_token(
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
login_submission TEXT NOT NULL,
|
||||
creation_time BIGINT NOT NULL
|
||||
);
|
||||
|
||||
-- FIXME: index on creation_time to expire old tokens
|
||||
@@ -0,0 +1,24 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
|
||||
user_id TEXT NOT NULL, -- The user this fallback key is for.
|
||||
device_id TEXT NOT NULL, -- The device this fallback key is for.
|
||||
algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
|
||||
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
|
||||
key_json TEXT NOT NULL, -- The key as a JSON blob.
|
||||
used SMALLINT NOT NULL DEFAULT 0, -- Whether the key has been used or not.
|
||||
CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
|
||||
);
|
||||
@@ -754,3 +754,68 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
||||
channel.json_body["error"],
|
||||
"JWT validation failed: Signature verification failed",
|
||||
)
|
||||
|
||||
|
||||
class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
login.register_servlets,
|
||||
logout.register_servlets,
|
||||
devices.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.hs = self.setup_test_homeserver()
|
||||
self.hs.config.enable_registration = True
|
||||
self.hs.config.registrations_require_3pid = []
|
||||
self.hs.config.auto_join_rooms = []
|
||||
self.hs.config.enable_registration_captcha = False
|
||||
|
||||
return self.hs
|
||||
|
||||
def test_dehydrate_and_rehydrate_device(self):
|
||||
self.register_user("kermit", "monkey")
|
||||
access_token = self.login("kermit", "monkey")
|
||||
|
||||
# dehydrate a device
|
||||
params = json.dumps({"device_data": "foobar"})
|
||||
request, channel = self.make_request(
|
||||
b"POST",
|
||||
b"/_matrix/client/unstable/org.matrix.msc2697/device/dehydrate",
|
||||
params,
|
||||
access_token=access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.code, 200, channel.result)
|
||||
dehydrated_device_id = channel.json_body["device_id"]
|
||||
|
||||
# Log out
|
||||
request, channel = self.make_request(
|
||||
b"POST", "/logout", access_token=access_token
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
# log in, requesting a dehydrated device
|
||||
params = json.dumps(
|
||||
{
|
||||
"type": "m.login.password",
|
||||
"user": "kermit",
|
||||
"password": "monkey",
|
||||
"org.matrix.msc2697.restore_device": True,
|
||||
}
|
||||
)
|
||||
request, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(channel.json_body["device_data"], "foobar")
|
||||
self.assertEqual(channel.json_body["device_id"], dehydrated_device_id)
|
||||
dehydration_token = channel.json_body["dehydration_token"]
|
||||
|
||||
params = json.dumps({"rehydrate": True, "dehydration_token": dehydration_token})
|
||||
request, channel = self.make_request(
|
||||
"POST", "/_matrix/client/unstable/org.matrix.msc2697/restore_device", params
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(channel.json_body["device_id"], dehydrated_device_id)
|
||||
|
||||
Reference in New Issue
Block a user