1
0

run black

This commit is contained in:
Hubert Chathi
2020-07-27 16:48:32 -04:00
parent 52ddb79781
commit e60a99deb6
4 changed files with 52 additions and 28 deletions

View File

@@ -492,12 +492,17 @@ class DeviceHandler(DeviceWorkerHandler):
await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
async def store_dehydrated_device(
self, user_id: str, device_data: str,
initial_device_display_name: Optional[str] = None) -> str:
self,
user_id: str,
device_data: str,
initial_device_display_name: Optional[str] = None,
) -> str:
device_id = await self.check_device_registered(
user_id, None, initial_device_display_name,
)
old_device_id = await self.store.store_dehydrated_device(user_id, device_id, device_data)
old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data
)
if old_device_id is not None:
await self.delete_device(user_id, old_device_id)
return device_id
@@ -505,8 +510,12 @@ class DeviceHandler(DeviceWorkerHandler):
async def get_dehydrated_device(self, user_id: str) -> Tuple[str, str]:
return await self.store.get_dehydrated_device(user_id)
async def get_dehydration_token(self, user_id: str, device_id: str, login_submission: JsonDict) -> str:
return await self.store.create_dehydration_token(user_id, device_id, json.dumps(login_submission))
async def get_dehydration_token(
self, user_id: str, device_id: str, login_submission: JsonDict
) -> str:
return await self.store.create_dehydration_token(
user_id, device_id, json.dumps(login_submission)
)
async def rehydrate_device(self, token: str) -> dict:
# FIXME: if can't find token, return 404
@@ -519,9 +528,13 @@ class DeviceHandler(DeviceWorkerHandler):
if token_info["dehydrated"]:
# create access token for dehydrated device
initial_display_name = None # FIXME: get display name from login submission?
initial_display_name = (
None # FIXME: get display name from login submission?
)
device_id, access_token = await registration_handler.register_device(
token_info.get("user_id"), token_info.get("device_id"), initial_display_name
token_info.get("user_id"),
token_info.get("device_id"),
initial_display_name,
)
return {

View File

@@ -341,9 +341,14 @@ class LoginRestServlet(RestServlet):
user_id = canonical_uid
if login_submission.get("org.matrix.msc2697.restore_device"):
device_id, dehydrated_device = await self.device_handler.get_dehydrated_device(user_id)
(
device_id,
dehydrated_device,
) = await self.device_handler.get_dehydrated_device(user_id)
if dehydrated_device:
token = await self.device_handler.get_dehydration_token(user_id, device_id, login_submission)
token = await self.device_handler.get_dehydration_token(
user_id, device_id, login_submission
)
result = {
"user_id": user_id,
"home_server": self.hs.hostname,
@@ -430,9 +435,19 @@ class RestoreDeviceServlet(RestServlet):
submission = parse_json_object_from_request(request)
if submission.get("rehydrate"):
return 200, await self.device_handler.rehydrate_device(submission.get("dehydration_token"))
return (
200,
await self.device_handler.rehydrate_device(
submission.get("dehydration_token")
),
)
else:
return 200, await self.device_handler.cancel_rehydrate(submission.get("dehydration_token"))
return (
200,
await self.device_handler.cancel_rehydrate(
submission.get("dehydration_token")
),
)
class StoreDeviceServlet(RestServlet):

View File

@@ -79,7 +79,10 @@ class KeyUploadServlet(RestServlet):
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
if requester.device_id is not None and device_id != requester.device_id:
dehydrated_device_id, _ = await self.device_handler.get_dehydrated_device(user_id)
(
dehydrated_device_id,
_,
) = await self.device_handler.get_dehydrated_device(user_id)
if device_id != dehydrated_device_id:
set_tag("error", True)
log_kv(

View File

@@ -738,7 +738,7 @@ class DeviceWorkerStore(SQLBaseStore):
return (row["device_id"], row["device_data"]) if row else (None, None)
def _store_dehydrated_device_txn(
self, txn, user_id: str, device_id: str, device_data: str
self, txn, user_id: str, device_id: str, device_data: str
) -> Optional[str]:
old_device_id = self.db.simple_select_one_onecol_txn(
txn,
@@ -762,24 +762,23 @@ class DeviceWorkerStore(SQLBaseStore):
txn,
table="dehydrated_devices",
keyvalues={"user_id", user_id},
updatevalues={
"device_id": device_id,
"device_data": device_data,
},
updatevalues={"device_id": device_id, "device_data": device_data,},
)
return old_device_id
async def store_dehydrated_device(
self, user_id: str, device_id: str, device_data: str
self, user_id: str, device_id: str, device_data: str
) -> Optional[str]:
return await self.db.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id, device_id, device_data,
user_id,
device_id,
device_data,
)
async def create_dehydration_token(
self, user_id: str, device_id: str, login_submission: str
self, user_id: str, device_id: str, login_submission: str
) -> str:
# FIXME: expire any old tokens
@@ -808,17 +807,11 @@ class DeviceWorkerStore(SQLBaseStore):
token_info = self.db.simple_select_one_txn(
txn,
"dehydration_token",
{
"token": token,
},
{"token": token,},
["user_id", "device_id", "login_submission"],
)
self.db.simple_delete_one_txn(
txn,
"dehydration_token",
{
"token": token,
},
txn, "dehydration_token", {"token": token,},
)
if dehydrate: