1
0

Merge branch 'rav/refactor_ui_auth' into neilj/fix_check_threepid_for_msisdns

This commit is contained in:
Richard van der Hoff
2019-09-25 11:24:22 +01:00
17 changed files with 621 additions and 169 deletions

1
changelog.d/6037.feature Normal file
View File

@@ -0,0 +1 @@
Make the process for mapping SAML2 users to matrix IDs more flexible.

1
changelog.d/6069.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug which caused SAML attribute maps to be overridden by defaults.

1
changelog.d/6092.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix the logged number of updated items for the users_set_deactivated_flag background update.

1
changelog.d/6099.misc Normal file
View File

@@ -0,0 +1 @@
Remove unused parameter to get_user_id_by_threepid.

1
changelog.d/6105.misc Normal file
View File

@@ -0,0 +1 @@
Refactor the user-interactive auth handling.

View File

@@ -1174,6 +1174,32 @@ saml2_config:
#
#saml_session_lifetime: 5m
# The SAML attribute (after mapping via the attribute maps) to use to derive
# the Matrix ID from. 'uid' by default.
#
#mxid_source_attribute: displayName
# The mapping system to use for mapping the saml attribute onto a matrix ID.
# Options include:
# * 'hexencode' (which maps unpermitted characters to '=xx')
# * 'dotreplace' (which replaces unpermitted characters with '.').
# The default is 'hexencode'.
#
#mxid_mapping: dotreplace
# In previous versions of synapse, the mapping from SAML attribute to MXID was
# always calculated dynamically rather than stored in a table. For backwards-
# compatibility, we will look for user_ids matching such a pattern before
# creating a new account.
#
# This setting controls the SAML attribute which will be used for this
# backwards-compatibility lookup. Typically it should be 'uid', but if the
# attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
#grandfathered_mxid_source_attribute: upn
# Enable CAS for registration and login.

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,11 +13,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
from synapse.util.module_loader import load_python_module
from ._base import Config, ConfigError
def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
Recursively merges `merge_dict` into `into_dict`:
* For keys where both `merge_dict` and `into_dict` have a dict value, the values
are recursively merged
* For all other keys, the values in `into_dict` (if any) are overwritten with
the value from `merge_dict`.
Args:
merge_dict (dict): dict to merge
into_dict (dict): target dict
"""
for k, v in merge_dict.items():
if k not in into_dict:
into_dict[k] = v
continue
current_val = into_dict[k]
if isinstance(v, dict) and isinstance(current_val, dict):
_dict_merge(v, current_val)
continue
# otherwise we just overwrite
into_dict[k] = v
class SAML2Config(Config):
def read_config(self, config, **kwargs):
self.saml2_enabled = False
@@ -36,21 +73,40 @@ class SAML2Config(Config):
self.saml2_enabled = True
import saml2.config
self.saml2_mxid_source_attribute = saml2_config.get(
"mxid_source_attribute", "uid"
)
self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(self._default_saml_config_dict())
self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid"
)
saml2_config_dict = self._default_saml_config_dict()
_dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
)
config_path = saml2_config.get("config_path", None)
if config_path is not None:
self.saml2_sp_config.load_file(config_path)
mod = load_python_module(config_path)
_dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict)
import saml2.config
self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(saml2_config_dict)
# session lifetime: in milliseconds
self.saml2_session_lifetime = self.parse_duration(
saml2_config.get("saml_session_lifetime", "5m")
)
mapping = saml2_config.get("mxid_mapping", "hexencode")
try:
self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
except KeyError:
raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
def _default_saml_config_dict(self):
import saml2
@@ -58,6 +114,13 @@ class SAML2Config(Config):
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
required_attributes = {"uid", self.saml2_mxid_source_attribute}
optional_attributes = {"displayName"}
if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
response_url = public_baseurl + "_matrix/saml2/authn_response"
return {
@@ -69,8 +132,9 @@ class SAML2Config(Config):
(response_url, saml2.BINDING_HTTP_POST)
]
},
"required_attributes": ["uid"],
"optional_attributes": ["mail", "surname", "givenname"],
"required_attributes": list(required_attributes),
"optional_attributes": list(optional_attributes),
# "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT,
}
},
}
@@ -146,6 +210,52 @@ class SAML2Config(Config):
# The default is 5 minutes.
#
#saml_session_lifetime: 5m
# The SAML attribute (after mapping via the attribute maps) to use to derive
# the Matrix ID from. 'uid' by default.
#
#mxid_source_attribute: displayName
# The mapping system to use for mapping the saml attribute onto a matrix ID.
# Options include:
# * 'hexencode' (which maps unpermitted characters to '=xx')
# * 'dotreplace' (which replaces unpermitted characters with '.').
# The default is 'hexencode'.
#
#mxid_mapping: dotreplace
# In previous versions of synapse, the mapping from SAML attribute to MXID was
# always calculated dynamically rather than stored in a table. For backwards-
# compatibility, we will look for user_ids matching such a pattern before
# creating a new account.
#
# This setting controls the SAML attribute which will be used for this
# backwards-compatibility lookup. Typically it should be 'uid', but if the
# attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
#grandfathered_mxid_source_attribute: upn
""" % {
"config_dir_path": config_dir_path
}
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
def dot_replace_for_mxid(username: str) -> str:
username = username.lower()
username = DOT_REPLACE_PATTERN.sub(".", username)
# regular mxids aren't allowed to start with an underscore either
username = re.sub("^_", "", username)
return username
MXID_MAPPER_MAP = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
}

