1
0

Merge commit '66f24449d' into anoa/dinsic_release_1_21_x

* commit '66f24449d':
  Improve performance of the register endpoint (#8009)
This commit is contained in:
Andrew Morgan
2020-10-19 14:29:04 +01:00
8 changed files with 188 additions and 72 deletions

1
changelog.d/8009.misc Normal file
View File

@@ -0,0 +1 @@
Improve the performance of the register endpoint.

View File

@@ -239,14 +239,16 @@ class InteractiveAuthIncompleteError(Exception):
(This indicates we should return a 401 with 'result' as the body)
Attributes:
session_id: The ID of the ongoing interactive auth session.
result: the server response to the request, which should be
passed back to the client
"""
def __init__(self, result: "JsonDict"):
def __init__(self, session_id: str, result: "JsonDict"):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete"
)
self.session_id = session_id
self.result = result

View File

@@ -162,7 +162,7 @@ class AuthHandler(BaseHandler):
request_body: Dict[str, Any],
clientip: str,
description: str,
) -> dict:
) -> Tuple[dict, str]:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -183,9 +183,14 @@ class AuthHandler(BaseHandler):
describes the operation happening on their account.
Returns:
The parameters for this request (which may
A tuple of (params, session_id).
'params' contains the parameters for this request (which may
have been given only in a previous call).
'session_id' is the ID of this session, either passed in by the
client or assigned by this call
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
any of the permitted login flows
@@ -207,7 +212,7 @@ class AuthHandler(BaseHandler):
flows = [[login_type] for login_type in self._supported_ui_auth_types]
try:
result, params, _ = await self.check_auth(
result, params, session_id = await self.check_ui_auth(
flows, request, request_body, clientip, description
)
except LoginError:
@@ -230,7 +235,7 @@ class AuthHandler(BaseHandler):
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
return params
return params, session_id
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -240,7 +245,7 @@ class AuthHandler(BaseHandler):
"""
return self.checkers.keys()
async def check_auth(
async def check_ui_auth(
self,
flows: List[List[str]],
request: SynapseRequest,
@@ -363,7 +368,7 @@ class AuthHandler(BaseHandler):
if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session.session_id)
session.session_id, self._auth_dict_for_flows(flows, session.session_id)
)
# check auth type currently being presented
@@ -410,7 +415,7 @@ class AuthHandler(BaseHandler):
ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
raise InteractiveAuthIncompleteError(session.session_id, ret)
async def add_oob_auth(
self, stagetype: str, authdict: Dict[str, Any], clientip: str

View File

@@ -26,7 +26,12 @@ if TYPE_CHECKING:
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
from synapse.api.errors import (
Codes,
InteractiveAuthIncompleteError,
SynapseError,
ThreepidValidationError,
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
@@ -253,18 +258,12 @@ class PasswordRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the new password provided to us.
if "new_password" in body:
new_password = body.pop("new_password")
new_password = body.pop("new_password", None)
if new_password is not None:
if not isinstance(new_password, str) or len(new_password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(new_password)
# If the password is valid, hash it and store it back on the body.
# This ensures that only the hashed password is handled everywhere.
if "new_password_hash" in body:
raise SynapseError(400, "Unexpected property: new_password_hash")
body["new_password_hash"] = await self.auth_handler.hash(new_password)
# there are two possibilities here. Either the user does not have an
# access token, and needs to do a password reset; or they have one and
# need to validate their identity.
@@ -281,23 +280,52 @@ class PasswordRestServlet(RestServlet):
if requester.app_service:
params = body
else:
params = await self.auth_handler.validate_user_via_ui_auth(
requester,
try:
(
params,
session_id,
) = await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth, but
# they're not required to provide the password again.
#
# If a password is available now, hash the provided password and
# store it for later.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
)
raise
user_id = requester.user.to_string()
else:
requester = None
try:
result, params, session_id = await self.auth_handler.check_ui_auth(
[[LoginType.EMAIL_IDENTITY]],
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
)
user_id = requester.user.to_string()
else:
requester = None
result, params, _ = await self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY]],
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth, but
# they're not required to provide the password again.
#
# If a password is available now, hash the provided password and
# store it for later.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
)
raise
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
@@ -322,12 +350,21 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
assert_params_in_dict(params, ["new_password_hash"])
new_password_hash = params["new_password_hash"]
# If we have a password in this request, prefer it. Otherwise, there
# must be a password hash from an earlier request.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
else:
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
)
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(
user_id, new_password_hash, logout_devices, requester
user_id, password_hash, logout_devices, requester
)
if self.hs.config.shadow_server:

View File

