1
0

Merge commit '01333681b' into anoa/dinsic_release_1_31_0

This commit is contained in:
Andrew Morgan
2021-04-16 15:06:19 +01:00
26 changed files with 345 additions and 297 deletions

1
changelog.d/8911.feature Normal file
View File

@@ -0,0 +1 @@
Add support for allowing users to pick their own user ID during a single-sign-on login.

View File

@@ -1 +1 @@
Improve structured logging tests.
Various clean-ups to the structured logging and logging context code.

1
changelog.d/8935.misc Normal file
View File

@@ -0,0 +1 @@
Various clean-ups to the structured logging and logging context code.

1
changelog.d/8937.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug introduced in Synapse v1.24.0 which would cause an exception on startup if both `enabled` and `localdb_enabled` were set to `False` in the `password_config` setting of the configuration file.

1
changelog.d/8938.feature Normal file
View File

@@ -0,0 +1 @@
Add support for allowing users to pick their own user ID during a single-sign-on login.

1
changelog.d/8943.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to push module.

View File

@@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
# filter options, but care must when using e.g. MemoryHandler to buffer
# writes.
log_context_filter = LoggingContextFilter(request="")
log_context_filter = LoggingContextFilter()
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
old_factory = logging.getLogRecordFactory()

View File

@@ -198,27 +198,25 @@ class AuthHandler(BaseHandler):
self._password_enabled = hs.config.password_enabled
self._password_localdb_enabled = hs.config.password_localdb_enabled
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
login_types = set()
if self._password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
for t in provider.get_supported_login_types().keys():
if t not in login_types:
login_types.append(t)
login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
login_types.remove(LoginType.PASSWORD)
login_types.discard(LoginType.PASSWORD)
self._supported_login_types = login_types
# Some clients just pick the first type in the list. In this case, we want
# them to use PASSWORD (rather than token or whatever), so we want to make sure
# that comes first, where it's present.
self._supported_login_types = []
if LoginType.PASSWORD in login_types:
self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.

View File

@@ -163,6 +163,29 @@ class SamlHandler(BaseHandler):
return
logger.debug("SAML2 response: %s", saml2_auth.origxml)
await self._handle_authn_response(request, saml2_auth, relay_state)
async def _handle_authn_response(
self,
request: SynapseRequest,
saml2_auth: saml2.response.AuthnResponse,
relay_state: str,
) -> None:
"""Handle an AuthnResponse, having parsed it from the request params
Assumes that the signature on the response object has been checked. Maps
the user onto an MXID, registering them if necessary, and returns a response
to the browser.
Args:
request: the incoming request from the browser. We'll respond to it with an
HTML page or a redirect
saml2_auth: the parsed AuthnResponse object
relay_state: the RelayState query param, which encodes the URI to rediret
back to
"""
for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.

View File

@@ -128,8 +128,7 @@ class SynapseRequest(Request):
# create a LogContext for this request
request_id = self.get_request_id()
logcontext = self.logcontext = LoggingContext(request_id)
logcontext.request = request_id
self.logcontext = LoggingContext(request_id, request=request_id)
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)

View File

@@ -203,10 +203,6 @@ class _Sentinel:
def copy_to(self, record):
pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self, rusage: "Optional[resource._RUsage]"):
pass
@@ -372,13 +368,6 @@ class LoggingContext:
# we also track the current scope:
record.scope = self.scope
def copy_to_twisted_log_entry(self, record) -> None:
"""
Copy logging fields from this context to a Twisted log record.
"""
record["request"] = self.request
record["scope"] = self.scope
def start(self, rusage: "Optional[resource._RUsage]") -> None:
"""
Record that this logcontext is currently running.
@@ -542,13 +531,10 @@ class LoggingContext:
class LoggingContextFilter(logging.Filter):
"""Logging filter that adds values from the current logging context to each
record.
Args:
**defaults: Default values to avoid formatters complaining about
missing fields
"""
def __init__(self, **defaults) -> None:
self.defaults = defaults
def __init__(self, request: str = ""):
self._default_request = request
def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
@@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter):
True to include the record in the log output.
"""
context = current_context()
for key, value in self.defaults.items():
setattr(record, key, value)
record.request = self._default_request
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
context.copy_to(record)
# Logging is interested in the request.
record.request = context.request
return True

View File