View File

@@ -21,10 +21,8 @@ import unicodedata
import attr
import bcrypt
import pymacaroons
from canonicaljson import json
from twisted.internet import defer
from twisted.web.client import PartialDownloadError
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
@@ -38,7 +36,8 @@ from synapse.api.errors import (
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi
from synapse.types import UserID
@@ -58,13 +57,12 @@ class AuthHandler(BaseHandler):
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn,
LoginType.DUMMY: self._check_dummy_auth,
LoginType.TERMS: self._check_terms_auth,
}
self.checkers = {} # type: dict[str, UserInteractiveAuthChecker]
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
self.checkers[inst.AUTH_TYPE] = inst
self.bcrypt_rounds = hs.config.bcrypt_rounds
# This is not a cache per se, but a store of all current sessions that
@@ -292,7 +290,7 @@ class AuthHandler(BaseHandler):
sess["creds"] = {}
creds = sess["creds"]
result = yield self.checkers[stagetype](authdict, clientip)
result = yield self.checkers[stagetype].check_auth(authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
@@ -363,7 +361,7 @@ class AuthHandler(BaseHandler):
login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
res = yield checker(authdict, clientip=clientip)
res = yield checker.check_auth(authdict, clientip=clientip)
return res
# build a v1-login-style dict out of the authdict and fall back to the
@@ -376,130 +374,6 @@ class AuthHandler(BaseHandler):
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
return canonical_id
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip, **kwargs):
try:
user_response = authdict["response"]
except KeyError:
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
"Submitting recaptcha response %s with remoteip %s", user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
client = self.hs.get_simple_http_client()
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
"secret": self.hs.config.recaptcha_private_key,
"response": user_response,
"remoteip": clientip,
},
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json.loads(data)
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body["success"] else "Failed",
resp_body.get("hostname"),
)
if resp_body["success"]:
return True
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
def _check_email_identity(self, authdict, **kwargs):
return self._check_threepid("email", authdict, **kwargs)
def _check_msisdn(self, authdict, **kwargs):
return self._check_threepid("msisdn", authdict)
def _check_dummy_auth(self, authdict, **kwargs):
return defer.succeed(True)
def _check_terms_auth(self, authdict, **kwargs):
return defer.succeed(True)
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict, **kwargs):
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict["threepid_creds"]
identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
# msisdns are currently always ThreepidBehaviour.REMOTE
if medium == "msisdn":
if self.hs.config.account_threepid_delegate_msisdn:
threepid = yield identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
else:
raise SynapseError(
400, "SMS delegation is not enabled on this homeserver"
)
elif medium == "email":
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
if medium == "email":
threepid = yield identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
row = yield self.store.get_threepid_validation_session(
medium,
threepid_creds["client_secret"],
sid=threepid_creds["sid"],
validated=True,
)
threepid = (
{
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
}
if row
else None
)
if row:
# Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(400, "Email is not enabled on this homeserver")
else:
raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,))
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
if threepid["medium"] != medium:
raise LoginError(
401,
"Expecting threepid of type '%s', got '%s'"
% (medium, threepid["medium"]),
errcode=Codes.UNAUTHORIZED,
)
threepid["threepid_creds"] = authdict["threepid_creds"]
return threepid
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}

View File

