Compare commits

...

9 Commits

Author SHA1 Message Date
Azrenbeth
862d820d44 moved more saml imports behind checks
(tried everything with pysaml2 uninstalled so should now work!)
2021-08-25 11:49:21 +01:00
Azrenbeth
e7d6c061e4 Added ANOTHER check for saml being enabled before loading it 2021-08-25 11:22:04 +01:00
Azrenbeth
08d386a37a no longer assert default was loaded, and better test for if using default 2021-08-25 10:48:47 +01:00
Azrenbeth
d9e48a4caa Don't import saml stuff if saml not enabled 2021-08-24 16:42:03 +01:00
Azrenbeth
057628b018 Added (unneccesary?) parentheses to try and make python3.6 happy 2021-08-24 16:29:07 +01:00
Azrenbeth
d1a0a27056 created changelog 2021-08-24 14:38:22 +01:00
Azrenbeth
a0ace792e0 Wrote docs for saml user mapping provider callbacks 2021-08-24 14:38:22 +01:00
Azrenbeth
162738feb6 Updated tests to use new module system 2021-08-24 14:38:22 +01:00
Azrenbeth
2b3e4e856f Port the saml mapping providers to new module interface 2021-08-24 14:38:22 +01:00
13 changed files with 654 additions and 152 deletions

View File

@@ -0,0 +1 @@
Port the SAML user mapping providers module interface to the new generic interface.

View File

@@ -329,6 +329,76 @@ For example, if the user `@alice:example.org` is passed to this method, and the
should receive presence updates sent by Bob and Charlie, regardless of whether these users
share a room.
#### Saml User Mapping Provider Callbacks
Saml user mapping provider callbacks are used to work out how to map
attributes of an SSO response to Matrix-specific user attributes.
As an example, a SSO service may return the email address
"john.smith@example.com" for a user and Synapse will need to figure out how
to turn that into a displayname when creating a Matrix user for this individual.
It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
variations.
A module with mapping provider functionality must register all of the following:
```python
saml_attributes: Tuple[Set[str], Set[str]]
```
A tuple of two sets, the first being the SAML auth response attributes that are
required for the module to function, and the second set being the attributes which
can be used if available, but are not necessary.
```python
async def get_remote_user_id(
self,
saml_response: "saml2.response.AuthnResponse",
client_redirect_url: Optional[str],
) -> str
```
This callback is used to extract the *remote* user id for a user. It is provided with a SAML
auth response object to extract the information from, and the URL that the client will
be redirected to after authentication - which may be None. It should return an
immutable identifier for the user (Commonly the `uid` field of the response). The module should
return a unique identifier for each user. If no mapping can be made then it should raise a
`synapse.ModuleApi.errors.MappingException`.
```python
async def saml_response_to_user_attributes(
self,
saml_response: "saml2.response.AuthnResponse",
failures: int,
client_redirect_url: str,
) -> dict
```
This callback is used to extract certain attributes for a new user.
It is provided with a SAML auth response object to extract the information from, a number representing
the number of times the returned matrix user id mapping has failed and the URL that the client will
be redirected to after authentication.
It should return a dict which will be used by Synapse to build a new user.
The following keys are allowed:
* `mxid_localpart` - A string, the local part of the matrix user ID for the new user.
If this is `None`, the user is prompted to pick their own username.
This is only used during a user's first login. Once a localpart has
been associated with a remote user ID (see `get_remote_user_id`) it
cannot be updated.
* `displayname` - The displayname of the new user. If not provided, it will default to
the value of `mxid_localpart`.
* `emails` - A list of emails for the new user. If not provided, it will default to an empty list.
For example, if this method returns `john.doe` as the value of `mxid_localpart` in the returned
dict, and that is already taken on the homeserver, this method will be called again with the
same parameters but with `failures=1`. The method should then return a different `mxid_localpart`
value, such as `john.doe1`.
If no mapping can be made then it should raise a `synapse.ModuleApi.errors.MappingException`.
Alternatively it may raise a `synapse.ModuleApi.errors.RedirectException` to redirect the user to another
page which prompts for additional information. After which, it is the module's responsibility
to either redirect back to `client_redirect_url` (including any additional information)
or to complete registration using methods from the ModuleApi. TODO: explain what this means in more detail
### Porting an existing module that uses the old interface
In order to port a module that uses Synapse's old module interface, its author needs to:

View File

@@ -1544,7 +1544,9 @@ saml2_config:
#
# Default values will be used for the 'entityid' and 'service' settings,
# so it is not normally necessary to specify them unless you need to
# override them.
# override them. Note that setting 'service.sp.required_attributes' or
# 'service.sp.optional_attributes' here will override anything configured
# by a module that registers saml2 user mapping provider callbacks
#
sp_config:
# Point this to the IdP's metadata. You must provide either a local
@@ -1622,18 +1624,14 @@ saml2_config:
#
#saml_session_lifetime: 5m
# An external module can be provided here as a custom solution to
# mapping attributes returned from a saml provider onto a matrix user.
# Setting for the default mapping provider which maps attributes returned
# from a saml provider onto a matrix user. Custom solutions can be used by
# adding a module that provides these features to the 'modules' config
# section, in which case the following section will be ignored.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
#
#module: mapping_provider.SamlMappingProvider
# Custom configuration values for the module. Below options are
# intended for the built-in provider, they should be changed if
# using a custom module. This section will be passed as a Python
# dictionary to the module's `parse_config` method.
# intended for the built-in provider.
#
config:
# The SAML attribute (after mapping via the attribute maps) to use

View File

@@ -1,3 +1,9 @@
<h2 style="color:red">
Parts of this section of the Synapse documentation are now deprecated. For up to date
documentation on setting up or writing a saml mapping provider module, please
see <a href="modules.md">this page</a>.
</h2>
# SSO Mapping Providers
A mapping provider is a Python class (loaded via a Python module) that

View File

@@ -372,6 +372,16 @@ async def start(hs: "HomeServer"):
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs)
# 'module_has_registered' is true if a module calls 'register_saml2_user_mapping_provider_callbacks'
# Only one mapping provider can be set, so only load default (or legacy configured one) if this is
# still false
if (
hs.config.saml2.saml2_enabled
and not hs.get_saml2_user_mapping_provider().module_has_registered
):
from synapse.handlers.saml import load_default_or_legacy_saml2_mapping_provider
load_default_or_legacy_saml2_mapping_provider(hs)
# If we've configured an expiry time for caches, start the background job now.
setup_expire_lru_cache_entries(hs)

View File