@@ -199,8 +199,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc()
with BackgroundProcessLoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count)
with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
try:
ctx = noop_context_manager()
if bg_start_span:
@@ -244,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext):
__slots__ = ["_proc"]
def __init__(self, name: str):
super().__init__(name)
def __init__(self, name: str, request: Optional[str] = None):
super().__init__(name, request=request)
self._proc = _BackgroundProcess(name, self)

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict
from synapse.types import RoomStreamToken
@@ -36,12 +36,21 @@ class Pusher(metaclass=abc.ABCMeta):
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
# should honour this rather than just looking for anything higher
# because of potential out-of-order event serialisation. This starts
# off as None though as we don't know any better.
self.max_stream_ordering = None # type: Optional[int]
# because of potential out-of-order event serialisation.
self.max_stream_ordering = self.store.get_room_max_stream_ordering()
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
max_stream_ordering = max_token.stream
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
self._start_processing()
@abc.abstractmethod
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
def _start_processing(self):
"""Start processing push notifications."""
raise NotImplementedError()
@abc.abstractmethod

View File

@@ -22,7 +22,6 @@ from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher
from synapse.push.mailer import Mailer
from synapse.types import RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -93,20 +92,6 @@ class EmailPusher(Pusher):
pass
self.timed_call = None
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
max_stream_ordering = max_token.stream
if self.max_stream_ordering:
self.max_stream_ordering = max(
max_stream_ordering, self.max_stream_ordering
)
else:
self.max_stream_ordering = max_stream_ordering
self._start_processing()
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
@@ -172,7 +157,6 @@ class EmailPusher(Pusher):
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
assert self.max_stream_ordering is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering
)

View File

@@ -26,7 +26,6 @@ from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfigException
from synapse.types import RoomStreamToken
from . import push_rule_evaluator, push_tools
@@ -122,17 +121,6 @@ class HttpPusher(Pusher):
if should_check_for_notifs:
self._start_processing()
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
max_stream_ordering = max_token.stream
self.max_stream_ordering = max(
max_stream_ordering, self.max_stream_ordering or 0
)
self._start_processing()
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# Note that the min here shouldn't be relied upon to be accurate.
@@ -192,10 +180,7 @@ class HttpPusher(Pusher):
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
assert self.max_stream_ordering is not None
unprocessed = await fn(
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)

View File

@@ -129,9 +129,8 @@ class PusherPool:
)
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
# pushes from this point onwards.
last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
# stream ordering, so it will process pushes from this point onwards.
last_stream_ordering = self.store.get_room_max_stream_ordering()
await self.store.add_pusher(
user_id=user_id,

View File

@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
ctx_name = "replication-conn-%s" % self.conn_id
self._logging_context = BackgroundProcessLoggingContext(ctx_name)
self._logging_context.request = ctx_name
self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())

View File

@@ -894,16 +894,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
async def get_latest_push_action_stream_ordering(self):
def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
result = await self.db_pool.runInteraction(
"get_latest_push_action_stream_ordering", f
)
return result[0] or 0
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):

View File

@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
with LoggingContext(request="send_rejected"):
with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
with LoggingContext(request="send_rejected"):
with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -198,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
# the auth code requires that a signature exists, but doesn't check that
# signature... go figure.
join_event.signatures[other_server] = {"x": "y"}
with LoggingContext(request="send_join"):
with LoggingContext("send_join"):
d = run_in_background(
self.handler.on_send_join_request, other_server, join_event
)

View File