@@ -26,6 +26,7 @@ import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
InteractiveAuthIncompleteError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
@@ -385,6 +386,7 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
self._registration_enabled = self.hs.config.enable_registration
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -410,22 +412,6 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
desired_password_hash = None
if "password" in body:
password = body.pop("password")
if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(password)
# If the password is valid, hash it and store it back on the body.
# This ensures that only the hashed password is handled everywhere.
if "password_hash" in body:
raise SynapseError(400, "Unexpected property: password_hash")
body["password_hash"] = await self.auth_handler.hash(password)
desired_password_hash = body["password_hash"]
# We don't care about usernames for this deployment. In fact, the act
# of checking whether they exist already can leak metadata about
# which users are already registered.
@@ -442,6 +428,11 @@ class RegisterRestServlet(RestServlet):
if self.auth.has_access_token(request):
appservice = await self.auth.get_appservice_by_req(request)
# We need to retrieve the password early in order to pass it to
# application service registration
# This is specific to shadow server registration of users via an AS
password = body.pop("password", None)
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -461,21 +452,28 @@ class RegisterRestServlet(RestServlet):
if isinstance(desired_username, str):
result = await self._do_appservice_registration(
desired_username,
desired_password_hash,
desired_display_name,
access_token,
body,
desired_username, password, desired_display_name, access_token, body
)
return 200, result # we throw for non 200 responses
# == Normal User Registration == (everyone else)
if not self.hs.config.enable_registration:
if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled")
# Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)
if "initial_device_display_name" in body and "password_hash" not in body:
# Pull out the provided password and do basic sanity checks early.
#
# Note that we remove the password from the body since the auth layer
# will store the body in the session and we don't want a plaintext
# password store there.
if password is not None:
if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(password)
if "initial_device_display_name" in body and password is None:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
@@ -485,6 +483,7 @@ class RegisterRestServlet(RestServlet):
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
password_hash = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
@@ -493,21 +492,43 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
# Extract the previously-hashed password from the session.
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
)
auth_result, params, session_id = await self.auth_handler.check_auth(
self._registration_flows,
request,
body,
self.hs.get_ip_from_request(request),
"register a new account",
)
# Check if the user-interactive authentication flows are complete, if
# not this will raise a user-interactive auth error.
try:
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
self._registration_flows,
request,
body,
self.hs.get_ip_from_request(request),
"register a new account",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth.
#
# Hash the password and store it with the session since the client
# is not required to provide the password again.
#
# If a password hash was previously stored we will not attempt to
# re-hash and store it for efficiency. This assumes the password
# does not change throughout the authentication flow, but this
# should be fine since the data is meant to be consistent.
if not password_hash and password:
password_hash = await self.auth_handler.hash(password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
)
raise
# Check that we're not trying to register a denied 3pid.
#
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
@@ -603,8 +624,12 @@ class RegisterRestServlet(RestServlet):
# don't re-register the threepids
registered = False
else:
# NB: This may be from the auth handler and NOT from the POST
assert_params_in_dict(params, ["password_hash"])
# If we have a password in this request, prefer it. Otherwise, there
# might be a password hash from an earlier request.
if password:
password_hash = await self.auth_handler.hash(password)
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
if not self.hs.config.register_mxid_from_3pid:
desired_username = params.get("username", None)
@@ -613,7 +638,6 @@ class RegisterRestServlet(RestServlet):
pass
guest_access_token = params.get("guest_access_token", None)
new_password_hash = params.get("password_hash", None)
if desired_username is not None:
desired_username = desired_username.lower()
@@ -655,7 +679,7 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password_hash=new_password_hash,
password_hash=password_hash,
guest_access_token=guest_access_token,
default_display_name=desired_display_name,
threepid=threepid,
@@ -677,8 +701,8 @@ class RegisterRestServlet(RestServlet):
params=params,
)
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
@@ -702,12 +726,20 @@ class RegisterRestServlet(RestServlet):
return 200, {}
async def _do_appservice_registration(
self, username, password_hash, display_name, as_token, body
self, username, password, display_name, as_token, body
):
# FIXME: appservice_register() is horribly duplicated with register()
# and they should probably just be combined together with a config flag.
if password:
# Hash the password
#
# In mainline hashing of the password was moved further on in the registration
# flow, but we need it here for the AS use-case of shadow servers
password = await self.auth_handler.hash(password)
user_id = await self.registration_handler.appservice_register(
username, as_token, password_hash, display_name
username, as_token, password, display_name
)
result = await self._create_registration_details(user_id, body)

View File

@@ -0,0 +1,23 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* A flag saying whether the user owning the profile has been deactivated
* This really belongs on the users table, not here, but the users table
* stores users by their full user_id and profiles stores them by localpart,
* so we can't easily join between the two tables. Plus, the batch number
* realy ought to represent data in this table that has changed.
*/
ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;

View File

@@ -0,0 +1,16 @@
/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE UNIQUE INDEX profile_replication_status_idx ON profile_replication_status(host);

View File

@@ -112,8 +112,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
@override_config({"enable_registration": False})
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)