@@ -18,6 +18,7 @@ from typing import Any, List
from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util import dict_merge
from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
@@ -33,34 +34,6 @@ LEGACY_USER_MAPPING_PROVIDER = (
)
def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
Recursively merges `merge_dict` into `into_dict`:
* For keys where both `merge_dict` and `into_dict` have a dict value, the values
are recursively merged
* For all other keys, the values in `into_dict` (if any) are overwritten with
the value from `merge_dict`.
Args:
merge_dict (dict): dict to merge
into_dict (dict): target dict
"""
for k, v in merge_dict.items():
if k not in into_dict:
into_dict[k] = v
continue
current_val = into_dict[k]
if isinstance(v, dict) and isinstance(current_val, dict):
_dict_merge(v, current_val)
continue
# otherwise we just overwrite
into_dict[k] = v
class SAML2Config(Config):
section = "saml2"
@@ -99,11 +72,15 @@ class SAML2Config(Config):
ump_dict = saml2_config.get("user_mapping_provider") or {}
# Use the default user mapping provider if not set
# NOTE this is the legacy way of using custom modules
# New style-modules should be placed in the 'modules:' config section
ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
if ump_dict.get("module") == LEGACY_USER_MAPPING_PROVIDER:
ump_dict["module"] = DEFAULT_USER_MAPPING_PROVIDER
# Ensure a config is present
# This is the config for the default mapping provider, or the legacy
# way of configuring a custom module
ump_dict["config"] = ump_dict.get("config") or {}
if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
@@ -132,59 +109,30 @@ class SAML2Config(Config):
self.saml2_user_mapping_provider_config,
) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
# Note parse_config() is already checked during the call to load_module
required_methods = [
"get_saml_attributes",
"saml_response_to_user_attributes",
"get_remote_user_id",
]
missing_methods = [
method
for method in required_methods
if not hasattr(self.saml2_user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by saml2_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
# Get the desired saml auth response attributes from the module
saml2_config_dict = self._default_saml_config_dict(
*self.saml2_user_mapping_provider_class.get_saml_attributes(
self.saml2_user_mapping_provider_config
)
)
_dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
# This is only the *base* config since a custom user mapping provider can change
# the values of 'service.sp.required_attributes' and 'service.sp.optional_attributes'
self.base_sp_config = self._default_sp_config_dict()
dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=self.base_sp_config
)
config_path = saml2_config.get("config_path", None)
if config_path is not None:
mod = load_python_module(config_path)
config = getattr(mod, "CONFIG", None)
if config is None:
sp_config_path = saml2_config.get("config_path", None)
if sp_config_path is not None:
mod = load_python_module(sp_config_path)
sp_config_from_file = getattr(mod, "CONFIG", None)
if sp_config_from_file is None:
raise ConfigError(
"Config path specified by saml2_config.config_path does not "
"have a CONFIG property."
)
_dict_merge(merge_dict=config, into_dict=saml2_config_dict)
import saml2.config
self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(saml2_config_dict)
dict_merge(merge_dict=sp_config_from_file, into_dict=self.base_sp_config)
# session lifetime: in milliseconds
self.saml2_session_lifetime = self.parse_duration(
saml2_config.get("saml_session_lifetime", "15m")
)
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
):
def _default_sp_config_dict(self):
"""Generate a configuration dictionary with required and optional attributes that
will be needed to process new user registration
@@ -203,10 +151,6 @@ class SAML2Config(Config):
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml"
response_url = public_baseurl + "_synapse/client/saml2/authn_response"
return {
@@ -218,8 +162,6 @@ class SAML2Config(Config):
(response_url, saml2.BINDING_HTTP_POST)
]
},
"required_attributes": list(required_attributes),
"optional_attributes": list(optional_attributes),
# "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT,
}
},
@@ -257,7 +199,9 @@ class SAML2Config(Config):
#
# Default values will be used for the 'entityid' and 'service' settings,
# so it is not normally necessary to specify them unless you need to
# override them.
# override them. Note that setting 'service.sp.required_attributes' or
# 'service.sp.optional_attributes' here will override anything configured
# by a module that registers saml2 user mapping provider callbacks
#
sp_config:
# Point this to the IdP's metadata. You must provide either a local
@@ -335,18 +279,14 @@ class SAML2Config(Config):
#
#saml_session_lifetime: 5m
# An external module can be provided here as a custom solution to
# mapping attributes returned from a saml provider onto a matrix user.
# Setting for the default mapping provider which maps attributes returned
# from a saml provider onto a matrix user. Custom solutions can be used by
# adding a module that provides these features to the 'modules' config
# section, in which case the following section will be ignored.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
#
#module: mapping_provider.SamlMappingProvider
# Custom configuration values for the module. Below options are
# intended for the built-in provider, they should be changed if
# using a custom module. This section will be passed as a Python
# dictionary to the module's `parse_config` method.
# intended for the built-in provider.
#
config:
# The SAML attribute (after mapping via the attribute maps) to use

View File