@@ -15,7 +15,7 @@
import json
from urllib.parse import parse_qs, urlparse
from mock import Mock, patch
from mock import ANY, Mock, patch
import pymacaroons
@@ -23,7 +23,7 @@ from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
from tests.test_utils import FakeResponse
from tests.test_utils import FakeResponse, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests.
@@ -82,16 +82,6 @@ class TestMappingProviderFailures(TestMappingProvider):
}
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
if raises:
raise raises
return return_value
return Mock(side_effect=cb)
async def get_json(url):
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
@@ -160,6 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(args[2], error_description)
# Reset the render_error mock
self.render_error.reset_mock()
return args
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
@@ -374,26 +365,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
"id_token": "id_token",
"access_token": "access_token",
}
username = "bar"
userinfo = {
"sub": "foo",
"preferred_username": "bar",
"username": username,
}
user_id = "@foo:domain.org"
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
request = Mock(
spec=[
"args",
"getCookie",
"addCookie",
"requestHeaders",
"getClientIP",
"get_user_agent",
]
)
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
code = "code"
state = "state"
@@ -401,64 +383,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
request.getCookie.return_value = self.handler._generate_oidc_session_token(
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address
request.get_user_agent.return_value = user_agent
request = self._build_callback_request(
code, state, session, user_agent=user_agent, ip_address=ip_address
)
self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
user_id, request, client_redirect_url, {},
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
self.handler._map_userinfo_to_user.assert_called_once_with(
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
self.handler._map_userinfo_to_user = simple_async_mock(
raises=MappingException()
)
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error")
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
with patch.object(
self.handler,
"_remote_id_from_userinfo",
new=Mock(side_effect=MappingException()),
):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error")
# Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
self.handler._auth_handler.complete_sso_login.reset_mock()
auth_handler.complete_sso_login.reset_mock()
self.handler._exchange_code.reset_mock()
self.handler._parse_id_token.reset_mock()
self.handler._map_userinfo_to_user.reset_mock()
self.handler._fetch_userinfo.reset_mock()
# With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
user_id, request, client_redirect_url, {},
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
self.handler._map_userinfo_to_user.assert_called_once_with(
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called()
@@ -609,72 +581,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
userinfo = {
"sub": "foo",
"username": "foo",
"phone": "1234567",
}
user_id = "@foo:domain.org"
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
request = Mock(
spec=[
"args",
"getCookie",
"addCookie",
"requestHeaders",
"getClientIP",
"get_user_agent",
]
)
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
state = "state"
client_redirect_url = "http://client/redirect"
request.getCookie.return_value = self.handler._generate_oidc_session_token(
session = self.handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request.args = {}
request.args[b"code"] = [b"code"]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = "10.0.0.1"
request.get_user_agent.return_value = "Browser"
request = self._build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
user_id, request, client_redirect_url, {"phone": "1234567"},
auth_handler.complete_sso_login.assert_called_once_with(
"@foo:test", request, client_redirect_url, {"phone": "1234567"},
)
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
userinfo = {
"sub": "test_user",
"username": "test_user",
}
# The token doesn't matter with the default user mapping provider.
token = {}
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
self._make_callback_with_userinfo(userinfo)
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", ANY, ANY, {}
)
self.assertEqual(mxid, "@test_user:test")
auth_handler.complete_sso_login.reset_mock()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
self._make_callback_with_userinfo(userinfo)
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user_2:test", ANY, ANY, {}
)
self.assertEqual(mxid, "@test_user_2:test")
auth_handler.complete_sso_login.reset_mock()
# Test if the mxid is already taken
store = self.hs.get_datastore()
@@ -683,14 +638,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(
str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
self._make_callback_with_userinfo(userinfo)
auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
"Mapping provider does not support de-duplicating Matrix IDs",
)
@override_config({"oidc_config": {"allow_existing_users": True}})
@@ -702,26 +654,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None)
)
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
}
token = {}
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
self._make_callback_with_userinfo(userinfo)
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, {},
)
self.assertEqual(mxid, "@test_user:test")
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
self._make_callback_with_userinfo(userinfo)
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, {},
)
self.assertEqual(mxid, "@test_user:test")
auth_handler.complete_sso_login.reset_mock()
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
@@ -732,13 +684,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
token = {}
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
self._make_callback_with_userinfo(userinfo)
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, {},
)
self.assertEqual(mxid, "@test_user:test")
auth_handler.complete_sso_login.reset_mock()
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
@@ -755,14 +705,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
)
self._make_callback_with_userinfo(userinfo)
auth_handler.complete_sso_login.assert_not_called()
args = self.assertRenderedError("mapping_error")
self.assertTrue(
str(e.value).startswith(
args[2].startswith(
"Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
)
)
@@ -773,28 +720,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
self._make_callback_with_userinfo(userinfo)
auth_handler.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test", ANY, ANY, {},
)
self.assertEqual(mxid, "@TEST_USER_2:test")
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
userinfo = {
"sub": "test2",
"username": "föö",
}
token = {}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(str(e.value), "localpart is invalid: föö")
self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
{
@@ -807,6 +741,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_map_userinfo_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
@@ -815,14 +752,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
token = {}
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
self._make_callback_with_userinfo(userinfo)
# test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test")
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", ANY, ANY, {},
)
auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
@@ -838,12 +774,70 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
self._make_callback_with_userinfo(userinfo)
auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
self.assertEqual(
str(e.value), "Unable to generate a Matrix ID from the SSO response"
def _make_callback_with_userinfo(
self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
) -> None:
self.handler._exchange_code = simple_async_mock(return_value={})
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
state = "state"
session = self.handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = self._build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
def _build_callback_request(
self,
code: str,
state: str,
session: str,
user_agent: str = "Browser",
ip_address: str = "10.0.0.1",
):
"""Builds a fake SynapseRequest to mock the browser callback
Returns a Mock object which looks like the SynapseRequest we get from a browser
after SSO (before we return to the client)
Args:
code: the authorization code which would have been returned by the OIDC
provider
state: the "state" param which would have been passed around in the
query param. Should be the same as was embedded in the session in
_build_oidc_session.
session: the "session" which would have been passed around in the cookie.
user_agent: the user-agent to present
ip_address: the IP address to pretend the request came from
"""
request = Mock(
spec=[
"args",
"getCookie",
"addCookie",
"requestHeaders",
"getClientIP",
"get_user_agent",
]
)
request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address
request.get_user_agent.return_value = user_agent
return request

View File

@@ -430,6 +430,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**providers_config(CustomAuthProvider),
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled(self):
"""Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
cause an exception.
"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**providers_config(PasswordCustomAuthProvider),

View File

@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from mock import Mock
import attr
from synapse.api.errors import RedirectException
from synapse.handlers.sso import MappingException
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests.
@@ -44,6 +48,8 @@ BASE_URL = "https://synapse/"
@attr.s
class FakeAuthnResponse:
ava = attr.ib(type=dict)
assertions = attr.ib(type=list, factory=list)
in_response_to = attr.ib(type=Optional[str], default=None)
class TestMappingProvider:
@@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase):
def test_map_saml_response_to_user(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
# The redirect_url doesn't matter with the default user mapping provider.
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
)
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri"
)
self.assertEqual(mxid, "@test_user:test")
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self):
@@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
store.register_user(user_id="@test_user:test", password_hash=None)
)
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# Map a user via SSO.
saml_response = FakeAuthnResponse(
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
)
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "")
)
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, ""
)
self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
auth_handler.complete_sso_login.reset_mock()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "")
)
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, ""
)
self.assertEqual(mxid, "@test_user:test")
def test_map_saml_response_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# mock out the error renderer too
sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None)
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
redirect_url = ""
e = self.get_failure(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
MappingException,
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, ""),
)
self.assertEqual(str(e.value), "localpart is invalid: föö")
sso_handler.render_error.assert_called_once_with(
request, "mapping_error", "localpart is invalid: föö"
)
auth_handler.complete_sso_login.assert_not_called()
def test_map_saml_response_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
# stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None)
# register a user to occupy the first-choice MXID
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
# send the fake SAML response
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, ""),
)
# test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test")
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", request, ""
)
auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular SAML username.
self.get_success(
@@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
# Now attempt to map to a username, this will fail since all potential usernames are taken.
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
e = self.get_failure(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
MappingException,
self.get_success(
self.handler._handle_authn_response(request, saml_response, ""),
)
self.assertEqual(
str(e.value), "Unable to generate a Matrix ID from the SSO response"
sso_handler.render_error.assert_called_once_with(
request,
"mapping_error",
"Unable to generate a Matrix ID from the SSO response",
)
auth_handler.complete_sso_login.assert_not_called()
@override_config(
{
@@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
)
def test_map_saml_response_redirect(self):
"""Test a mapping provider that raises a RedirectException"""
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
redirect_url = ""
request = _mock_request()
e = self.get_failure(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
self.handler._handle_authn_response(request, saml_response, ""),
RedirectException,
)
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
return Mock(spec=["getClientIP", "get_user_agent"])

View File

@@ -117,11 +117,10 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
handler = logging.StreamHandler(self.output)
handler.setFormatter(JsonFormatter())
handler.addFilter(LoggingContextFilter(request=""))
handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler)
with LoggingContext() as context_one:
context_one.request = "test"
with LoggingContext(request="test"):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
@@ -132,9 +131,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"level",
"namespace",
"request",
"scope",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
self.assertEqual(log["request"], "test")
self.assertIsNone(log["scope"])

View File

@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
with LoggingContext(request="lying_event"):
with LoggingContext():
failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True

View File

@@ -22,6 +22,8 @@ import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
from mock import Mock
import attr
from twisted.python.failure import Failure
@@ -87,6 +89,16 @@ def setup_awaitable_errors() -> Callable[[], None]:
return cleanup
def simple_async_mock(return_value=None, raises=None) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
if raises:
raise raises
return return_value
return Mock(side_effect=cb)
@attr.s
class FakeResponse:
"""A fake twisted.web.IResponse object

View File

@@ -48,7 +48,7 @@ def setup_logging():
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request=""))
handler.addFilter(LoggingContextFilter())
root_logger.addHandler(handler)
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")