1
0

Port the saml mapping providers to new module interface

This commit is contained in:
Azrenbeth
2021-08-24 11:33:02 +01:00
parent d12ba52f17
commit 2b3e4e856f
9 changed files with 389 additions and 147 deletions

View File

@@ -1544,7 +1544,9 @@ saml2_config:
#
# Default values will be used for the 'entityid' and 'service' settings,
# so it is not normally necessary to specify them unless you need to
# override them.
# override them. Note that setting 'service.sp.required_attributes' or
# 'service.sp.optional_attributes' here will override anything configured
# by a module that registers saml2 user mapping provider callbacks
#
sp_config:
# Point this to the IdP's metadata. You must provide either a local
@@ -1622,18 +1624,14 @@ saml2_config:
#
#saml_session_lifetime: 5m
# An external module can be provided here as a custom solution to
# mapping attributes returned from a saml provider onto a matrix user.
# Setting for the default mapping provider which maps attributes returned
# from a saml provider onto a matrix user. Custom solutions can be used by
# adding a module that provides these features to the 'modules' config
# section, in which case the following section will be ignored.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
#
#module: mapping_provider.SamlMappingProvider
# Custom configuration values for the module. Below options are
# intended for the built-in provider, they should be changed if
# using a custom module. This section will be passed as a Python
# dictionary to the module's `parse_config` method.
# intended for the built-in provider.
#
config:
# The SAML attribute (after mapping via the attribute maps) to use

View File

@@ -40,6 +40,7 @@ from synapse.crypto import context_factory
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.saml import load_default_or_legacy_saml2_mapping_provider
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -372,6 +373,11 @@ async def start(hs: "HomeServer"):
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs)
# 'module_has_registered' is true if a module calls 'register_saml2_user_mapping_provider_callbacks'
# Only one mapping provider can be set, so only load default (or legacy configured one) if this is
# still false
if not hs.get_saml2_user_mapping_provider().module_has_registered:
load_default_or_legacy_saml2_mapping_provider(hs)
# If we've configured an expiry time for caches, start the background job now.
setup_expire_lru_cache_entries(hs)

View File