@@ -13,15 +13,16 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Set, Tuple
import attr
import saml2
import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.api.errors import RedirectException
from synapse.config import ConfigError
from synapse.config.saml2 import DEFAULT_USER_MAPPING_PROVIDER
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
@@ -32,6 +33,8 @@ from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
from synapse.util import dict_merge
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -54,7 +57,52 @@ class Saml2SessionData:
class SamlHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
# If support for legacy saml2_mapping_providers is dropped then this
# is where the DefaultSamlMappingProvider should be loaded
self._user_mapping_provider = hs.get_saml2_user_mapping_provider()
# At this point either a module will have registered user mapping provider
# callbacks or the default will have been registered.
# however if for some reason that hasn't happened (e.g. testing) load the default
if not self._user_mapping_provider.module_has_registered:
load_default_or_legacy_saml2_mapping_provider(hs)
# Merge the required and optional saml_attributes registered by the mapping
# provider with the base sp config. NOTE: If there are conflicts then the
# module's expected attributes are overwritten by the base sp_config. This is
# how it worked with legacy modules.
(
required_attributes,
optional_attributes,
) = self._user_mapping_provider.get_saml_attributes()
# Required for backwards compatability
if hs.config.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(hs.config.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
sp_config_dict = {
"service": {
"sp": {
"required_attributes": list(required_attributes),
"optional_attributes": list(optional_attributes),
}
},
}
# Merged this way around for backwards compatability
dict_merge(
merge_dict=hs.config.saml2.base_sp_config,
into_dict=sp_config_dict,
)
self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(sp_config_dict)
self._saml_client = Saml2Client(self.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
@@ -64,12 +112,6 @@ class SamlHandler(BaseHandler):
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
self._error_template = hs.config.sso_error_template
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
hs.config.saml2_user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
)
# identifier for the external_ids table
self.idp_id = "saml"
@@ -222,7 +264,9 @@ class SamlHandler(BaseHandler):
# first check if we're doing a UIA
if current_session and current_session.ui_auth_session_id:
try:
remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
remote_user_id = await self._user_mapping_provider.get_remote_user_id(
saml2_auth, None
)
except MappingException as e:
logger.exception("Failed to extract remote user id from SAML response")
self._sso_handler.render_error(request, "mapping_error", str(e))
@@ -273,7 +317,7 @@ class SamlHandler(BaseHandler):
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
remote_user_id = self._remote_id_from_saml_response(
remote_user_id = await self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
@@ -286,7 +330,7 @@ class SamlHandler(BaseHandler):
This is backwards compatibility for abstraction for the SSO handler.
"""
# Call the mapping provider.
result = self._user_mapping_provider.saml_response_to_user_attributes(
result = await self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, failures, client_redirect_url
)
# Remap some of the results.
@@ -331,35 +375,6 @@ class SamlHandler(BaseHandler):
grandfather_existing_users,
)
def _remote_id_from_saml_response(
self,
saml2_auth: saml2.response.AuthnResponse,
client_redirect_url: Optional[str],
) -> str:
"""Extract the unique remote id from a SAML2 AuthnResponse
Args:
saml2_auth: The parsed SAML2 response.
client_redirect_url: The redirect URL passed in by the client.
Returns:
remote user id
Raises:
MappingException if there was an error extracting the user id
"""
# It's not obvious why we need to pass in the redirect URI to the mapping
# provider, but we do :/
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
if not remote_user_id:
raise MappingException(
"Failed to extract remote user id from SAML response"
)
return remote_user_id
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
@@ -398,6 +413,15 @@ class SamlConfig:
mxid_mapper = attr.ib()
# The type definition for the user mapping provider callbacks
GET_REMOTE_USER_ID_CALLBACK = Callable[
[saml2.response.AuthnResponse, Optional[str]], Awaitable[str]
]
SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK = Callable[
[saml2.response.AuthnResponse, int, str], Awaitable[Dict]
]
class DefaultSamlMappingProvider:
__version__ = "0.0.1"
@@ -411,12 +435,19 @@ class DefaultSamlMappingProvider:
self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper
self._grandfathered_mxid_source_attribute = (
module_api._hs.config.saml2_grandfathered_mxid_source_attribute
module_api.register_saml2_user_mapping_provider_callbacks(
get_remote_user_id=self.get_remote_user_id,
saml_response_to_user_attributes=self.saml_response_to_user_attributes,
saml_attributes=(
{"uid", self._mxid_source_attribute},
{"displayName", "email"},
),
)
def get_remote_user_id(
self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
async def get_remote_user_id(
self,
saml_response: saml2.response.AuthnResponse,
client_redirect_url: Optional[str],
) -> str:
"""Extracts the remote user id from the SAML response"""
try:
@@ -425,7 +456,7 @@ class DefaultSamlMappingProvider:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise MappingException("'uid' not in SAML2 response")
def saml_response_to_user_attributes(
async def saml_response_to_user_attributes(
self,
saml_response: saml2.response.AuthnResponse,
failures: int,
@@ -454,8 +485,8 @@ class DefaultSamlMappingProvider:
"SAML2 response lacks a '%s' attestation",
self._mxid_source_attribute,
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
raise MappingException(
"%s not in SAML2 response" % (self._mxid_source_attribute,)
)
# Use the configured mapper for this mxid_source
@@ -501,8 +532,235 @@ class DefaultSamlMappingProvider:
return SamlConfig(mxid_source_attribute, mxid_mapper)
@staticmethod
def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
def load_default_or_legacy_saml2_mapping_provider(hs: "HomeServer"):
"""Wrapper that loads a saml2 mapping provider either from the default module or
configured using the legacy configuration. Legacy modules then have their callbacks
registered
"""
if hs.config.saml2.saml2_user_mapping_provider_class is None:
# This should be an impossible position to be in
raise RuntimeError("No default saml2 user mapping provider is set")
module = hs.config.saml2.saml2_user_mapping_provider_class
config = hs.config.saml2.saml2_user_mapping_provider_config
api = hs.get_module_api()
mapping_provider = module(config, api)
# if we were loading the default provider, then it has already registered its callbacks!
# so we can stop here
if module.__module__ + "." + module.__qualname__ == DEFAULT_USER_MAPPING_PROVIDER:
return
# The required hooks. If a custom module doesn't implement all of these then raise an error
required_mapping_provider_methods = {
"get_saml_attributes",
"saml_response_to_user_attributes",
"get_remote_user_id",
}
missing_methods = [
method
for method in required_mapping_provider_methods
if not hasattr(module, method)
]
if missing_methods:
raise RuntimeError(
"Class specified by saml2_config."
" user_mapping_provider.module is missing required"
" methods: %s" % (", ".join(missing_methods),)
)
# New modules have to proactively register this instead of just the callback
saml_attributes = mapping_provider.get_saml_attributes(config)
mapping_provider_methods = {
"saml_response_to_user_attributes",
"get_remote_user_id",
}
# Methods that the module provides should be async, but this wasn't the case
# in the old module system, so we wrap them if needed
def async_wrapper(f: Callable) -> Callable[..., Awaitable]:
def run(*args, **kwargs):
return maybe_awaitable(f(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks = {
hook: async_wrapper(getattr(mapping_provider, hook, None))
for hook in mapping_provider_methods
}
api.register_saml2_user_mapping_provider_callbacks(
saml_attributes=saml_attributes, **hooks
)
class Saml2UserMappingProvider:
def __init__(self, hs: "HomeServer"):
"""The SAML user mapping provider
Args:
parsed_config: Module configuration
module_api: module api proxy
"""
# self._mxid_source_attribute = parsed_config.mxid_source_attribute
# self._mxid_mapper = parsed_config.mxid_mapper
self.get_remote_user_id_callback: Optional[GET_REMOTE_USER_ID_CALLBACK] = None
self.saml_response_to_user_attributes_callback: Optional[
SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK
] = None
self.saml_attributes: Tuple[Set[str], Set[str]] = (set(), set())
self.module_has_registered = False
def register_saml2_user_mapping_provider_callbacks(
self,
get_remote_user_id: GET_REMOTE_USER_ID_CALLBACK,
saml_response_to_user_attributes: SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK,
saml_attributes: Tuple[Set[str], Set[str]],
):
"""Called by modules to register callbacks and saml_attributes"""
# Only one module can register callbacks
if self.module_has_registered:
raise RuntimeError(
"Multiple modules have attempted to register as saml mapping providers"
)
self.module_has_registered = True
self.get_remote_user_id_callback = get_remote_user_id
self.saml_response_to_user_attributes_callback = (
saml_response_to_user_attributes
)
self.saml_attributes = saml_attributes
async def get_remote_user_id(
self,
saml_response: saml2.response.AuthnResponse,
client_redirect_url: Optional[str],
) -> str:
"""Extracts the remote user id from the SAML response
Args:
saml2_auth: The parsed SAML2 response.
client_redirect_url: The redirect URL passed in by the client. This may
be None.
Returns:
remote user id
Raises:
MappingException: if there was an error extracting the user id
Any other exception: for backwards compatability
"""
# If no module has registered callbacks then raise an error
if not self.module_has_registered:
raise RuntimeError("No Saml2 mapping provider has been registered")
assert self.get_remote_user_id_callback is not None
try:
result = await self.get_remote_user_id_callback(
saml_response, client_redirect_url
)
except MappingException:
# Mapping providers are allowed to issue a mapping exception
# if a remote user id cannot be generated.
raise
except Exception as e:
logger.warning(
f"Something went wrong when calling custom module callback for get_remote_user_id: {e}"
)
# for compatablity with legacy modules, need to raise this exception as is:
raise e
# # If the module raises some other sort of exception then don't display that to the user
# raise MappingException(
# "Failed to extract remote user id from SAML response"
# )
if not isinstance(result, str):
logger.warning( # type: ignore[unreachable]
f"Wrong type returned by module callback for get_remote_user_id: {result}, expected str"
)
# Don't overshare to the user, as something has clearly gone wrong
raise MappingException(
"Failed to extract remote user id from SAML response"
)
return result
async def saml_response_to_user_attributes(
self,
saml_response: saml2.response.AuthnResponse,
failures: int,
client_redirect_url: str,
) -> dict:
"""Maps some text from a SAML response to attributes of a new user
Args:
saml_response: A SAML auth response object
failures: How many times a call to this function with this
saml_response has resulted in a failure
client_redirect_url: where the client wants to redirect to
Returns:
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user
* emails (list[str]): Any emails for the user
Raises:
MappingException: if something goes wrong while processing the response
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
Any other exception: for backwards compatability
"""
# If no module has registered callbacks then raise an error
if not self.module_has_registered:
raise RuntimeError("No Saml2 mapping provider has been registered")
assert self.saml_response_to_user_attributes_callback is not None
try:
result = await self.saml_response_to_user_attributes_callback(
saml_response, failures, client_redirect_url
)
except (RedirectException, MappingException):
# Mapping providers are allowed to issue a redirect (e.g. to ask
# the user for more information) and can issue a mapping exception
# if a name cannot be generated.
raise
except Exception as e:
logger.warning(
f"Something went wrong when calling custom module callback for saml_response_to_user_attributes: {e}"
)
# for compatablity with legacy modules, need to raise this exception as is:
raise e
# # If the module raises some other sort of exception then don't display that to the user
# raise MappingException(
# "Unable to map from SAML2 response to user attributes"
# )
if not isinstance(result, dict):
logger.warning( # type: ignore[unreachable]
f"Wrong type returned by module callback for get_remote_user_id: {result}, expected dict"
)
# Don't overshare to the user, as something has clearly gone wrong
raise MappingException(
"Unable to map from SAML2 response to user attributes"
)
return result
def get_saml_attributes(self) -> Tuple[Set[str], Set[str]]:
"""Returns the required attributes of a SAML
Args:
@@ -514,4 +772,4 @@ class DefaultSamlMappingProvider:
second set consists of those attributes which can be used if
available, but are not necessary
"""
return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
return self.saml_attributes

View File

@@ -117,6 +117,8 @@ class ModuleApi:
self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules()
self._presence_router = hs.get_presence_router()
if hs.config.saml2.saml2_enabled:
self._saml2_user_mapping_provider = hs.get_saml2_user_mapping_provider()
#################################################################################
# The following methods should only be called during the module's initialisation.
@@ -141,6 +143,17 @@ class ModuleApi:
"""Registers callbacks for presence router capabilities."""
return self._presence_router.register_presence_router_callbacks
@property
def register_saml2_user_mapping_provider_callbacks(self):
"""Registers callbacks for presence router capabilities."""
if not self._hs.config.saml2.saml2_enabled:
raise RuntimeError(
"Saml2 is not enabled, so cannot register saml2 usr mapping provider callbacks"
)
return (
self._saml2_user_mapping_provider.register_saml2_user_mapping_provider_callbacks
)
def register_web_resource(self, path: str, resource: IResource):
"""Registers a web resource to be served at the given path.

View File

@@ -20,3 +20,4 @@ from synapse.api.errors import ( # noqa: F401
SynapseError,
)
from synapse.config._base import ConfigError # noqa: F401
from synapse.handlers.sso import MappingException # noqa: F401

View File

@@ -25,7 +25,7 @@ class SAML2MetadataResource(Resource):
def __init__(self, hs):
Resource.__init__(self)
self.sp_config = hs.config.saml2_sp_config
self.sp_config = hs.get_saml_handler().saml2_sp_config
def render_GET(self, request):
metadata_xml = saml2.metadata.create_metadata_string(

View File

@@ -143,7 +143,7 @@ if TYPE_CHECKING:
from txredisapi import RedisProtocol
from synapse.handlers.oidc import OidcHandler
from synapse.handlers.saml import SamlHandler
from synapse.handlers.saml import Saml2UserMappingProvider, SamlHandler
T = TypeVar("T", bound=Callable[..., Any])
@@ -729,6 +729,12 @@ class HomeServer(metaclass=abc.ABCMeta):
return SamlHandler(self)
@cache_in_self
def get_saml2_user_mapping_provider(self) -> "Saml2UserMappingProvider":
from synapse.handlers.saml import Saml2UserMappingProvider
return Saml2UserMappingProvider(self)
@cache_in_self
def get_oidc_handler(self) -> "OidcHandler":
from synapse.handlers.oidc import OidcHandler

View File

@@ -213,3 +213,31 @@ def re_word_boundary(r: str) -> str:
# we can't use \b as it chokes on unicode. however \W seems to be okay
# as shorthand for [^0-9A-Za-z_].
return r"(^|\W)%s(\W|$)" % (r,)
def dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
Recursively merges `merge_dict` into `into_dict`:
* For keys where both `merge_dict` and `into_dict` have a dict value, the values
are recursively merged
* For all other keys, the values in `into_dict` (if any) are overwritten with
the value from `merge_dict`.
Args:
merge_dict (dict): dict to merge
into_dict (dict): target dict
"""
for k, v in merge_dict.items():
if k not in into_dict:
into_dict[k] = v
continue
current_val = into_dict[k]
if isinstance(v, dict) and isinstance(current_val, dict):
dict_merge(v, current_val)
continue
# otherwise we just overwrite
into_dict[k] = v

View File

@@ -27,6 +27,8 @@ try:
import saml2.config
from saml2.sigver import SigverError
from synapse.handlers.saml import load_default_or_legacy_saml2_mapping_provider
has_saml2 = True
# pysaml2 can be installed and imported, but might not be able to find xmlsec1.
@@ -51,7 +53,7 @@ class FakeAuthnResponse:
in_response_to = attr.ib(type=Optional[str], default=None)
class TestMappingProvider:
class LegacyTestMappingProvider:
def __init__(self, config, module):
pass
@@ -73,6 +75,31 @@ class TestMappingProvider:
return {"mxid_localpart": localpart, "displayname": None}
class LegacyTestRedirectMappingProvider(LegacyTestMappingProvider):
def saml_response_to_user_attributes(
self, saml_response, failures, client_redirect_url
):
raise RedirectException(b"https://custom-saml-redirect/")
class TestMappingProvider:
def __init__(self, config, api):
api.register_saml2_user_mapping_provider_callbacks(
get_remote_user_id=self.get_remote_user_id,
saml_response_to_user_attributes=self.saml_response_to_user_attributes,
saml_attributes=({"uid"}, {"displayName"}),
)
async def get_remote_user_id(self, saml_response, client_redirect_url):
return saml_response.ava["uid"]
async def saml_response_to_user_attributes(
self, saml_response, failures, client_redirect_url
):
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
return {"mxid_localpart": localpart, "displayname": None}
class TestRedirectMappingProvider(TestMappingProvider):
def saml_response_to_user_attributes(
self, saml_response, failures, client_redirect_url
@@ -88,7 +115,6 @@ class SamlHandlerTestCase(HomeserverTestCase):
"sp_config": {"metadata": {}},
# Disable grandfathering.
"grandfathered_mxid_source_attribute": None,
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
@@ -101,6 +127,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules:
module(config=config, api=module_api)
if not hs.get_saml2_user_mapping_provider().module_has_registered:
load_default_or_legacy_saml2_mapping_provider(hs)
self.handler = hs.get_saml_handler()
# Reduce the number of attempts when generating MXIDs.
@@ -114,7 +147,31 @@ class SamlHandlerTestCase(HomeserverTestCase):
elif not has_xmlsec1:
skip = "Requires xmlsec1"
@override_config(
{
"saml2_config": {
"user_mapping_provider": {
"module": __name__ + ".LegacyTestMappingProvider"
},
}
}
)
def test_map_saml_response_to_user_legacy(self):
self.map_saml_response_to_user_body()
@override_config(
{
"modules": [
{
"module": __name__ + ".TestMappingProvider",
}
]
}
)
def test_map_saml_response_to_user(self):
self.map_saml_response_to_user_body()
def map_saml_response_to_user_body(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
# stub out the auth handler
@@ -133,8 +190,35 @@ class SamlHandlerTestCase(HomeserverTestCase):
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@override_config(
{
"saml2_config": {
"user_mapping_provider": {
"module": __name__ + ".LegacyTestMappingProvider"
},
"grandfathered_mxid_source_attribute": "mxid",
}
}
)
def test_map_saml_response_to_existing_user_legacy(self):
self.map_saml_response_to_existing_user_body()
@override_config(
{
"modules": [
{
"module": __name__ + ".TestMappingProvider",
}
],
"saml2_config": {
"grandfathered_mxid_source_attribute": "mxid",
},
}
)
def test_map_saml_response_to_existing_user(self):
self.map_saml_response_to_existing_user_body()
def map_saml_response_to_existing_user_body(self):
"""Existing users can log in with SAML account."""
store = self.hs.get_datastore()
self.get_success(
@@ -168,7 +252,31 @@ class SamlHandlerTestCase(HomeserverTestCase):
"@test_user:test", "saml", request, "", None, new_user=False
)
@override_config(
{
"saml2_config": {
"user_mapping_provider": {
"module": __name__ + ".LegacyTestMappingProvider"
},
}
}
)
def test_map_saml_response_to_invalid_localpart_legacy(self):
self.map_saml_response_to_invalid_localpart_body()
@override_config(
{
"modules": [
{
"module": __name__ + ".TestMappingProvider",
}
]
}
)
def test_map_saml_response_to_invalid_localpart(self):
self.map_saml_response_to_invalid_localpart_body()
def map_saml_response_to_invalid_localpart_body(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
# stub out the auth handler
@@ -189,7 +297,31 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
auth_handler.complete_sso_login.assert_not_called()
@override_config(
{
"saml2_config": {
"user_mapping_provider": {
"module": __name__ + ".LegacyTestMappingProvider"
},
}
}
)
def test_map_saml_response_to_user_retries_legacy(self):
self.map_saml_response_to_user_retries_body()
@override_config(
{
"modules": [
{
"module": __name__ + ".TestMappingProvider",
}
]
}
)
def test_map_saml_response_to_user_retries(self):
self.map_saml_response_to_user_retries_body()
def map_saml_response_to_user_retries_body(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
# stub out the auth handler and error renderer
@@ -242,12 +374,27 @@ class SamlHandlerTestCase(HomeserverTestCase):
{
"saml2_config": {
"user_mapping_provider": {
"module": __name__ + ".TestRedirectMappingProvider"
"module": __name__ + ".LegacyTestRedirectMappingProvider"
},
}
}
)
def test_map_saml_response_redirect_legacy(self):
self.map_saml_response_redirect_body()
@override_config(
{
"modules": [
{
"module": __name__ + ".TestRedirectMappingProvider",
}
]
}
)
def test_map_saml_response_redirect(self):
self.map_saml_response_redirect_body()
def map_saml_response_redirect_body(self):
"""Test a mapping provider that raises a RedirectException"""
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@@ -260,6 +407,27 @@ class SamlHandlerTestCase(HomeserverTestCase):
@override_config(
{
"saml2_config": {
"user_mapping_provider": {
"module": __name__ + ".LegacyTestMappingProvider"
},
"attribute_requirements": [
{"attribute": "userGroup", "value": "staff"},
{"attribute": "department", "value": "sales"},
],
},
}
)
def test_attribute_requirements_legacy(self):
self.attribute_requirements_body()
@override_config(
{
"modules": [
{
"module": __name__ + ".TestMappingProvider",
}
],
"saml2_config": {
"attribute_requirements": [
{"attribute": "userGroup", "value": "staff"},
@@ -269,6 +437,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
)
def test_attribute_requirements(self):
self.attribute_requirements_body()
def attribute_requirements_body(self):
"""The required attributes must be met from the SAML response."""
# stub out the auth handler