@@ -21,6 +21,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -29,12 +31,26 @@ class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._sso_auth_handler = SSOAuthHandler(hs)
self._registration_handler = hs.get_registration_handler()
self._clock = hs.get_clock()
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
self._mxid_mapper = hs.config.saml2_mxid_mapper
# identifier for the external_ids table
self._auth_provider_id = "saml"
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {}
self._clock = hs.get_clock()
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
def handle_redirect_request(self, client_redirect_url):
"""Handle an incoming request to /login/sso/redirect
@@ -60,7 +76,7 @@ class SamlHandler:
# this shouldn't happen!
raise Exception("prepare_for_authenticate didn't return a Location header")
def handle_saml_response(self, request):
async def handle_saml_response(self, request):
"""Handle an incoming request to /_matrix/saml2/authn_response
Args:
@@ -77,6 +93,10 @@ class SamlHandler:
# the dict.
self.expire_sessions()
user_id = await self._map_saml_response_to_user(resp_bytes)
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(self, resp_bytes):
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
@@ -91,18 +111,88 @@ class SamlHandler:
logger.warning("SAML2 response was not signed")
raise SynapseError(400, "SAML2 response was not signed")
if "uid" not in saml2_auth.ava:
logger.info("SAML2 response: %s", saml2_auth.origxml)
logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
try:
remote_user_id = saml2_auth.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise SynapseError(400, "uid not in SAML2 response")
try:
mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
except KeyError:
logger.warning(
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
)
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
username = saml2_auth.ava["uid"][0]
displayName = saml2_auth.ava.get("displayName", [None])[0]
return self._sso_auth_handler.on_successful_auth(
username, request, relay_state, user_display_name=displayName
)
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
"Looking for existing mapping for user %s:%s",
self._auth_provider_id,
remote_user_id,
)
registered_user_id = await self._datastore.get_user_by_external_id(
self._auth_provider_id, remote_user_id
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
self._grandfathered_mxid_source_attribute
and self._grandfathered_mxid_source_attribute in saml2_auth.ava
):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID(
map_username_to_mxid_localpart(attrval), self._hostname
).to_string()
logger.info(
"Looking for existing account based on mapped %s %s",
self._grandfathered_mxid_source_attribute,
user_id,
)
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
# figure out a new mxid for this user
base_mxid_localpart = self._mxid_mapper(mxid_source)
suffix = 0
while True:
localpart = base_mxid_localpart + (str(suffix) if suffix else "")
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
):
break
suffix += 1
logger.info("Allocating mxid for new user with localpart %s", localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=displayName
)
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module implements user-interactive auth verification.
TODO: move more stuff out of AuthHandler in here.
"""
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401

View File

@@ -0,0 +1,222 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from canonicaljson import json
from twisted.internet import defer
from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.config.emailconfig import ThreepidBehaviour
logger = logging.getLogger(__name__)
class UserInteractiveAuthChecker:
"""Abstract base class for an interactive auth checker"""
def __init__(self, hs):
pass
@defer.inlineCallbacks
def check_auth(self, authdict, clientip):
"""Given the authentication dict from the client, attempt to check this step
Args:
authdict (dict): authentication dictionary from the client
clientip (str): The IP address of the client.
Raises:
SynapseError if authentication failed
Returns:
Deferred: the result of authentication (to pass back to the client?)
"""
raise NotImplementedError()
class DummyAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.DUMMY
def check_auth(self, authdict, clientip):
return defer.succeed(True)
class TermsAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.TERMS
def check_auth(self, authdict, clientip):
return defer.succeed(True)
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.RECAPTCHA
def __init__(self, hs):
super().__init__(hs)
self._http_client = hs.get_simple_http_client()
self._url = hs.config.recaptcha_siteverify_api
self._secret = hs.config.recaptcha_private_key
@defer.inlineCallbacks
def check_auth(self, authdict, clientip):
try:
user_response = authdict["response"]
except KeyError:
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
"Submitting recaptcha response %s with remoteip %s", user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
resp_body = yield self._http_client.post_urlencoded_get_json(
self._url,
args={
"secret": self._secret,
"response": user_response,
"remoteip": clientip,
},
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json.loads(data)
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body["success"] else "Failed",
resp_body.get("hostname"),
)
if resp_body["success"]:
return True
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
class _BaseThreepidAuthChecker:
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict, **kwargs):
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict["threepid_creds"]
identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
# msisdns are currently always ThreepidBehaviour.REMOTE
if medium == "msisdn":
if self.hs.config.account_threepid_delegate_msisdn:
threepid = yield identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
else:
raise SynapseError(
400, "SMS delegation is not enabled on this homeserver"
)
elif medium == "email":
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
if medium == "email":
threepid = yield identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
row = yield self.store.get_threepid_validation_session(
medium,
threepid_creds["client_secret"],
sid=threepid_creds["sid"],
validated=True,
)
threepid = (
{
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
}
if row
else None
)
if row:
# Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(400, "Email is not enabled on this homeserver")
else:
raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,))
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
if threepid["medium"] != medium:
raise LoginError(
401,
"Expecting threepid of type '%s', got '%s'"
% (medium, threepid["medium"]),
errcode=Codes.UNAUTHORIZED,
)
threepid["threepid_creds"] = authdict["threepid_creds"]
return threepid
class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
AUTH_TYPE = LoginType.EMAIL_IDENTITY
def __init__(self, hs):
UserInteractiveAuthChecker.__init__(self, hs)
_BaseThreepidAuthChecker.__init__(self, hs)
def check_auth(self, authdict, clientip):
return self._check_threepid("email", authdict)
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
AUTH_TYPE = LoginType.MSISDN
def __init__(self, hs):
UserInteractiveAuthChecker.__init__(self, hs)
_BaseThreepidAuthChecker.__init__(self, hs)
def check_auth(self, authdict, clientip):
return self._check_threepid("msisdn", authdict)
INTERACTIVE_AUTH_CHECKERS = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
EmailIdentityAuthChecker,
MsisdnAuthChecker,
]
"""A list of UserInteractiveAuthChecker classes"""

View File

@@ -29,6 +29,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -507,6 +508,19 @@ class SSOAuthHandler(object):
localpart=localpart, default_display_name=user_display_name
)
self.complete_sso_login(registered_user_id, request, client_redirect_url)
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id:
request:
client_redirect_url:
"""
login_token = self._macaroon_gen.generate_short_term_login_token(
registered_user_id
)