@@ -18,6 +18,7 @@ from typing import Any, List
from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util import dict_merge
from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
@@ -33,34 +34,6 @@ LEGACY_USER_MAPPING_PROVIDER = (
)
def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
Recursively merges `merge_dict` into `into_dict`:
* For keys where both `merge_dict` and `into_dict` have a dict value, the values
are recursively merged
* For all other keys, the values in `into_dict` (if any) are overwritten with
the value from `merge_dict`.
Args:
merge_dict (dict): dict to merge
into_dict (dict): target dict
"""
for k, v in merge_dict.items():
if k not in into_dict:
into_dict[k] = v
continue
current_val = into_dict[k]
if isinstance(v, dict) and isinstance(current_val, dict):
_dict_merge(v, current_val)
continue
# otherwise we just overwrite
into_dict[k] = v
class SAML2Config(Config):
section = "saml2"
@@ -99,11 +72,15 @@ class SAML2Config(Config):
ump_dict = saml2_config.get("user_mapping_provider") or {}
# Use the default user mapping provider if not set
# NOTE this is the legacy way of using custom modules
# New style-modules should be placed in the 'modules:' config section
ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
if ump_dict.get("module") == LEGACY_USER_MAPPING_PROVIDER:
ump_dict["module"] = DEFAULT_USER_MAPPING_PROVIDER
# Ensure a config is present
# This is the config for the default mapping provider, or the legacy
# way of configuring a custom module
ump_dict["config"] = ump_dict.get("config") or {}
if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
@@ -132,59 +109,30 @@ class SAML2Config(Config):
self.saml2_user_mapping_provider_config,
) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
# Note parse_config() is already checked during the call to load_module
required_methods = [
"get_saml_attributes",
"saml_response_to_user_attributes",
"get_remote_user_id",
]
missing_methods = [
method
for method in required_methods
if not hasattr(self.saml2_user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by saml2_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
# Get the desired saml auth response attributes from the module
saml2_config_dict = self._default_saml_config_dict(
*self.saml2_user_mapping_provider_class.get_saml_attributes(
self.saml2_user_mapping_provider_config
)
)
_dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
# This is only the *base* config since a custom user mapping provider can change
# the values of 'service.sp.required_attributes' and 'service.sp.optional_attributes'
self.base_sp_config = self._default_sp_config_dict()
dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=self.base_sp_config
)
config_path = saml2_config.get("config_path", None)
if config_path is not None:
mod = load_python_module(config_path)
config = getattr(mod, "CONFIG", None)
if config is None:
sp_config_path = saml2_config.get("config_path", None)
if sp_config_path is not None:
mod = load_python_module(sp_config_path)
sp_config_from_file = getattr(mod, "CONFIG", None)
if sp_config_from_file is None:
raise ConfigError(
"Config path specified by saml2_config.config_path does not "
"have a CONFIG property."
)
_dict_merge(merge_dict=config, into_dict=saml2_config_dict)
import saml2.config
self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(saml2_config_dict)
dict_merge(merge_dict=sp_config_from_file, into_dict=self.base_sp_config)
# session lifetime: in milliseconds
self.saml2_session_lifetime = self.parse_duration(
saml2_config.get("saml_session_lifetime", "15m")
)
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
):
def _default_sp_config_dict(self):
"""Generate a configuration dictionary with required and optional attributes that
will be needed to process new user registration
@@ -203,10 +151,6 @@ class SAML2Config(Config):
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml"
response_url = public_baseurl + "_synapse/client/saml2/authn_response"
return {
@@ -218,8 +162,6 @@ class SAML2Config(Config):
(response_url, saml2.BINDING_HTTP_POST)
]
},
"required_attributes": list(required_attributes),
"optional_attributes": list(optional_attributes),
# "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT,
}
},
@@ -257,7 +199,9 @@ class SAML2Config(Config):
#
# Default values will be used for the 'entityid' and 'service' settings,
# so it is not normally necessary to specify them unless you need to
# override them.
# override them. Note that setting 'service.sp.required_attributes' or
# 'service.sp.optional_attributes' here will override anything configured
# by a module that registers saml2 user mapping provider callbacks
#
sp_config:
# Point this to the IdP's metadata. You must provide either a local
@@ -335,18 +279,14 @@ class SAML2Config(Config):
#
#saml_session_lifetime: 5m
# An external module can be provided here as a custom solution to
# mapping attributes returned from a saml provider onto a matrix user.
# Setting for the default mapping provider which maps attributes returned
# from a saml provider onto a matrix user. Custom solutions can be used by
# adding a module that provides these features to the 'modules' config
# section, in which case the following section will be ignored.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
#
#module: mapping_provider.SamlMappingProvider
# Custom configuration values for the module. Below options are
# intended for the built-in provider, they should be changed if
# using a custom module. This section will be passed as a Python
# dictionary to the module's `parse_config` method.
# intended for the built-in provider.
#
config:
# The SAML attribute (after mapping via the attribute maps) to use

View File

