Merge commit '01333681b' into anoa/dinsic_release_1_31_0
This commit is contained in:
1
changelog.d/8911.feature
Normal file
1
changelog.d/8911.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add support for allowing users to pick their own user ID during a single-sign-on login.
|
||||
@@ -1 +1 @@
|
||||
Improve structured logging tests.
|
||||
Various clean-ups to the structured logging and logging context code.
|
||||
|
||||
1
changelog.d/8935.misc
Normal file
1
changelog.d/8935.misc
Normal file
@@ -0,0 +1 @@
|
||||
Various clean-ups to the structured logging and logging context code.
|
||||
1
changelog.d/8937.bugfix
Normal file
1
changelog.d/8937.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix bug introduced in Synapse v1.24.0 which would cause an exception on startup if both `enabled` and `localdb_enabled` were set to `False` in the `password_config` setting of the configuration file.
|
||||
1
changelog.d/8938.feature
Normal file
1
changelog.d/8938.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add support for allowing users to pick their own user ID during a single-sign-on login.
|
||||
1
changelog.d/8943.misc
Normal file
1
changelog.d/8943.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to push module.
|
||||
@@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
|
||||
# filter options, but care must when using e.g. MemoryHandler to buffer
|
||||
# writes.
|
||||
|
||||
log_context_filter = LoggingContextFilter(request="")
|
||||
log_context_filter = LoggingContextFilter()
|
||||
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
|
||||
old_factory = logging.getLogRecordFactory()
|
||||
|
||||
|
||||
@@ -198,27 +198,25 @@ class AuthHandler(BaseHandler):
|
||||
self._password_enabled = hs.config.password_enabled
|
||||
self._password_localdb_enabled = hs.config.password_localdb_enabled
|
||||
|
||||
# we keep this as a list despite the O(N^2) implication so that we can
|
||||
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||
# type in the list. (NB that the spec doesn't require us to do so and
|
||||
# clients which favour types that they don't understand over those that
|
||||
# they do are technically broken)
|
||||
|
||||
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
||||
login_types = []
|
||||
login_types = set()
|
||||
if self._password_localdb_enabled:
|
||||
login_types.append(LoginType.PASSWORD)
|
||||
login_types.add(LoginType.PASSWORD)
|
||||
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "get_supported_login_types"):
|
||||
for t in provider.get_supported_login_types().keys():
|
||||
if t not in login_types:
|
||||
login_types.append(t)
|
||||
login_types.update(provider.get_supported_login_types().keys())
|
||||
|
||||
if not self._password_enabled:
|
||||
login_types.remove(LoginType.PASSWORD)
|
||||
login_types.discard(LoginType.PASSWORD)
|
||||
|
||||
self._supported_login_types = login_types
|
||||
# Some clients just pick the first type in the list. In this case, we want
|
||||
# them to use PASSWORD (rather than token or whatever), so we want to make sure
|
||||
# that comes first, where it's present.
|
||||
self._supported_login_types = []
|
||||
if LoginType.PASSWORD in login_types:
|
||||
self._supported_login_types.append(LoginType.PASSWORD)
|
||||
login_types.remove(LoginType.PASSWORD)
|
||||
self._supported_login_types.extend(login_types)
|
||||
|
||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||
# as per `rc_login.failed_attempts`.
|
||||
|
||||
@@ -163,6 +163,29 @@ class SamlHandler(BaseHandler):
|
||||
return
|
||||
|
||||
logger.debug("SAML2 response: %s", saml2_auth.origxml)
|
||||
|
||||
await self._handle_authn_response(request, saml2_auth, relay_state)
|
||||
|
||||
async def _handle_authn_response(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
saml2_auth: saml2.response.AuthnResponse,
|
||||
relay_state: str,
|
||||
) -> None:
|
||||
"""Handle an AuthnResponse, having parsed it from the request params
|
||||
|
||||
Assumes that the signature on the response object has been checked. Maps
|
||||
the user onto an MXID, registering them if necessary, and returns a response
|
||||
to the browser.
|
||||
|
||||
Args:
|
||||
request: the incoming request from the browser. We'll respond to it with an
|
||||
HTML page or a redirect
|
||||
saml2_auth: the parsed AuthnResponse object
|
||||
relay_state: the RelayState query param, which encodes the URI to rediret
|
||||
back to
|
||||
"""
|
||||
|
||||
for assertion in saml2_auth.assertions:
|
||||
# kibana limits the length of a log field, whereas this is all rather
|
||||
# useful, so split it up.
|
||||
|
||||
@@ -128,8 +128,7 @@ class SynapseRequest(Request):
|
||||
|
||||
# create a LogContext for this request
|
||||
request_id = self.get_request_id()
|
||||
logcontext = self.logcontext = LoggingContext(request_id)
|
||||
logcontext.request = request_id
|
||||
self.logcontext = LoggingContext(request_id, request=request_id)
|
||||
|
||||
# override the Server header which is set by twisted
|
||||
self.setHeader("Server", self.site.server_version_string)
|
||||
|
||||
@@ -203,10 +203,6 @@ class _Sentinel:
|
||||
def copy_to(self, record):
|
||||
pass
|
||||
|
||||
def copy_to_twisted_log_entry(self, record):
|
||||
record["request"] = None
|
||||
record["scope"] = None
|
||||
|
||||
def start(self, rusage: "Optional[resource._RUsage]"):
|
||||
pass
|
||||
|
||||
@@ -372,13 +368,6 @@ class LoggingContext:
|
||||
# we also track the current scope:
|
||||
record.scope = self.scope
|
||||
|
||||
def copy_to_twisted_log_entry(self, record) -> None:
|
||||
"""
|
||||
Copy logging fields from this context to a Twisted log record.
|
||||
"""
|
||||
record["request"] = self.request
|
||||
record["scope"] = self.scope
|
||||
|
||||
def start(self, rusage: "Optional[resource._RUsage]") -> None:
|
||||
"""
|
||||
Record that this logcontext is currently running.
|
||||
@@ -542,13 +531,10 @@ class LoggingContext:
|
||||
class LoggingContextFilter(logging.Filter):
|
||||
"""Logging filter that adds values from the current logging context to each
|
||||
record.
|
||||
Args:
|
||||
**defaults: Default values to avoid formatters complaining about
|
||||
missing fields
|
||||
"""
|
||||
|
||||
def __init__(self, **defaults) -> None:
|
||||
self.defaults = defaults
|
||||
def __init__(self, request: str = ""):
|
||||
self._default_request = request
|
||||
|
||||
def filter(self, record) -> Literal[True]:
|
||||
"""Add each fields from the logging contexts to the record.
|
||||
@@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter):
|
||||
True to include the record in the log output.
|
||||
"""
|
||||
context = current_context()
|
||||
for key, value in self.defaults.items():
|
||||
setattr(record, key, value)
|
||||
record.request = self._default_request
|
||||
|
||||
# context should never be None, but if it somehow ends up being, then
|
||||
# we end up in a death spiral of infinite loops, so let's check, for
|
||||
# robustness' sake.
|
||||
if context is not None:
|
||||
context.copy_to(record)
|
||||
# Logging is interested in the request.
|
||||
record.request = context.request
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -199,8 +199,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
|
||||
_background_process_start_count.labels(desc).inc()
|
||||
_background_process_in_flight_count.labels(desc).inc()
|
||||
|
||||
with BackgroundProcessLoggingContext(desc) as context:
|
||||
context.request = "%s-%i" % (desc, count)
|
||||
with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
|
||||
try:
|
||||
ctx = noop_context_manager()
|
||||
if bg_start_span:
|
||||
@@ -244,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext):
|
||||
|
||||
__slots__ = ["_proc"]
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
def __init__(self, name: str, request: Optional[str] = None):
|
||||
super().__init__(name, request=request)
|
||||
|
||||
self._proc = _BackgroundProcess(name, self)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from synapse.types import RoomStreamToken
|
||||
|
||||
@@ -36,12 +36,21 @@ class Pusher(metaclass=abc.ABCMeta):
|
||||
# This is the highest stream ordering we know it's safe to process.
|
||||
# When new events arrive, we'll be given a window of new events: we
|
||||
# should honour this rather than just looking for anything higher
|
||||
# because of potential out-of-order event serialisation. This starts
|
||||
# off as None though as we don't know any better.
|
||||
self.max_stream_ordering = None # type: Optional[int]
|
||||
# because of potential out-of-order event serialisation.
|
||||
self.max_stream_ordering = self.store.get_room_max_stream_ordering()
|
||||
|
||||
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
|
||||
# We just use the minimum stream ordering and ignore the vector clock
|
||||
# component. This is safe to do as long as we *always* ignore the vector
|
||||
# clock components.
|
||||
max_stream_ordering = max_token.stream
|
||||
|
||||
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
|
||||
self._start_processing()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
|
||||
def _start_processing(self):
|
||||
"""Start processing push notifications."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -22,7 +22,6 @@ from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.push import Pusher
|
||||
from synapse.push.mailer import Mailer
|
||||
from synapse.types import RoomStreamToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
@@ -93,20 +92,6 @@ class EmailPusher(Pusher):
|
||||
pass
|
||||
self.timed_call = None
|
||||
|
||||
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
|
||||
# We just use the minimum stream ordering and ignore the vector clock
|
||||
# component. This is safe to do as long as we *always* ignore the vector
|
||||
# clock components.
|
||||
max_stream_ordering = max_token.stream
|
||||
|
||||
if self.max_stream_ordering:
|
||||
self.max_stream_ordering = max(
|
||||
max_stream_ordering, self.max_stream_ordering
|
||||
)
|
||||
else:
|
||||
self.max_stream_ordering = max_stream_ordering
|
||||
self._start_processing()
|
||||
|
||||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
|
||||
# We could wake up and cancel the timer but there tend to be quite a
|
||||
# lot of read receipts so it's probably less work to just let the
|
||||
@@ -172,7 +157,6 @@ class EmailPusher(Pusher):
|
||||
being run.
|
||||
"""
|
||||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
||||
assert self.max_stream_ordering is not None
|
||||
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
|
||||
self.user_id, start, self.max_stream_ordering
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ from synapse.events import EventBase
|
||||
from synapse.logging import opentracing
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.push import Pusher, PusherConfigException
|
||||
from synapse.types import RoomStreamToken
|
||||
|
||||
from . import push_rule_evaluator, push_tools
|
||||
|
||||
@@ -122,17 +121,6 @@ class HttpPusher(Pusher):
|
||||
if should_check_for_notifs:
|
||||
self._start_processing()
|
||||
|
||||
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
|
||||
# We just use the minimum stream ordering and ignore the vector clock
|
||||
# component. This is safe to do as long as we *always* ignore the vector
|
||||
# clock components.
|
||||
max_stream_ordering = max_token.stream
|
||||
|
||||
self.max_stream_ordering = max(
|
||||
max_stream_ordering, self.max_stream_ordering or 0
|
||||
)
|
||||
self._start_processing()
|
||||
|
||||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
|
||||
# Note that the min here shouldn't be relied upon to be accurate.
|
||||
|
||||
@@ -192,10 +180,7 @@ class HttpPusher(Pusher):
|
||||
Never call this directly: use _process which will only allow this to
|
||||
run once per pusher.
|
||||
"""
|
||||
|
||||
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
|
||||
assert self.max_stream_ordering is not None
|
||||
unprocessed = await fn(
|
||||
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
|
||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||
)
|
||||
|
||||
|
||||
@@ -129,9 +129,8 @@ class PusherPool:
|
||||
)
|
||||
|
||||
# create the pusher setting last_stream_ordering to the current maximum
|
||||
# stream ordering in event_push_actions, so it will process
|
||||
# pushes from this point onwards.
|
||||
last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
|
||||
# stream ordering, so it will process pushes from this point onwards.
|
||||
last_stream_ordering = self.store.get_room_max_stream_ordering()
|
||||
|
||||
await self.store.add_pusher(
|
||||
user_id=user_id,
|
||||
|
||||
@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
# a logcontext which we use for processing incoming commands. We declare it as a
|
||||
# background process so that the CPU stats get reported to prometheus.
|
||||
ctx_name = "replication-conn-%s" % self.conn_id
|
||||
self._logging_context = BackgroundProcessLoggingContext(ctx_name)
|
||||
self._logging_context.request = ctx_name
|
||||
self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
|
||||
|
||||
def connectionMade(self):
|
||||
logger.info("[%s] Connection established", self.id())
|
||||
|
||||
@@ -894,16 +894,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
|
||||
return push_actions
|
||||
|
||||
async def get_latest_push_action_stream_ordering(self):
|
||||
def f(txn):
|
||||
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
|
||||
return txn.fetchone()
|
||||
|
||||
result = await self.db_pool.runInteraction(
|
||||
"get_latest_push_action_stream_ordering", f
|
||||
)
|
||||
return result[0] or 0
|
||||
|
||||
def _remove_old_push_actions_before_txn(
|
||||
self, txn, room_id, user_id, stream_ordering
|
||||
):
|
||||
|
||||
@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||
room_version,
|
||||
)
|
||||
|
||||
with LoggingContext(request="send_rejected"):
|
||||
with LoggingContext("send_rejected"):
|
||||
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
|
||||
self.get_success(d)
|
||||
|
||||
@@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||
room_version,
|
||||
)
|
||||
|
||||
with LoggingContext(request="send_rejected"):
|
||||
with LoggingContext("send_rejected"):
|
||||
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
|
||||
self.get_success(d)
|
||||
|
||||
@@ -198,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||
# the auth code requires that a signature exists, but doesn't check that
|
||||
# signature... go figure.
|
||||
join_event.signatures[other_server] = {"x": "y"}
|
||||
with LoggingContext(request="send_join"):
|
||||
with LoggingContext("send_join"):
|
||||
d = run_in_background(
|
||||
self.handler.on_send_join_request, other_server, join_event
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import json
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from mock import Mock, patch
|
||||
from mock import ANY, Mock, patch
|
||||
|
||||
import pymacaroons
|
||||
|
||||
@@ -23,7 +23,7 @@ from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests.test_utils import FakeResponse
|
||||
from tests.test_utils import FakeResponse, simple_async_mock
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
# These are a few constants that are used as config parameters in the tests.
|
||||
@@ -82,16 +82,6 @@ class TestMappingProviderFailures(TestMappingProvider):
|
||||
}
|
||||
|
||||
|
||||
def simple_async_mock(return_value=None, raises=None):
|
||||
# AsyncMock is not available in python3.5, this mimics part of its behaviour
|
||||
async def cb(*args, **kwargs):
|
||||
if raises:
|
||||
raise raises
|
||||
return return_value
|
||||
|
||||
return Mock(side_effect=cb)
|
||||
|
||||
|
||||
async def get_json(url):
|
||||
# Mock get_json calls to handle jwks & oidc discovery endpoints
|
||||
if url == WELL_KNOWN:
|
||||
@@ -160,6 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.assertEqual(args[2], error_description)
|
||||
# Reset the render_error mock
|
||||
self.render_error.reset_mock()
|
||||
return args
|
||||
|
||||
def test_config(self):
|
||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||
@@ -374,26 +365,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"id_token": "id_token",
|
||||
"access_token": "access_token",
|
||||
}
|
||||
username = "bar"
|
||||
userinfo = {
|
||||
"sub": "foo",
|
||||
"preferred_username": "bar",
|
||||
"username": username,
|
||||
}
|
||||
user_id = "@foo:domain.org"
|
||||
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
||||
self.handler._auth_handler.complete_sso_login = simple_async_mock()
|
||||
request = Mock(
|
||||
spec=[
|
||||
"args",
|
||||
"getCookie",
|
||||
"addCookie",
|
||||
"requestHeaders",
|
||||
"getClientIP",
|
||||
"get_user_agent",
|
||||
]
|
||||
)
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
code = "code"
|
||||
state = "state"
|
||||
@@ -401,64 +383,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
client_redirect_url = "http://client/redirect"
|
||||
user_agent = "Browser"
|
||||
ip_address = "10.0.0.1"
|
||||
request.getCookie.return_value = self.handler._generate_oidc_session_token(
|
||||
session = self.handler._generate_oidc_session_token(
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=None,
|
||||
)
|
||||
|
||||
request.args = {}
|
||||
request.args[b"code"] = [code.encode("utf-8")]
|
||||
request.args[b"state"] = [state.encode("utf-8")]
|
||||
|
||||
request.getClientIP.return_value = ip_address
|
||||
request.get_user_agent.return_value = user_agent
|
||||
request = self._build_callback_request(
|
||||
code, state, session, user_agent=user_agent, ip_address=ip_address
|
||||
)
|
||||
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user_id, request, client_redirect_url, {},
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
expected_user_id, request, client_redirect_url, {},
|
||||
)
|
||||
self.handler._exchange_code.assert_called_once_with(code)
|
||||
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||
self.handler._map_userinfo_to_user.assert_called_once_with(
|
||||
userinfo, token, user_agent, ip_address
|
||||
)
|
||||
self.handler._fetch_userinfo.assert_not_called()
|
||||
self.render_error.assert_not_called()
|
||||
|
||||
# Handle mapping errors
|
||||
self.handler._map_userinfo_to_user = simple_async_mock(
|
||||
raises=MappingException()
|
||||
)
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("mapping_error")
|
||||
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
||||
with patch.object(
|
||||
self.handler,
|
||||
"_remote_id_from_userinfo",
|
||||
new=Mock(side_effect=MappingException()),
|
||||
):
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("mapping_error")
|
||||
|
||||
# Handle ID token errors
|
||||
self.handler._parse_id_token = simple_async_mock(raises=Exception())
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_token")
|
||||
|
||||
self.handler._auth_handler.complete_sso_login.reset_mock()
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
self.handler._exchange_code.reset_mock()
|
||||
self.handler._parse_id_token.reset_mock()
|
||||
self.handler._map_userinfo_to_user.reset_mock()
|
||||
self.handler._fetch_userinfo.reset_mock()
|
||||
|
||||
# With userinfo fetching
|
||||
self.handler._scopes = [] # do not ask the "openid" scope
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user_id, request, client_redirect_url, {},
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
expected_user_id, request, client_redirect_url, {},
|
||||
)
|
||||
self.handler._exchange_code.assert_called_once_with(code)
|
||||
self.handler._parse_id_token.assert_not_called()
|
||||
self.handler._map_userinfo_to_user.assert_called_once_with(
|
||||
userinfo, token, user_agent, ip_address
|
||||
)
|
||||
self.handler._fetch_userinfo.assert_called_once_with(token)
|
||||
self.render_error.assert_not_called()
|
||||
|
||||
@@ -609,72 +581,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
userinfo = {
|
||||
"sub": "foo",
|
||||
"username": "foo",
|
||||
"phone": "1234567",
|
||||
}
|
||||
user_id = "@foo:domain.org"
|
||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
||||
self.handler._auth_handler.complete_sso_login = simple_async_mock()
|
||||
request = Mock(
|
||||
spec=[
|
||||
"args",
|
||||
"getCookie",
|
||||
"addCookie",
|
||||
"requestHeaders",
|
||||
"getClientIP",
|
||||
"get_user_agent",
|
||||
]
|
||||
)
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
state = "state"
|
||||
client_redirect_url = "http://client/redirect"
|
||||
request.getCookie.return_value = self.handler._generate_oidc_session_token(
|
||||
session = self.handler._generate_oidc_session_token(
|
||||
state=state,
|
||||
nonce="nonce",
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=None,
|
||||
)
|
||||
|
||||
request.args = {}
|
||||
request.args[b"code"] = [b"code"]
|
||||
request.args[b"state"] = [state.encode("utf-8")]
|
||||
|
||||
request.getClientIP.return_value = "10.0.0.1"
|
||||
request.get_user_agent.return_value = "Browser"
|
||||
request = self._build_callback_request("code", state, session)
|
||||
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user_id, request, client_redirect_url, {"phone": "1234567"},
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@foo:test", request, client_redirect_url, {"phone": "1234567"},
|
||||
)
|
||||
|
||||
def test_map_userinfo_to_user(self):
|
||||
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
userinfo = {
|
||||
"sub": "test_user",
|
||||
"username": "test_user",
|
||||
}
|
||||
# The token doesn't matter with the default user mapping provider.
|
||||
token = {}
|
||||
mxid = self.get_success(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
)
|
||||
self._make_callback_with_userinfo(userinfo)
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", ANY, ANY, {}
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
# Some providers return an integer ID.
|
||||
userinfo = {
|
||||
"sub": 1234,
|
||||
"username": "test_user_2",
|
||||
}
|
||||
mxid = self.get_success(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
)
|
||||
self._make_callback_with_userinfo(userinfo)
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user_2:test", ANY, ANY, {}
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user_2:test")
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
# Test if the mxid is already taken
|
||||
store = self.hs.get_datastore()
|
||||
@@ -683,14 +638,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
store.register_user(user_id=user3.to_string(), password_hash=None)
|
||||
)
|
||||
userinfo = {"sub": "test3", "username": "test_user_3"}
|
||||
e = self.get_failure(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
),
|
||||
MappingException,
|
||||
)
|
||||
self.assertEqual(
|
||||
str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
|
||||
self._make_callback_with_userinfo(userinfo)
|
||||
auth_handler.complete_sso_login.assert_not_called()
|
||||
self.assertRenderedError(
|
||||
"mapping_error",
|
||||
"Mapping provider does not support de-duplicating Matrix IDs",
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": {"allow_existing_users": True}})
|
||||
@@ -702,26 +654,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
store.register_user(user_id=user.to_string(), password_hash=None)
|
||||
)
|
||||
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
# Map a user via SSO.
|
||||
userinfo = {
|
||||
"sub": "test",
|
||||
"username": "test_user",
|
||||
}
|
||||
token = {}
|
||||
mxid = self.get_success(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
)
|
||||
self._make_callback_with_userinfo(userinfo)
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user.to_string(), ANY, ANY, {},
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
# Subsequent calls should map to the same mxid.
|
||||
mxid = self.get_success(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
)
|
||||
self._make_callback_with_userinfo(userinfo)
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user.to_string(), ANY, ANY, {},
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
# Note that a second SSO user can be mapped to the same Matrix ID. (This
|
||||
# requires a unique sub, but something that maps to the same matrix ID,
|
||||
@@ -732,13 +684,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"sub": "test1",
|
||||
"username": "test_user",
|
||||
}
|
||||
token = {}
|
||||
mxid = self.get_success(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
)
|
||||
self._make_callback_with_userinfo(userinfo)
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user.to_string(), ANY, ANY, {},
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
# Register some non-exact matching cases.
|
||||
user2 = UserID.from_string("@TEST_user_2:test")
|
||||
@@ -755,14 +705,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"sub": "test2",
|
||||
"username": "TEST_USER_2",
|
||||
}
|
||||
e = self.get_failure(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
),
|
||||
MappingException,
|
||||
)
|
||||
self._make_callback_with_userinfo(userinfo)
|
||||
auth_handler.complete_sso_login.assert_not_called()
|
||||
args = self.assertRenderedError("mapping_error")
|
||||
self.assertTrue(
|
||||
str(e.value).startswith(
|
||||
args[2].startswith(
|
||||
"Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
|
||||
)
|
||||
)
|
||||
@@ -773,28 +720,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
store.register_user(user_id=user2.to_string(), password_hash=None)
|
||||
)
|
||||
|
||||
mxid = self.get_success(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
)
|
||||
self._make_callback_with_userinfo(userinfo)
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@TEST_USER_2:test", ANY, ANY, {},
|
||||
)
|
||||
self.assertEqual(mxid, "@TEST_USER_2:test")
|
||||
|
||||
def test_map_userinfo_to_invalid_localpart(self):
|
||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||
userinfo = {
|
||||
"sub": "test2",
|
||||
"username": "föö",
|
||||
}
|
||||
token = {}
|
||||
|
||||
e = self.get_failure(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
),
|
||||
MappingException,
|
||||
)
|
||||
self.assertEqual(str(e.value), "localpart is invalid: föö")
|
||||
self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
|
||||
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
|
||||
|
||||
@override_config(
|
||||
{
|
||||
@@ -807,6 +741,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
def test_map_userinfo_to_user_retries(self):
|
||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||
@@ -815,14 +752,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"sub": "test",
|
||||
"username": "test_user",
|
||||
}
|
||||
token = {}
|
||||
mxid = self.get_success(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
)
|
||||
)
|
||||
self._make_callback_with_userinfo(userinfo)
|
||||
|
||||
# test_user is already taken, so test_user1 gets registered instead.
|
||||
self.assertEqual(mxid, "@test_user1:test")
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user1:test", ANY, ANY, {},
|
||||
)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
# Register all of the potential mxids for a particular OIDC username.
|
||||
self.get_success(
|
||||
@@ -838,12 +774,70 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"sub": "tester",
|
||||
"username": "tester",
|
||||
}
|
||||
e = self.get_failure(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
),
|
||||
MappingException,
|
||||
self._make_callback_with_userinfo(userinfo)
|
||||
auth_handler.complete_sso_login.assert_not_called()
|
||||
self.assertRenderedError(
|
||||
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
|
||||
)
|
||||
self.assertEqual(
|
||||
str(e.value), "Unable to generate a Matrix ID from the SSO response"
|
||||
|
||||
def _make_callback_with_userinfo(
|
||||
self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
|
||||
) -> None:
|
||||
self.handler._exchange_code = simple_async_mock(return_value={})
|
||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
state = "state"
|
||||
session = self.handler._generate_oidc_session_token(
|
||||
state=state,
|
||||
nonce="nonce",
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=None,
|
||||
)
|
||||
request = self._build_callback_request("code", state, session)
|
||||
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
def _build_callback_request(
|
||||
self,
|
||||
code: str,
|
||||
state: str,
|
||||
session: str,
|
||||
user_agent: str = "Browser",
|
||||
ip_address: str = "10.0.0.1",
|
||||
):
|
||||
"""Builds a fake SynapseRequest to mock the browser callback
|
||||
|
||||
Returns a Mock object which looks like the SynapseRequest we get from a browser
|
||||
after SSO (before we return to the client)
|
||||
|
||||
Args:
|
||||
code: the authorization code which would have been returned by the OIDC
|
||||
provider
|
||||
state: the "state" param which would have been passed around in the
|
||||
query param. Should be the same as was embedded in the session in
|
||||
_build_oidc_session.
|
||||
session: the "session" which would have been passed around in the cookie.
|
||||
user_agent: the user-agent to present
|
||||
ip_address: the IP address to pretend the request came from
|
||||
"""
|
||||
request = Mock(
|
||||
spec=[
|
||||
"args",
|
||||
"getCookie",
|
||||
"addCookie",
|
||||
"requestHeaders",
|
||||
"getClientIP",
|
||||
"get_user_agent",
|
||||
]
|
||||
)
|
||||
|
||||
request.getCookie.return_value = session
|
||||
request.args = {}
|
||||
request.args[b"code"] = [code.encode("utf-8")]
|
||||
request.args[b"state"] = [state.encode("utf-8")]
|
||||
request.getClientIP.return_value = ip_address
|
||||
request.get_user_agent.return_value = user_agent
|
||||
return request
|
||||
|
||||
@@ -430,6 +430,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
|
||||
@override_config(
|
||||
{
|
||||
**providers_config(CustomAuthProvider),
|
||||
"password_config": {"enabled": False, "localdb_enabled": False},
|
||||
}
|
||||
)
|
||||
def test_custom_auth_password_disabled_localdb_enabled(self):
|
||||
"""Check the localdb_enabled == enabled == False
|
||||
|
||||
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
|
||||
that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
|
||||
cause an exception.
|
||||
"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
flows = self._get_login_flows()
|
||||
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||
|
||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
|
||||
@override_config(
|
||||
{
|
||||
**providers_config(PasswordCustomAuthProvider),
|
||||
|
||||
@@ -12,11 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import RedirectException
|
||||
from synapse.handlers.sso import MappingException
|
||||
|
||||
from tests.test_utils import simple_async_mock
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
# Check if we have the dependencies to run the tests.
|
||||
@@ -44,6 +48,8 @@ BASE_URL = "https://synapse/"
|
||||
@attr.s
|
||||
class FakeAuthnResponse:
|
||||
ava = attr.ib(type=dict)
|
||||
assertions = attr.ib(type=list, factory=list)
|
||||
in_response_to = attr.ib(type=Optional[str], default=None)
|
||||
|
||||
|
||||
class TestMappingProvider:
|
||||
@@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
def test_map_saml_response_to_user(self):
|
||||
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
|
||||
|
||||
# stub out the auth handler
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
# send a mocked-up SAML response to the callback
|
||||
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
||||
# The redirect_url doesn't matter with the default user mapping provider.
|
||||
redirect_url = ""
|
||||
mxid = self.get_success(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
)
|
||||
request = _mock_request()
|
||||
self.get_success(
|
||||
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
|
||||
)
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "redirect_uri"
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
|
||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||
def test_map_saml_response_to_existing_user(self):
|
||||
@@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||
)
|
||||
|
||||
# stub out the auth handler
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
# Map a user via SSO.
|
||||
saml_response = FakeAuthnResponse(
|
||||
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
|
||||
)
|
||||
redirect_url = ""
|
||||
mxid = self.get_success(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
)
|
||||
request = _mock_request()
|
||||
self.get_success(
|
||||
self.handler._handle_authn_response(request, saml_response, "")
|
||||
)
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, ""
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
|
||||
# Subsequent calls should map to the same mxid.
|
||||
mxid = self.get_success(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
self.get_success(
|
||||
self.handler._handle_authn_response(request, saml_response, "")
|
||||
)
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, ""
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
|
||||
def test_map_saml_response_to_invalid_localpart(self):
|
||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||
|
||||
# stub out the auth handler
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
# mock out the error renderer too
|
||||
sso_handler = self.hs.get_sso_handler()
|
||||
sso_handler.render_error = Mock(return_value=None)
|
||||
|
||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
|
||||
redirect_url = ""
|
||||
e = self.get_failure(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
),
|
||||
MappingException,
|
||||
request = _mock_request()
|
||||
self.get_success(
|
||||
self.handler._handle_authn_response(request, saml_response, ""),
|
||||
)
|
||||
self.assertEqual(str(e.value), "localpart is invalid: föö")
|
||||
sso_handler.render_error.assert_called_once_with(
|
||||
request, "mapping_error", "localpart is invalid: föö"
|
||||
)
|
||||
auth_handler.complete_sso_login.assert_not_called()
|
||||
|
||||
def test_map_saml_response_to_user_retries(self):
|
||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||
|
||||
# stub out the auth handler and error renderer
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
sso_handler = self.hs.get_sso_handler()
|
||||
sso_handler.render_error = Mock(return_value=None)
|
||||
|
||||
# register a user to occupy the first-choice MXID
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||
)
|
||||
|
||||
# send the fake SAML response
|
||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
||||
redirect_url = ""
|
||||
mxid = self.get_success(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
)
|
||||
request = _mock_request()
|
||||
self.get_success(
|
||||
self.handler._handle_authn_response(request, saml_response, ""),
|
||||
)
|
||||
|
||||
# test_user is already taken, so test_user1 gets registered instead.
|
||||
self.assertEqual(mxid, "@test_user1:test")
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user1:test", request, ""
|
||||
)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
# Register all of the potential mxids for a particular SAML username.
|
||||
self.get_success(
|
||||
@@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# Now attempt to map to a username, this will fail since all potential usernames are taken.
|
||||
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
|
||||
e = self.get_failure(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
),
|
||||
MappingException,
|
||||
self.get_success(
|
||||
self.handler._handle_authn_response(request, saml_response, ""),
|
||||
)
|
||||
self.assertEqual(
|
||||
str(e.value), "Unable to generate a Matrix ID from the SSO response"
|
||||
sso_handler.render_error.assert_called_once_with(
|
||||
request,
|
||||
"mapping_error",
|
||||
"Unable to generate a Matrix ID from the SSO response",
|
||||
)
|
||||
auth_handler.complete_sso_login.assert_not_called()
|
||||
|
||||
@override_config(
|
||||
{
|
||||
@@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
)
|
||||
def test_map_saml_response_redirect(self):
|
||||
"""Test a mapping provider that raises a RedirectException"""
|
||||
|
||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
||||
redirect_url = ""
|
||||
request = _mock_request()
|
||||
e = self.get_failure(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
),
|
||||
self.handler._handle_authn_response(request, saml_response, ""),
|
||||
RedirectException,
|
||||
)
|
||||
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
|
||||
|
||||
|
||||
def _mock_request():
|
||||
"""Returns a mock which will stand in as a SynapseRequest"""
|
||||
return Mock(spec=["getClientIP", "get_user_agent"])
|
||||
|
||||
@@ -117,11 +117,10 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||
"""
|
||||
handler = logging.StreamHandler(self.output)
|
||||
handler.setFormatter(JsonFormatter())
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
handler.addFilter(LoggingContextFilter())
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
with LoggingContext() as context_one:
|
||||
context_one.request = "test"
|
||||
with LoggingContext(request="test"):
|
||||
logger.info("Hello there, %s!", "wally")
|
||||
|
||||
log = self.get_log_line()
|
||||
@@ -132,9 +131,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||
"level",
|
||||
"namespace",
|
||||
"request",
|
||||
"scope",
|
||||
]
|
||||
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||
self.assertEqual(log["log"], "Hello there, wally!")
|
||||
self.assertEqual(log["request"], "test")
|
||||
self.assertIsNone(log["scope"])
|
||||
|
||||
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||
}
|
||||
)
|
||||
|
||||
with LoggingContext(request="lying_event"):
|
||||
with LoggingContext():
|
||||
failure = self.get_failure(
|
||||
self.handler.on_receive_pdu(
|
||||
"test.serv", lying_event, sent_to_us_directly=True
|
||||
|
||||
@@ -22,6 +22,8 @@ import warnings
|
||||
from asyncio import Future
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
@@ -87,6 +89,16 @@ def setup_awaitable_errors() -> Callable[[], None]:
|
||||
return cleanup
|
||||
|
||||
|
||||
def simple_async_mock(return_value=None, raises=None) -> Mock:
|
||||
# AsyncMock is not available in python3.5, this mimics part of its behaviour
|
||||
async def cb(*args, **kwargs):
|
||||
if raises:
|
||||
raise raises
|
||||
return return_value
|
||||
|
||||
return Mock(side_effect=cb)
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeResponse:
|
||||
"""A fake twisted.web.IResponse object
|
||||
|
||||
@@ -48,7 +48,7 @@ def setup_logging():
|
||||
handler = ToTwistedHandler()
|
||||
formatter = logging.Formatter(log_format)
|
||||
handler.setFormatter(formatter)
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
handler.addFilter(LoggingContextFilter())
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
|
||||
|
||||
Reference in New Issue
Block a user