Merge branch 'rav/refactor_ui_auth' into neilj/fix_check_threepid_for_msisdns
This commit is contained in:
1
changelog.d/6037.feature
Normal file
1
changelog.d/6037.feature
Normal file
@@ -0,0 +1 @@
|
||||
Make the process for mapping SAML2 users to matrix IDs more flexible.
|
||||
1
changelog.d/6069.bugfix
Normal file
1
changelog.d/6069.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug which caused SAML attribute maps to be overridden by defaults.
|
||||
1
changelog.d/6092.bugfix
Normal file
1
changelog.d/6092.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix the logged number of updated items for the users_set_deactivated_flag background update.
|
||||
1
changelog.d/6099.misc
Normal file
1
changelog.d/6099.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove unused parameter to get_user_id_by_threepid.
|
||||
1
changelog.d/6105.misc
Normal file
1
changelog.d/6105.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor the user-interactive auth handling.
|
||||
@@ -1174,6 +1174,32 @@ saml2_config:
|
||||
#
|
||||
#saml_session_lifetime: 5m
|
||||
|
||||
# The SAML attribute (after mapping via the attribute maps) to use to derive
|
||||
# the Matrix ID from. 'uid' by default.
|
||||
#
|
||||
#mxid_source_attribute: displayName
|
||||
|
||||
# The mapping system to use for mapping the saml attribute onto a matrix ID.
|
||||
# Options include:
|
||||
# * 'hexencode' (which maps unpermitted characters to '=xx')
|
||||
# * 'dotreplace' (which replaces unpermitted characters with '.').
|
||||
# The default is 'hexencode'.
|
||||
#
|
||||
#mxid_mapping: dotreplace
|
||||
|
||||
# In previous versions of synapse, the mapping from SAML attribute to MXID was
|
||||
# always calculated dynamically rather than stored in a table. For backwards-
|
||||
# compatibility, we will look for user_ids matching such a pattern before
|
||||
# creating a new account.
|
||||
#
|
||||
# This setting controls the SAML attribute which will be used for this
|
||||
# backwards-compatibility lookup. Typically it should be 'uid', but if the
|
||||
# attribute maps are changed, it may be necessary to change it.
|
||||
#
|
||||
# The default is 'uid'.
|
||||
#
|
||||
#grandfathered_mxid_source_attribute: upn
|
||||
|
||||
|
||||
|
||||
# Enable CAS for registration and login.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,11 +13,47 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
from synapse.types import (
|
||||
map_username_to_mxid_localpart,
|
||||
mxid_localpart_allowed_characters,
|
||||
)
|
||||
from synapse.util.module_loader import load_python_module
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
||||
def _dict_merge(merge_dict, into_dict):
|
||||
"""Do a deep merge of two dicts
|
||||
|
||||
Recursively merges `merge_dict` into `into_dict`:
|
||||
* For keys where both `merge_dict` and `into_dict` have a dict value, the values
|
||||
are recursively merged
|
||||
* For all other keys, the values in `into_dict` (if any) are overwritten with
|
||||
the value from `merge_dict`.
|
||||
|
||||
Args:
|
||||
merge_dict (dict): dict to merge
|
||||
into_dict (dict): target dict
|
||||
"""
|
||||
for k, v in merge_dict.items():
|
||||
if k not in into_dict:
|
||||
into_dict[k] = v
|
||||
continue
|
||||
|
||||
current_val = into_dict[k]
|
||||
|
||||
if isinstance(v, dict) and isinstance(current_val, dict):
|
||||
_dict_merge(v, current_val)
|
||||
continue
|
||||
|
||||
# otherwise we just overwrite
|
||||
into_dict[k] = v
|
||||
|
||||
|
||||
class SAML2Config(Config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.saml2_enabled = False
|
||||
@@ -36,21 +73,40 @@ class SAML2Config(Config):
|
||||
|
||||
self.saml2_enabled = True
|
||||
|
||||
import saml2.config
|
||||
self.saml2_mxid_source_attribute = saml2_config.get(
|
||||
"mxid_source_attribute", "uid"
|
||||
)
|
||||
|
||||
self.saml2_sp_config = saml2.config.SPConfig()
|
||||
self.saml2_sp_config.load(self._default_saml_config_dict())
|
||||
self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
|
||||
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
|
||||
"grandfathered_mxid_source_attribute", "uid"
|
||||
)
|
||||
|
||||
saml2_config_dict = self._default_saml_config_dict()
|
||||
_dict_merge(
|
||||
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
|
||||
)
|
||||
|
||||
config_path = saml2_config.get("config_path", None)
|
||||
if config_path is not None:
|
||||
self.saml2_sp_config.load_file(config_path)
|
||||
mod = load_python_module(config_path)
|
||||
_dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict)
|
||||
|
||||
import saml2.config
|
||||
|
||||
self.saml2_sp_config = saml2.config.SPConfig()
|
||||
self.saml2_sp_config.load(saml2_config_dict)
|
||||
|
||||
# session lifetime: in milliseconds
|
||||
self.saml2_session_lifetime = self.parse_duration(
|
||||
saml2_config.get("saml_session_lifetime", "5m")
|
||||
)
|
||||
|
||||
mapping = saml2_config.get("mxid_mapping", "hexencode")
|
||||
try:
|
||||
self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
|
||||
except KeyError:
|
||||
raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
|
||||
|
||||
def _default_saml_config_dict(self):
|
||||
import saml2
|
||||
|
||||
@@ -58,6 +114,13 @@ class SAML2Config(Config):
|
||||
if public_baseurl is None:
|
||||
raise ConfigError("saml2_config requires a public_baseurl to be set")
|
||||
|
||||
required_attributes = {"uid", self.saml2_mxid_source_attribute}
|
||||
|
||||
optional_attributes = {"displayName"}
|
||||
if self.saml2_grandfathered_mxid_source_attribute:
|
||||
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
|
||||
optional_attributes -= required_attributes
|
||||
|
||||
metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
|
||||
response_url = public_baseurl + "_matrix/saml2/authn_response"
|
||||
return {
|
||||
@@ -69,8 +132,9 @@ class SAML2Config(Config):
|
||||
(response_url, saml2.BINDING_HTTP_POST)
|
||||
]
|
||||
},
|
||||
"required_attributes": ["uid"],
|
||||
"optional_attributes": ["mail", "surname", "givenname"],
|
||||
"required_attributes": list(required_attributes),
|
||||
"optional_attributes": list(optional_attributes),
|
||||
# "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT,
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -146,6 +210,52 @@ class SAML2Config(Config):
|
||||
# The default is 5 minutes.
|
||||
#
|
||||
#saml_session_lifetime: 5m
|
||||
|
||||
# The SAML attribute (after mapping via the attribute maps) to use to derive
|
||||
# the Matrix ID from. 'uid' by default.
|
||||
#
|
||||
#mxid_source_attribute: displayName
|
||||
|
||||
# The mapping system to use for mapping the saml attribute onto a matrix ID.
|
||||
# Options include:
|
||||
# * 'hexencode' (which maps unpermitted characters to '=xx')
|
||||
# * 'dotreplace' (which replaces unpermitted characters with '.').
|
||||
# The default is 'hexencode'.
|
||||
#
|
||||
#mxid_mapping: dotreplace
|
||||
|
||||
# In previous versions of synapse, the mapping from SAML attribute to MXID was
|
||||
# always calculated dynamically rather than stored in a table. For backwards-
|
||||
# compatibility, we will look for user_ids matching such a pattern before
|
||||
# creating a new account.
|
||||
#
|
||||
# This setting controls the SAML attribute which will be used for this
|
||||
# backwards-compatibility lookup. Typically it should be 'uid', but if the
|
||||
# attribute maps are changed, it may be necessary to change it.
|
||||
#
|
||||
# The default is 'uid'.
|
||||
#
|
||||
#grandfathered_mxid_source_attribute: upn
|
||||
""" % {
|
||||
"config_dir_path": config_dir_path
|
||||
}
|
||||
|
||||
|
||||
DOT_REPLACE_PATTERN = re.compile(
|
||||
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
|
||||
)
|
||||
|
||||
|
||||
def dot_replace_for_mxid(username: str) -> str:
|
||||
username = username.lower()
|
||||
username = DOT_REPLACE_PATTERN.sub(".", username)
|
||||
|
||||
# regular mxids aren't allowed to start with an underscore either
|
||||
username = re.sub("^_", "", username)
|
||||
return username
|
||||
|
||||
|
||||
MXID_MAPPER_MAP = {
|
||||
"hexencode": map_username_to_mxid_localpart,
|
||||
"dotreplace": dot_replace_for_mxid,
|
||||
}
|
||||
|
||||
@@ -21,10 +21,8 @@ import unicodedata
|
||||
import attr
|
||||
import bcrypt
|
||||
import pymacaroons
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.api.constants import LoginType
|
||||
@@ -38,7 +36,8 @@ from synapse.api.errors import (
|
||||
UserDeactivatedError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.config.emailconfig import ThreepidBehaviour
|
||||
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
|
||||
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import UserID
|
||||
@@ -58,13 +57,12 @@ class AuthHandler(BaseHandler):
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
super(AuthHandler, self).__init__(hs)
|
||||
self.checkers = {
|
||||
LoginType.RECAPTCHA: self._check_recaptcha,
|
||||
LoginType.EMAIL_IDENTITY: self._check_email_identity,
|
||||
LoginType.MSISDN: self._check_msisdn,
|
||||
LoginType.DUMMY: self._check_dummy_auth,
|
||||
LoginType.TERMS: self._check_terms_auth,
|
||||
}
|
||||
|
||||
self.checkers = {} # type: dict[str, UserInteractiveAuthChecker]
|
||||
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
|
||||
inst = auth_checker_class(hs)
|
||||
self.checkers[inst.AUTH_TYPE] = inst
|
||||
|
||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||
|
||||
# This is not a cache per se, but a store of all current sessions that
|
||||
@@ -292,7 +290,7 @@ class AuthHandler(BaseHandler):
|
||||
sess["creds"] = {}
|
||||
creds = sess["creds"]
|
||||
|
||||
result = yield self.checkers[stagetype](authdict, clientip)
|
||||
result = yield self.checkers[stagetype].check_auth(authdict, clientip)
|
||||
if result:
|
||||
creds[stagetype] = result
|
||||
self._save_session(sess)
|
||||
@@ -363,7 +361,7 @@ class AuthHandler(BaseHandler):
|
||||
login_type = authdict["type"]
|
||||
checker = self.checkers.get(login_type)
|
||||
if checker is not None:
|
||||
res = yield checker(authdict, clientip=clientip)
|
||||
res = yield checker.check_auth(authdict, clientip=clientip)
|
||||
return res
|
||||
|
||||
# build a v1-login-style dict out of the authdict and fall back to the
|
||||
@@ -376,130 +374,6 @@ class AuthHandler(BaseHandler):
|
||||
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
|
||||
return canonical_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_recaptcha(self, authdict, clientip, **kwargs):
|
||||
try:
|
||||
user_response = authdict["response"]
|
||||
except KeyError:
|
||||
# Client tried to provide captcha but didn't give the parameter:
|
||||
# bad request.
|
||||
raise LoginError(
|
||||
400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Submitting recaptcha response %s with remoteip %s", user_response, clientip
|
||||
)
|
||||
|
||||
# TODO: get this from the homeserver rather than creating a new one for
|
||||
# each request
|
||||
try:
|
||||
client = self.hs.get_simple_http_client()
|
||||
resp_body = yield client.post_urlencoded_get_json(
|
||||
self.hs.config.recaptcha_siteverify_api,
|
||||
args={
|
||||
"secret": self.hs.config.recaptcha_private_key,
|
||||
"response": user_response,
|
||||
"remoteip": clientip,
|
||||
},
|
||||
)
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted is silly
|
||||
data = pde.response
|
||||
resp_body = json.loads(data)
|
||||
|
||||
if "success" in resp_body:
|
||||
# Note that we do NOT check the hostname here: we explicitly
|
||||
# intend the CAPTCHA to be presented by whatever client the
|
||||
# user is using, we just care that they have completed a CAPTCHA.
|
||||
logger.info(
|
||||
"%s reCAPTCHA from hostname %s",
|
||||
"Successful" if resp_body["success"] else "Failed",
|
||||
resp_body.get("hostname"),
|
||||
)
|
||||
if resp_body["success"]:
|
||||
return True
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
def _check_email_identity(self, authdict, **kwargs):
|
||||
return self._check_threepid("email", authdict, **kwargs)
|
||||
|
||||
def _check_msisdn(self, authdict, **kwargs):
|
||||
return self._check_threepid("msisdn", authdict)
|
||||
|
||||
def _check_dummy_auth(self, authdict, **kwargs):
|
||||
return defer.succeed(True)
|
||||
|
||||
def _check_terms_auth(self, authdict, **kwargs):
|
||||
return defer.succeed(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_threepid(self, medium, authdict, **kwargs):
|
||||
if "threepid_creds" not in authdict:
|
||||
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
|
||||
|
||||
threepid_creds = authdict["threepid_creds"]
|
||||
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
|
||||
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
|
||||
|
||||
# msisdns are currently always ThreepidBehaviour.REMOTE
|
||||
if medium == "msisdn":
|
||||
if self.hs.config.account_threepid_delegate_msisdn:
|
||||
threepid = yield identity_handler.threepid_from_creds(
|
||||
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
|
||||
)
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "SMS delegation is not enabled on this homeserver"
|
||||
)
|
||||
elif medium == "email":
|
||||
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
|
||||
if medium == "email":
|
||||
threepid = yield identity_handler.threepid_from_creds(
|
||||
self.hs.config.account_threepid_delegate_email, threepid_creds
|
||||
)
|
||||
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
row = yield self.store.get_threepid_validation_session(
|
||||
medium,
|
||||
threepid_creds["client_secret"],
|
||||
sid=threepid_creds["sid"],
|
||||
validated=True,
|
||||
)
|
||||
|
||||
threepid = (
|
||||
{
|
||||
"medium": row["medium"],
|
||||
"address": row["address"],
|
||||
"validated_at": row["validated_at"],
|
||||
}
|
||||
if row
|
||||
else None
|
||||
)
|
||||
|
||||
if row:
|
||||
# Valid threepid returned, delete from the db
|
||||
yield self.store.delete_threepid_session(threepid_creds["sid"])
|
||||
else:
|
||||
raise SynapseError(400, "Email is not enabled on this homeserver")
|
||||
else:
|
||||
raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,))
|
||||
if not threepid:
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
if threepid["medium"] != medium:
|
||||
raise LoginError(
|
||||
401,
|
||||
"Expecting threepid of type '%s', got '%s'"
|
||||
% (medium, threepid["medium"]),
|
||||
errcode=Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
threepid["threepid_creds"] = authdict["threepid_creds"]
|
||||
|
||||
return threepid
|
||||
|
||||
def _get_params_recaptcha(self):
|
||||
return {"public_key": self.hs.config.recaptcha_public_key}
|
||||
|
||||
|
||||
@@ -21,6 +21,8 @@ from saml2.client import Saml2Client
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.rest.client.v1.login import SSOAuthHandler
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,12 +31,26 @@ class SamlHandler:
|
||||
def __init__(self, hs):
|
||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
self._datastore = hs.get_datastore()
|
||||
self._hostname = hs.hostname
|
||||
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
||||
self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
|
||||
self._grandfathered_mxid_source_attribute = (
|
||||
hs.config.saml2_grandfathered_mxid_source_attribute
|
||||
)
|
||||
self._mxid_mapper = hs.config.saml2_mxid_mapper
|
||||
|
||||
# identifier for the external_ids table
|
||||
self._auth_provider_id = "saml"
|
||||
|
||||
# a map from saml session id to Saml2SessionData object
|
||||
self._outstanding_requests_dict = {}
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
||||
# a lock on the mappings
|
||||
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
|
||||
|
||||
def handle_redirect_request(self, client_redirect_url):
|
||||
"""Handle an incoming request to /login/sso/redirect
|
||||
@@ -60,7 +76,7 @@ class SamlHandler:
|
||||
# this shouldn't happen!
|
||||
raise Exception("prepare_for_authenticate didn't return a Location header")
|
||||
|
||||
def handle_saml_response(self, request):
|
||||
async def handle_saml_response(self, request):
|
||||
"""Handle an incoming request to /_matrix/saml2/authn_response
|
||||
|
||||
Args:
|
||||
@@ -77,6 +93,10 @@ class SamlHandler:
|
||||
# the dict.
|
||||
self.expire_sessions()
|
||||
|
||||
user_id = await self._map_saml_response_to_user(resp_bytes)
|
||||
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||
|
||||
async def _map_saml_response_to_user(self, resp_bytes):
|
||||
try:
|
||||
saml2_auth = self._saml_client.parse_authn_request_response(
|
||||
resp_bytes,
|
||||
@@ -91,18 +111,88 @@ class SamlHandler:
|
||||
logger.warning("SAML2 response was not signed")
|
||||
raise SynapseError(400, "SAML2 response was not signed")
|
||||
|
||||
if "uid" not in saml2_auth.ava:
|
||||
logger.info("SAML2 response: %s", saml2_auth.origxml)
|
||||
logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
|
||||
|
||||
try:
|
||||
remote_user_id = saml2_auth.ava["uid"][0]
|
||||
except KeyError:
|
||||
logger.warning("SAML2 response lacks a 'uid' attestation")
|
||||
raise SynapseError(400, "uid not in SAML2 response")
|
||||
|
||||
try:
|
||||
mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
|
||||
)
|
||||
|
||||
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
|
||||
|
||||
username = saml2_auth.ava["uid"][0]
|
||||
displayName = saml2_auth.ava.get("displayName", [None])[0]
|
||||
|
||||
return self._sso_auth_handler.on_successful_auth(
|
||||
username, request, relay_state, user_display_name=displayName
|
||||
)
|
||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||
# first of all, check if we already have a mapping for this user
|
||||
logger.info(
|
||||
"Looking for existing mapping for user %s:%s",
|
||||
self._auth_provider_id,
|
||||
remote_user_id,
|
||||
)
|
||||
registered_user_id = await self._datastore.get_user_by_external_id(
|
||||
self._auth_provider_id, remote_user_id
|
||||
)
|
||||
if registered_user_id is not None:
|
||||
logger.info("Found existing mapping %s", registered_user_id)
|
||||
return registered_user_id
|
||||
|
||||
# backwards-compatibility hack: see if there is an existing user with a
|
||||
# suitable mapping from the uid
|
||||
if (
|
||||
self._grandfathered_mxid_source_attribute
|
||||
and self._grandfathered_mxid_source_attribute in saml2_auth.ava
|
||||
):
|
||||
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
|
||||
user_id = UserID(
|
||||
map_username_to_mxid_localpart(attrval), self._hostname
|
||||
).to_string()
|
||||
logger.info(
|
||||
"Looking for existing account based on mapped %s %s",
|
||||
self._grandfathered_mxid_source_attribute,
|
||||
user_id,
|
||||
)
|
||||
|
||||
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
|
||||
if users:
|
||||
registered_user_id = list(users.keys())[0]
|
||||
logger.info("Grandfathering mapping to %s", registered_user_id)
|
||||
await self._datastore.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
return registered_user_id
|
||||
|
||||
# figure out a new mxid for this user
|
||||
base_mxid_localpart = self._mxid_mapper(mxid_source)
|
||||
|
||||
suffix = 0
|
||||
while True:
|
||||
localpart = base_mxid_localpart + (str(suffix) if suffix else "")
|
||||
if not await self._datastore.get_users_by_id_case_insensitive(
|
||||
UserID(localpart, self._hostname).to_string()
|
||||
):
|
||||
break
|
||||
suffix += 1
|
||||
logger.info("Allocating mxid for new user with localpart %s", localpart)
|
||||
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=displayName
|
||||
)
|
||||
await self._datastore.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
return registered_user_id
|
||||
|
||||
def expire_sessions(self):
|
||||
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
|
||||
|
||||
22
synapse/handlers/ui_auth/__init__.py
Normal file
22
synapse/handlers/ui_auth/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This module implements user-interactive auth verification.
|
||||
|
||||
TODO: move more stuff out of AuthHandler in here.
|
||||
|
||||
"""
|
||||
|
||||
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
|
||||
222
synapse/handlers/ui_auth/checkers.py
Normal file
222
synapse/handlers/ui_auth/checkers.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.config.emailconfig import ThreepidBehaviour
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserInteractiveAuthChecker:
|
||||
"""Abstract base class for an interactive auth checker"""
|
||||
|
||||
def __init__(self, hs):
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, authdict, clientip):
|
||||
"""Given the authentication dict from the client, attempt to check this step
|
||||
|
||||
Args:
|
||||
authdict (dict): authentication dictionary from the client
|
||||
clientip (str): The IP address of the client.
|
||||
|
||||
Raises:
|
||||
SynapseError if authentication failed
|
||||
|
||||
Returns:
|
||||
Deferred: the result of authentication (to pass back to the client?)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DummyAuthChecker(UserInteractiveAuthChecker):
|
||||
AUTH_TYPE = LoginType.DUMMY
|
||||
|
||||
def check_auth(self, authdict, clientip):
|
||||
return defer.succeed(True)
|
||||
|
||||
|
||||
class TermsAuthChecker(UserInteractiveAuthChecker):
|
||||
AUTH_TYPE = LoginType.TERMS
|
||||
|
||||
def check_auth(self, authdict, clientip):
|
||||
return defer.succeed(True)
|
||||
|
||||
|
||||
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
||||
AUTH_TYPE = LoginType.RECAPTCHA
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
self._http_client = hs.get_simple_http_client()
|
||||
self._url = hs.config.recaptcha_siteverify_api
|
||||
self._secret = hs.config.recaptcha_private_key
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, authdict, clientip):
|
||||
try:
|
||||
user_response = authdict["response"]
|
||||
except KeyError:
|
||||
# Client tried to provide captcha but didn't give the parameter:
|
||||
# bad request.
|
||||
raise LoginError(
|
||||
400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Submitting recaptcha response %s with remoteip %s", user_response, clientip
|
||||
)
|
||||
|
||||
# TODO: get this from the homeserver rather than creating a new one for
|
||||
# each request
|
||||
try:
|
||||
resp_body = yield self._http_client.post_urlencoded_get_json(
|
||||
self._url,
|
||||
args={
|
||||
"secret": self._secret,
|
||||
"response": user_response,
|
||||
"remoteip": clientip,
|
||||
},
|
||||
)
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted is silly
|
||||
data = pde.response
|
||||
resp_body = json.loads(data)
|
||||
|
||||
if "success" in resp_body:
|
||||
# Note that we do NOT check the hostname here: we explicitly
|
||||
# intend the CAPTCHA to be presented by whatever client the
|
||||
# user is using, we just care that they have completed a CAPTCHA.
|
||||
logger.info(
|
||||
"%s reCAPTCHA from hostname %s",
|
||||
"Successful" if resp_body["success"] else "Failed",
|
||||
resp_body.get("hostname"),
|
||||
)
|
||||
if resp_body["success"]:
|
||||
return True
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
|
||||
class _BaseThreepidAuthChecker:
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_threepid(self, medium, authdict, **kwargs):
|
||||
if "threepid_creds" not in authdict:
|
||||
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
|
||||
|
||||
threepid_creds = authdict["threepid_creds"]
|
||||
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
|
||||
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
|
||||
|
||||
# msisdns are currently always ThreepidBehaviour.REMOTE
|
||||
if medium == "msisdn":
|
||||
if self.hs.config.account_threepid_delegate_msisdn:
|
||||
threepid = yield identity_handler.threepid_from_creds(
|
||||
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
|
||||
)
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "SMS delegation is not enabled on this homeserver"
|
||||
)
|
||||
elif medium == "email":
|
||||
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
|
||||
if medium == "email":
|
||||
threepid = yield identity_handler.threepid_from_creds(
|
||||
self.hs.config.account_threepid_delegate_email, threepid_creds
|
||||
)
|
||||
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
row = yield self.store.get_threepid_validation_session(
|
||||
medium,
|
||||
threepid_creds["client_secret"],
|
||||
sid=threepid_creds["sid"],
|
||||
validated=True,
|
||||
)
|
||||
|
||||
threepid = (
|
||||
{
|
||||
"medium": row["medium"],
|
||||
"address": row["address"],
|
||||
"validated_at": row["validated_at"],
|
||||
}
|
||||
if row
|
||||
else None
|
||||
)
|
||||
|
||||
if row:
|
||||
# Valid threepid returned, delete from the db
|
||||
yield self.store.delete_threepid_session(threepid_creds["sid"])
|
||||
else:
|
||||
raise SynapseError(400, "Email is not enabled on this homeserver")
|
||||
else:
|
||||
raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,))
|
||||
if not threepid:
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
if threepid["medium"] != medium:
|
||||
raise LoginError(
|
||||
401,
|
||||
"Expecting threepid of type '%s', got '%s'"
|
||||
% (medium, threepid["medium"]),
|
||||
errcode=Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
threepid["threepid_creds"] = authdict["threepid_creds"]
|
||||
|
||||
return threepid
|
||||
|
||||
|
||||
class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
||||
AUTH_TYPE = LoginType.EMAIL_IDENTITY
|
||||
|
||||
def __init__(self, hs):
|
||||
UserInteractiveAuthChecker.__init__(self, hs)
|
||||
_BaseThreepidAuthChecker.__init__(self, hs)
|
||||
|
||||
def check_auth(self, authdict, clientip):
|
||||
return self._check_threepid("email", authdict)
|
||||
|
||||
|
||||
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
||||
AUTH_TYPE = LoginType.MSISDN
|
||||
|
||||
def __init__(self, hs):
|
||||
UserInteractiveAuthChecker.__init__(self, hs)
|
||||
_BaseThreepidAuthChecker.__init__(self, hs)
|
||||
|
||||
def check_auth(self, authdict, clientip):
|
||||
return self._check_threepid("msisdn", authdict)
|
||||
|
||||
|
||||
INTERACTIVE_AUTH_CHECKERS = [
|
||||
DummyAuthChecker,
|
||||
TermsAuthChecker,
|
||||
RecaptchaAuthChecker,
|
||||
EmailIdentityAuthChecker,
|
||||
MsisdnAuthChecker,
|
||||
]
|
||||
"""A list of UserInteractiveAuthChecker classes"""
|
||||
@@ -29,6 +29,7 @@ from synapse.http.servlet import (
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
@@ -507,6 +508,19 @@ class SSOAuthHandler(object):
|
||||
localpart=localpart, default_display_name=user_display_name
|
||||
)
|
||||
|
||||
self.complete_sso_login(registered_user_id, request, client_redirect_url)
|
||||
|
||||
def complete_sso_login(
|
||||
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
|
||||
):
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
||||
Args:
|
||||
registered_user_id:
|
||||
request:
|
||||
client_redirect_url:
|
||||
"""
|
||||
|
||||
login_token = self._macaroon_gen.generate_short_term_login_token(
|
||||
registered_user_id
|
||||
)
|
||||
|
||||
@@ -218,7 +218,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
duration_ms = time_stop - time_start
|
||||
|
||||
logger.info(
|
||||
"Updating %r. Updated %r items in %rms."
|
||||
"Running background update %r. Processed %r items in %rms."
|
||||
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
|
||||
update_name,
|
||||
items_updated,
|
||||
|
||||
@@ -22,6 +22,7 @@ from six import iterkeys
|
||||
from six.moves import range
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||
@@ -384,6 +385,26 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
|
||||
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
||||
|
||||
async def get_user_by_external_id(
|
||||
self, auth_provider: str, external_id: str
|
||||
) -> str:
|
||||
"""Look up a user by their external auth id
|
||||
|
||||
Args:
|
||||
auth_provider: identifier for the remote auth provider
|
||||
external_id: id on that system
|
||||
|
||||
Returns:
|
||||
str|None: the mxid of the user, or None if they are not known
|
||||
"""
|
||||
return await self._simple_select_one_onecol(
|
||||
table="user_external_ids",
|
||||
keyvalues={"auth_provider": auth_provider, "external_id": external_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
desc="get_user_by_external_id",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def count_all_users(self):
|
||||
"""Counts all users registered on the homeserver."""
|
||||
@@ -495,7 +516,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_id_by_threepid(self, medium, address, require_verified=False):
|
||||
def get_user_id_by_threepid(self, medium, address):
|
||||
"""Returns user id from threepid
|
||||
|
||||
Args:
|
||||
@@ -844,7 +865,7 @@ class RegistrationStore(
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
if not rows:
|
||||
return True
|
||||
return True, 0
|
||||
|
||||
rows_processed_nb = 0
|
||||
|
||||
@@ -860,18 +881,18 @@ class RegistrationStore(
|
||||
)
|
||||
|
||||
if batch_size > len(rows):
|
||||
return True
|
||||
return True, len(rows)
|
||||
else:
|
||||
return False
|
||||
return False, len(rows)
|
||||
|
||||
end = yield self.runInteraction(
|
||||
end, nb_processed = yield self.runInteraction(
|
||||
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
|
||||
)
|
||||
|
||||
if end:
|
||||
yield self._end_background_update("users_set_deactivated_flag")
|
||||
|
||||
return batch_size
|
||||
return nb_processed
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
|
||||
@@ -1032,6 +1053,26 @@ class RegistrationStore(
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
def record_user_external_id(
|
||||
self, auth_provider: str, external_id: str, user_id: str
|
||||
) -> Deferred:
|
||||
"""Record a mapping from an external user id to a mxid
|
||||
|
||||
Args:
|
||||
auth_provider: identifier for the remote auth provider
|
||||
external_id: id on that system
|
||||
user_id: complete mxid that it is mapped to
|
||||
"""
|
||||
return self._simple_insert(
|
||||
table="user_external_ids",
|
||||
values={
|
||||
"auth_provider": auth_provider,
|
||||
"external_id": external_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
desc="record_user_external_id",
|
||||
)
|
||||
|
||||
def user_set_password_hash(self, user_id, password_hash):
|
||||
"""
|
||||
NB. This does *not* evict any cache because the one use for this
|
||||
|
||||
24
synapse/storage/schema/delta/56/user_external_ids.sql
Normal file
24
synapse/storage/schema/delta/56/user_external_ids.sql
Normal file
@@ -0,0 +1,24 @@
|
||||
/* Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
* a table which records mappings from external auth providers to mxids
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS user_external_ids (
|
||||
auth_provider TEXT NOT NULL,
|
||||
external_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (auth_provider, external_id)
|
||||
);
|
||||
@@ -14,12 +14,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
|
||||
|
||||
def load_module(provider):
|
||||
""" Loads a module with its config
|
||||
""" Loads a synapse module with its config
|
||||
Take a dict with keys 'module' (the module name) and 'config'
|
||||
(the config dict).
|
||||
|
||||
@@ -38,3 +39,20 @@ def load_module(provider):
|
||||
raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
|
||||
|
||||
return provider_class, provider_config
|
||||
|
||||
|
||||
def load_python_module(location: str):
|
||||
"""Load a python module, and return a reference to its global namespace
|
||||
|
||||
Args:
|
||||
location (str): path to the module
|
||||
|
||||
Returns:
|
||||
python module object
|
||||
"""
|
||||
spec = importlib.util.spec_from_file_location(location, location)
|
||||
if spec is None:
|
||||
raise Exception("Unable to load module at %s" % (location,))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
@@ -18,11 +18,22 @@ from twisted.internet.defer import succeed
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||
from synapse.rest.client.v2_alpha import auth, register
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class DummyRecaptchaChecker(UserInteractiveAuthChecker):
|
||||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
self.recaptcha_attempts = []
|
||||
|
||||
def check_auth(self, authdict, clientip):
|
||||
self.recaptcha_attempts.append((authdict, clientip))
|
||||
return succeed(True)
|
||||
|
||||
|
||||
class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
@@ -44,15 +55,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.recaptcha_checker = DummyRecaptchaChecker(hs)
|
||||
auth_handler = hs.get_auth_handler()
|
||||
|
||||
self.recaptcha_attempts = []
|
||||
|
||||
def _recaptcha(authdict, clientip):
|
||||
self.recaptcha_attempts.append((authdict, clientip))
|
||||
return succeed(True)
|
||||
|
||||
auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
|
||||
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
|
||||
|
||||
@unittest.INFO
|
||||
def test_fallback_captcha(self):
|
||||
@@ -89,8 +94,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(request.code, 200)
|
||||
|
||||
# The recaptcha handler is called with the response given
|
||||
self.assertEqual(len(self.recaptcha_attempts), 1)
|
||||
self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
|
||||
attempts = self.recaptcha_checker.recaptcha_attempts
|
||||
self.assertEqual(len(attempts), 1)
|
||||
self.assertEqual(attempts[0][0]["response"], "a")
|
||||
|
||||
# also complete the dummy auth
|
||||
request, channel = self.make_request(
|
||||
|
||||
Reference in New Issue
Block a user