@@ -13,15 +13,16 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Set, Tuple
import attr
import saml2
import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.api.errors import RedirectException
from synapse.config import ConfigError
from synapse.config.saml2 import DEFAULT_USER_MAPPING_PROVIDER
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
@@ -32,6 +33,8 @@ from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
from synapse.util import dict_merge
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -54,7 +57,50 @@ class Saml2SessionData:
class SamlHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
# If support for legacy saml2_mapping_providers is dropped then this
# is where the DefaultSamlMappingProvider should be loaded
self._user_mapping_provider = hs.get_saml2_user_mapping_provider()
# At this point either a module will have registered user mapping provider
# callbacks or the default will have been registered.
assert self._user_mapping_provider.module_has_registered
# Merge the required and optional saml_attributes registered by the mapping
# provider with the base sp config. NOTE: If there are conflicts then the
# module's expected attributes are overwritten by the base sp_config. This is
# how it worked with legacy modules.
(
required_attributes,
optional_attributes,
) = self._user_mapping_provider.get_saml_attributes()
# Required for backwards compatability
if hs.config.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(hs.config.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
sp_config_dict = {
"service": {
"sp": {
"required_attributes": list(required_attributes),
"optional_attributes": list(optional_attributes),
}
},
}
# Merged this way around for backwards compatability
dict_merge(
merge_dict=hs.config.saml2.base_sp_config,
into_dict=sp_config_dict,
)
self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(sp_config_dict)
self._saml_client = Saml2Client(self.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
@@ -64,12 +110,6 @@ class SamlHandler(BaseHandler):
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
self._error_template = hs.config.sso_error_template
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
hs.config.saml2_user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
)
# identifier for the external_ids table
self.idp_id = "saml"
@@ -222,7 +262,9 @@ class SamlHandler(BaseHandler):
# first check if we're doing a UIA
if current_session and current_session.ui_auth_session_id:
try:
remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
remote_user_id = await self._user_mapping_provider.get_remote_user_id(
saml2_auth, None
)
except MappingException as e:
logger.exception("Failed to extract remote user id from SAML response")
self._sso_handler.render_error(request, "mapping_error", str(e))
@@ -273,7 +315,7 @@ class SamlHandler(BaseHandler):
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
remote_user_id = self._remote_id_from_saml_response(
remote_user_id = await self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
@@ -286,7 +328,7 @@ class SamlHandler(BaseHandler):
This is backwards compatibility for abstraction for the SSO handler.
"""
# Call the mapping provider.
result = self._user_mapping_provider.saml_response_to_user_attributes(
result = await self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, failures, client_redirect_url
)
# Remap some of the results.
@@ -331,35 +373,6 @@ class SamlHandler(BaseHandler):
grandfather_existing_users,
)
def _remote_id_from_saml_response(
self,
saml2_auth: saml2.response.AuthnResponse,
client_redirect_url: Optional[str],
) -> str:
"""Extract the unique remote id from a SAML2 AuthnResponse
Args:
saml2_auth: The parsed SAML2 response.
client_redirect_url: The redirect URL passed in by the client.
Returns:
remote user id
Raises:
MappingException if there was an error extracting the user id
"""
# It's not obvious why we need to pass in the redirect URI to the mapping
# provider, but we do :/
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
if not remote_user_id:
raise MappingException(
"Failed to extract remote user id from SAML response"
)
return remote_user_id
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
@@ -398,6 +411,15 @@ class SamlConfig:
mxid_mapper = attr.ib()
# The type definition for the user mapping provider callbacks
GET_REMOTE_USER_ID_CALLBACK = Callable[
[saml2.response.AuthnResponse, Optional[str]], Awaitable[str]
]
SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK = Callable[
[saml2.response.AuthnResponse, int, str], Awaitable[Dict]
]
class DefaultSamlMappingProvider:
__version__ = "0.0.1"
@@ -411,12 +433,19 @@ class DefaultSamlMappingProvider:
self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper
self._grandfathered_mxid_source_attribute = (
module_api._hs.config.saml2_grandfathered_mxid_source_attribute
module_api.register_saml2_user_mapping_provider_callbacks(
get_remote_user_id=self.get_remote_user_id,
saml_response_to_user_attributes=self.saml_response_to_user_attributes,
saml_attributes=(
{"uid", self._mxid_source_attribute},
{"displayName", "email"},
),
)
def get_remote_user_id(
self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
async def get_remote_user_id(
self,
saml_response: saml2.response.AuthnResponse,
client_redirect_url: Optional[str],
) -> str:
"""Extracts the remote user id from the SAML response"""
try:
@@ -425,7 +454,7 @@ class DefaultSamlMappingProvider:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise MappingException("'uid' not in SAML2 response")
def saml_response_to_user_attributes(
async def saml_response_to_user_attributes(
self,
saml_response: saml2.response.AuthnResponse,
failures: int,
@@ -454,8 +483,8 @@ class DefaultSamlMappingProvider:
"SAML2 response lacks a '%s' attestation",
self._mxid_source_attribute,
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
raise MappingException(
"%s not in SAML2 response" % (self._mxid_source_attribute,)
)
# Use the configured mapper for this mxid_source
@@ -501,8 +530,235 @@ class DefaultSamlMappingProvider:
return SamlConfig(mxid_source_attribute, mxid_mapper)
@staticmethod
def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
def load_default_or_legacy_saml2_mapping_provider(hs: "HomeServer"):
"""Wrapper that loads a saml2 mapping provider either from the default module or
configured using the legacy configuration. Legacy modules then have their callbacks
registered
"""
if hs.config.saml2.saml2_user_mapping_provider_class is None:
# This should be an impossible position to be in
raise RuntimeError("No default saml2 user mapping provider is set")
module = hs.config.saml2.saml2_user_mapping_provider_class
config = hs.config.saml2.saml2_user_mapping_provider_config
api = hs.get_module_api()
mapping_provider = module(config, api)
# if we were loading the default provider, then it has already registered its callbacks!
# so we can stop here
if module == DEFAULT_USER_MAPPING_PROVIDER:
return
# The required hooks. If a custom module doesn't implement all of these then raise an error
required_mapping_provider_methods = {
"get_saml_attributes",
"saml_response_to_user_attributes",
"get_remote_user_id",
}
missing_methods = [
method
for method in required_mapping_provider_methods
if not hasattr(module, method)
]
if missing_methods:
raise RuntimeError(
"Class specified by saml2_config."
" user_mapping_provider.module is missing required"
" methods: %s" % (", ".join(missing_methods),)
)
# New modules have to proactively register this instead of just the callback
saml_attributes = mapping_provider.get_saml_attributes(config)
mapping_provider_methods = {
"saml_response_to_user_attributes",
"get_remote_user_id",
}
# Methods that the module provides should be async, but this wasn't the case
# in the old module system, so we wrap them if needed
def async_wrapper(f: Callable) -> Callable[..., Awaitable]:
def run(*args, **kwargs):
return maybe_awaitable(f(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks = {
hook: async_wrapper(getattr(mapping_provider, hook, None))
for hook in mapping_provider_methods
}
api.register_saml2_user_mapping_provider_callbacks(
saml_attributes=saml_attributes, **hooks
)
class Saml2UserMappingProvider:
def __init__(self, hs: "HomeServer"):
"""The SAML user mapping provider
Args:
parsed_config: Module configuration
module_api: module api proxy
"""
# self._mxid_source_attribute = parsed_config.mxid_source_attribute
# self._mxid_mapper = parsed_config.mxid_mapper
self.get_remote_user_id_callback: Optional[GET_REMOTE_USER_ID_CALLBACK] = None
self.saml_response_to_user_attributes_callback: Optional[
SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK
] = None
self.saml_attributes: Tuple[Set[str], Set[str]] = set(), set()
self.module_has_registered = False
def register_saml2_user_mapping_provider_callbacks(
self,
get_remote_user_id: GET_REMOTE_USER_ID_CALLBACK,
saml_response_to_user_attributes: SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK,
saml_attributes: Tuple[Set[str], Set[str]],
):
"""Called by modules to register callbacks and saml_attributes"""
# Only one module can register callbacks
if self.module_has_registered:
raise RuntimeError(
"Multiple modules have attempted to register as saml mapping providers"
)
self.module_has_registered = True
self.get_remote_user_id_callback = get_remote_user_id
self.saml_response_to_user_attributes_callback = (
saml_response_to_user_attributes
)
self.saml_attributes = saml_attributes
async def get_remote_user_id(
self,
saml_response: saml2.response.AuthnResponse,
client_redirect_url: Optional[str],
) -> str:
"""Extracts the remote user id from the SAML response
Args:
saml2_auth: The parsed SAML2 response.
client_redirect_url: The redirect URL passed in by the client. This may
be None.
Returns:
remote user id
Raises:
MappingException: if there was an error extracting the user id
Any other exception: for backwards compatability
"""
# If no module has registered callbacks then raise an error
if not self.module_has_registered:
raise RuntimeError("No Saml2 mapping provider has been registered")
assert self.get_remote_user_id_callback is not None
try:
result = await self.get_remote_user_id_callback(
saml_response, client_redirect_url
)
except MappingException:
# Mapping providers are allowed to issue a mapping exception
# if a remote user id cannot be generated.
raise
except Exception as e:
logger.warning(
f"Something went wrong when calling custom module callback for get_remote_user_id: {e}"
)
# for compatablity with legacy modules, need to raise this exception as is:
raise e
# # If the module raises some other sort of exception then don't display that to the user
# raise MappingException(
# "Failed to extract remote user id from SAML response"
# )
if not isinstance(result, str):
logger.warning( # type: ignore[unreachable]
f"Wrong type returned by module callback for get_remote_user_id: {result}, expected str"
)
# Don't overshare to the user, as something has clearly gone wrong
raise MappingException(
"Failed to extract remote user id from SAML response"
)
return result
async def saml_response_to_user_attributes(
self,
saml_response: saml2.response.AuthnResponse,
failures: int,
client_redirect_url: str,
) -> dict:
"""Maps some text from a SAML response to attributes of a new user
Args:
saml_response: A SAML auth response object
failures: How many times a call to this function with this
saml_response has resulted in a failure
client_redirect_url: where the client wants to redirect to
Returns:
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user
* emails (list[str]): Any emails for the user
Raises:
MappingException: if something goes wrong while processing the response
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
Any other exception: for backwards compatability
"""
# If no module has registered callbacks then raise an error
if not self.module_has_registered:
raise RuntimeError("No Saml2 mapping provider has been registered")
assert self.saml_response_to_user_attributes_callback is not None
try:
result = await self.saml_response_to_user_attributes_callback(
saml_response, failures, client_redirect_url
)
except (RedirectException, MappingException):
# Mapping providers are allowed to issue a redirect (e.g. to ask
# the user for more information) and can issue a mapping exception
# if a name cannot be generated.
raise
except Exception as e:
logger.warning(
f"Something went wrong when calling custom module callback for saml_response_to_user_attributes: {e}"
)
# for compatablity with legacy modules, need to raise this exception as is:
raise e
# # If the module raises some other sort of exception then don't display that to the user
# raise MappingException(
# "Unable to map from SAML2 response to user attributes"
# )
if not isinstance(result, dict):
logger.warning( # type: ignore[unreachable]
f"Wrong type returned by module callback for get_remote_user_id: {result}, expected dict"
)
# Don't overshare to the user, as something has clearly gone wrong
raise MappingException(
"Unable to map from SAML2 response to user attributes"
)
return result
def get_saml_attributes(self) -> Tuple[Set[str], Set[str]]:
"""Returns the required attributes of a SAML
Args:
@@ -514,4 +770,4 @@ class DefaultSamlMappingProvider:
second set consists of those attributes which can be used if
available, but are not necessary
"""
return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
return self.saml_attributes

View File

@@ -117,6 +117,7 @@ class ModuleApi:
self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules()
self._presence_router = hs.get_presence_router()
self._saml2_user_mapping_provider = hs.get_saml2_user_mapping_provider()
#################################################################################
# The following methods should only be called during the module's initialisation.
@@ -141,6 +142,13 @@ class ModuleApi:
"""Registers callbacks for presence router capabilities."""
return self._presence_router.register_presence_router_callbacks
@property
def register_saml2_user_mapping_provider_callbacks(self):
"""Registers callbacks for presence router capabilities."""
return (
self._saml2_user_mapping_provider.register_saml2_user_mapping_provider_callbacks
)
def register_web_resource(self, path: str, resource: IResource):
"""Registers a web resource to be served at the given path.

View File

@@ -20,3 +20,4 @@ from synapse.api.errors import ( # noqa: F401
SynapseError,
)
from synapse.config._base import ConfigError # noqa: F401
from synapse.handlers.sso import MappingException # noqa: F401

View File

@@ -25,7 +25,7 @@ class SAML2MetadataResource(Resource):
def __init__(self, hs):
Resource.__init__(self)
self.sp_config = hs.config.saml2_sp_config
self.sp_config = hs.get_saml_handler().saml2_sp_config
def render_GET(self, request):
metadata_xml = saml2.metadata.create_metadata_string(

View File

@@ -100,6 +100,7 @@ from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.room_summary import RoomSummaryHandler
from synapse.handlers.saml import Saml2UserMappingProvider
from synapse.handlers.search import SearchHandler
from synapse.handlers.send_email import SendEmailHandler
from synapse.handlers.set_password import SetPasswordHandler
@@ -729,6 +730,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return SamlHandler(self)
@cache_in_self
def get_saml2_user_mapping_provider(self) -> "Saml2UserMappingProvider":
return Saml2UserMappingProvider(self)
@cache_in_self
def get_oidc_handler(self) -> "OidcHandler":
from synapse.handlers.oidc import OidcHandler

View File

@@ -213,3 +213,31 @@ def re_word_boundary(r: str) -> str:
# we can't use \b as it chokes on unicode. however \W seems to be okay
# as shorthand for [^0-9A-Za-z_].
return r"(^|\W)%s(\W|$)" % (r,)
def dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
Recursively merges `merge_dict` into `into_dict`:
* For keys where both `merge_dict` and `into_dict` have a dict value, the values
are recursively merged
* For all other keys, the values in `into_dict` (if any) are overwritten with
the value from `merge_dict`.
Args:
merge_dict (dict): dict to merge
into_dict (dict): target dict
"""
for k, v in merge_dict.items():
if k not in into_dict:
into_dict[k] = v
continue
current_val = into_dict[k]
if isinstance(v, dict) and isinstance(current_val, dict):
dict_merge(v, current_val)
continue
# otherwise we just overwrite
into_dict[k] = v