View File

@@ -218,7 +218,7 @@ class BackgroundUpdateStore(SQLBaseStore):
duration_ms = time_stop - time_start
logger.info(
"Updating %r. Updated %r items in %rms."
"Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
update_name,
items_updated,

View File

@@ -22,6 +22,7 @@ from six import iterkeys
from six.moves import range
from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -384,6 +385,26 @@ class RegistrationWorkerStore(SQLBaseStore):
return self.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
) -> str:
"""Look up a user by their external auth id
Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
Returns:
str|None: the mxid of the user, or None if they are not known
"""
return await self._simple_select_one_onecol(
table="user_external_ids",
keyvalues={"auth_provider": auth_provider, "external_id": external_id},
retcol="user_id",
allow_none=True,
desc="get_user_by_external_id",
)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
@@ -495,7 +516,7 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
def get_user_id_by_threepid(self, medium, address, require_verified=False):
def get_user_id_by_threepid(self, medium, address):
"""Returns user id from threepid
Args:
@@ -844,7 +865,7 @@ class RegistrationStore(
rows = self.cursor_to_dict(txn)
if not rows:
return True
return True, 0
rows_processed_nb = 0
@@ -860,18 +881,18 @@ class RegistrationStore(
)
if batch_size > len(rows):
return True
return True, len(rows)
else:
return False
return False, len(rows)
end = yield self.runInteraction(
end, nb_processed = yield self.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
yield self._end_background_update("users_set_deactivated_flag")
return batch_size
return nb_processed
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
@@ -1032,6 +1053,26 @@ class RegistrationStore(
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> Deferred:
"""Record a mapping from an external user id to a mxid
Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
return self._simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
"external_id": external_id,
"user_id": user_id,
},
desc="record_user_external_id",
)
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this

View File

@@ -0,0 +1,24 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* a table which records mappings from external auth providers to mxids
*/
CREATE TABLE IF NOT EXISTS user_external_ids (
auth_provider TEXT NOT NULL,
external_id TEXT NOT NULL,
user_id TEXT NOT NULL,
UNIQUE (auth_provider, external_id)
);

View File

@@ -14,12 +14,13 @@
# limitations under the License.
import importlib
import importlib.util
from synapse.config._base import ConfigError
def load_module(provider):
""" Loads a module with its config
""" Loads a synapse module with its config
Take a dict with keys 'module' (the module name) and 'config'
(the config dict).
@@ -38,3 +39,20 @@ def load_module(provider):
raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
return provider_class, provider_config
def load_python_module(location: str):
"""Load a python module, and return a reference to its global namespace
Args:
location (str): path to the module
Returns:
python module object
"""
spec = importlib.util.spec_from_file_location(location, location)
if spec is None:
raise Exception("Unable to load module at %s" % (location,))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod

View File

@@ -18,11 +18,22 @@ from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client.v2_alpha import auth, register
from tests import unittest
class DummyRecaptchaChecker(UserInteractiveAuthChecker):
def __init__(self, hs):
super().__init__(hs)
self.recaptcha_attempts = []
def check_auth(self, authdict, clientip):
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
@@ -44,15 +55,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor, clock, hs):
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
self.recaptcha_attempts = []
def _recaptcha(authdict, clientip):
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)
auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
@unittest.INFO
def test_fallback_captcha(self):
@@ -89,8 +94,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(request.code, 200)
# The recaptcha handler is called with the response given
self.assertEqual(len(self.recaptcha_attempts), 1)
self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
attempts = self.recaptcha_checker.recaptcha_attempts
self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a")
# also complete the dummy auth
request, channel = self.make_request(