1
0

Merge commit '631dd06f2' into anoa/dinsic_release_1_31_0

This commit is contained in:
Andrew Morgan
2021-04-22 16:22:43 +01:00
13 changed files with 569 additions and 175 deletions

View File

@@ -1 +1 @@
Improve efficiency of large state resolutions for new rooms.
Improve efficiency of large state resolutions.

1
changelog.d/9029.misc Normal file
View File

@@ -0,0 +1 @@
Improve efficiency of large state resolutions.

1
changelog.d/9107.feature Normal file
View File

@@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9114.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.21.0.

View File

@@ -71,7 +71,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"],
"rooms": ["is_public"],
"rooms": ["is_public", "has_auth_chain_index"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
"presence_stream": ["currently_active"],

View File

@@ -429,7 +429,6 @@ def setup(config_options):
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
await oidc.load_jwks()
await _base.start(hs, config.listeners)

View File

@@ -35,6 +35,7 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.config.oidc_config import OidcProviderConfig
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
@@ -70,6 +71,131 @@ JWK = Dict[str, str]
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler()
provider_conf = hs.config.oidc.oidc_provider
# we should not have been instantiated if there is no configured provider.
assert provider_conf is not None
self._token_generator = OidcSessionTokenGenerator(hs)
self._provider = OidcProvider(hs, self._token_generator, provider_conf)
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
Called at startup to ensure we have everything we need.
"""
await self._provider.load_metadata()
await self._provider.load_jwks()
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
- first, we check if there was any error returned by the provider and
display it
- then we fetch the session cookie, decode and verify it
- the ``state`` query parameter should match with the one stored in the
session cookie
Once we know the session is legit, we then delegate to the OIDC Provider
implementation, which will exchange the code with the provider and complete the
login/authentication.
Args:
request: the incoming request from the browser.
"""
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
# error response from the auth server. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
# Most of the errors returned by the provider could be due by
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
# Remove the cookie. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
# Removing it early avoids spamming the provider with token requests.
request.addCookie(
SESSION_COOKIE_NAME,
b"",
path="/_synapse/oidc",
expires="Thu, Jan 01 1970 00:00:00 UTC",
httpOnly=True,
sameSite="lax",
)
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return
state = request.args[b"state"][0].decode()
# Deserialize the session token and verify it.
try:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return
code = request.args[b"code"][0].decode()
await self._provider.handle_oidc_callback(request, session_data, code)
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint
"""
@@ -84,21 +210,25 @@ class OidcError(Exception):
return self.error
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
class OidcProvider:
"""Wraps the config for a single OIDC IdentityProvider
Provides methods for handling redirect requests and callbacks via that particular
IdP.
"""
def __init__(self, hs: "HomeServer"):
def __init__(
self,
hs: "HomeServer",
token_generator: "OidcSessionTokenGenerator",
provider: OidcProviderConfig,
):
self._store = hs.get_datastore()
self._token_generator = OidcSessionTokenGenerator(hs)
self._token_generator = token_generator
self._callback_url = hs.config.oidc_callback_url # type: str
provider = hs.config.oidc.oidc_provider
# we should not have been instantiated if there is no configured provider.
assert provider is not None
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth(
@@ -552,22 +682,16 @@ class OidcHandler:
nonce=nonce,
)
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
async def handle_oidc_callback(
self, request: SynapseRequest, session_data: "OidcSessionData", code: str
) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._sso_handler.render_error`` which displays an HTML page for the error.
By this time we have already validated the session on the synapse side, and
now need to do the provider-specific operations. This includes:
Most of the OpenID Connect logic happens here:
- first, we check if there was any error returned by the provider and
display it
- then we fetch the session cookie, decode and verify it
- the ``state`` query parameter should match with the one stored in the
session cookie
- once we known this session is legit, exchange the code with the
provider using the ``token_endpoint`` (see ``_exchange_code``)
- exchange the code with the provider using the ``token_endpoint`` (see
``_exchange_code``)
- once we have the token, use it to either extract the UserInfo from
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
to fetch UserInfo from the ``userinfo_endpoint``
@@ -577,86 +701,12 @@ class OidcHandler:
Args:
request: the incoming request from the browser.
session_data: the session data, extracted from our cookie
code: The authorization code we got from the callback.
"""
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
# error response from the auth server. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
# Most of the errors returned by the provider could be due by
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
# Remove the cookie. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
# Removing it early avoids spamming the provider with token requests.
request.addCookie(
SESSION_COOKIE_NAME,
b"",
path="/_synapse/oidc",
expires="Thu, Jan 01 1970 00:00:00 UTC",
httpOnly=True,
sameSite="lax",
)
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return
state = request.args[b"state"][0].decode()
# Deserialize the session token and verify it.
try:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return
logger.debug("Exchanging code")
code = request.args[b"code"][0].decode()
try:
logger.debug("Exchanging code")
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")

View File

@@ -466,9 +466,6 @@ class PersistEventsStore:
if not state_events:
return
# Map from event ID to chain ID/sequence number.
chain_map = {} # type: Dict[str, Tuple[int, int]]
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
@@ -479,9 +476,33 @@ class PersistEventsStore:
event_to_auth_chain = {
e.event_id: e.auth_event_ids() for e in state_events.values()
}
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
self._add_chain_cover_index(
txn, event_to_room_id, event_to_types, event_to_auth_chain
)
def _add_chain_cover_index(
self,
txn,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
) -> None:
"""Calculate the chain cover index for the given events.
Args:
event_to_room_id: Event ID to the room ID of the event
event_to_types: Event ID to type and state_key of the event
event_to_auth_chain: Event ID to list of auth event IDs of the
event (events with no auth events can be excluded).
"""
# Map from event ID to chain ID/sequence number.
chain_map = {} # type: Dict[str, Tuple[int, int]]
# Set of event IDs to calculate chain ID/seq numbers for.
events_to_calc_chain_id_for = set(state_events)
events_to_calc_chain_id_for = set(event_to_room_id)
# We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where
@@ -491,7 +512,7 @@ class PersistEventsStore:
table="event_auth_chain_to_calculate",
keyvalues={},
column="room_id",
iterable={e.room_id for e in state_events.values()},
iterable=set(event_to_room_id.values()),
retcols=("event_id", "type", "state_key"),
)
for row in rows:
@@ -582,16 +603,17 @@ class PersistEventsStore:
# the list of events to calculate chain IDs for next time
# around. (Otherwise we will have already added it to the
# table).
event = state_events.get(event_id)
if event:
room_id = event_to_room_id.get(event_id)
if room_id:
e_type, state_key = event_to_types[event_id]
self.db_pool.simple_insert_txn(
txn,
table="event_auth_chain_to_calculate",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
"event_id": event_id,
"room_id": room_id,
"type": e_type,
"state_key": state_key,
},
)
@@ -617,7 +639,7 @@ class PersistEventsStore:
events_to_calc_chain_id_for, event_to_auth_chain
):
existing_chain_id = None
for auth_id in event_to_auth_chain[event_id]:
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id]
break
@@ -730,11 +752,11 @@ class PersistEventsStore:
# auth events (A, B) to check if B is reachable from A.
reduction = {
a_id
for a_id in event_to_auth_chain[event_id]
for a_id in event_to_auth_chain.get(event_id, [])
if chain_map[a_id][0] != chain_id
}
for start_auth_id, end_auth_id in itertools.permutations(
event_to_auth_chain[event_id], r=2,
event_to_auth_chain.get(event_id, []), r=2,
):
if chain_links.exists_path_from(
chain_map[start_auth_id], chain_map[end_auth_id]

View File

@@ -14,13 +14,13 @@
# limitations under the License.
import logging
from typing import List, Tuple
from typing import Dict, List, Optional, Tuple
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict
@@ -108,6 +108,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"rejected_events_metadata", self._rejected_events_metadata,
)
self.db_pool.updates.register_background_update_handler(
"chain_cover", self._chain_cover_index,
)
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -706,3 +710,187 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
return len(results)
async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
"""A background updates that iterates over all rooms and generates the
chain cover index for them.
"""
current_room_id = progress.get("current_room_id", "")
# Have we finished processing the current room.
finished = progress.get("finished", True)
# Where we've processed up to in the room, defaults to the start of the
# room.
last_depth = progress.get("last_depth", -1)
last_stream = progress.get("last_stream", -1)
# Have we set the `has_auth_chain_index` for the room yet.
has_set_room_has_chain_index = progress.get(
"has_set_room_has_chain_index", False
)
if finished:
# If we've finished with the previous room (or its our first
# iteration) we move on to the next room.
def _get_next_room(txn: Cursor) -> Optional[str]:
sql = """
SELECT room_id FROM rooms
WHERE room_id > ?
AND (
NOT has_auth_chain_index
OR has_auth_chain_index IS NULL
)
ORDER BY room_id
LIMIT 1
"""
txn.execute(sql, (current_room_id,))
row = txn.fetchone()
if row:
return row[0]
return None
current_room_id = await self.db_pool.runInteraction(
"_chain_cover_index", _get_next_room
)
if not current_room_id:
await self.db_pool.updates._end_background_update("chain_cover")
return 0
logger.debug("Adding chain cover to %s", current_room_id)
def _calculate_auth_chain(
txn: Cursor, last_depth: int, last_stream: int
) -> Tuple[int, int, int]:
# Get the next set of events in the room (that we haven't already
# computed chain cover for). We do this in topological order.
# We want to do a `(topological_ordering, stream_ordering) > (?,?)`
# comparison, but that is not supported on older SQLite versions
tuple_clause, tuple_args = make_tuple_comparison_clause(
self.database_engine,
[
("topological_ordering", last_depth),
("stream_ordering", last_stream),
],
)
sql = """
SELECT
event_id, state_events.type, state_events.state_key,
topological_ordering, stream_ordering
FROM events
INNER JOIN state_events USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
LEFT JOIN event_auth_chain_to_calculate USING (event_id)
WHERE events.room_id = ?
AND event_auth_chains.event_id IS NULL
AND event_auth_chain_to_calculate.event_id IS NULL
AND %(tuple_cmp)s
ORDER BY topological_ordering, stream_ordering
LIMIT ?
""" % {
"tuple_cmp": tuple_clause,
}
args = [current_room_id]
args.extend(tuple_args)
args.append(batch_size)
txn.execute(sql, args)
rows = txn.fetchall()
# Put the results in the necessary format for
# `_add_chain_cover_index`
event_to_room_id = {row[0]: current_room_id for row in rows}
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
new_last_depth = rows[-1][3] if rows else last_depth # type: int
new_last_stream = rows[-1][4] if rows else last_stream # type: int
count = len(rows)
# We also need to fetch the auth events for them.
auth_events = self.db_pool.simple_select_many_txn(
txn,
table="event_auth",
column="event_id",
iterable=event_to_room_id,
keyvalues={},
retcols=("event_id", "auth_id"),
)
event_to_auth_chain = {} # type: Dict[str, List[str]]
for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(
row["auth_id"]
)
# Calculate and persist the chain cover index for this set of events.
#
# Annoyingly we need to gut wrench into the persit event store so that
# we can reuse the function to calculate the chain cover for rooms.
self.hs.get_datastores().persist_events._add_chain_cover_index(
txn, event_to_room_id, event_to_types, event_to_auth_chain,
)
return new_last_depth, new_last_stream, count
last_depth, last_stream, count = await self.db_pool.runInteraction(
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
)
total_rows_processed = count
if count < batch_size and not has_set_room_has_chain_index:
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
# flag without having the chain cover calculated for them. This is
# fine as a) the code gracefully handles these cases and b) we'll
# calculate them below.
await self.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": current_room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
has_set_room_has_chain_index = True
# Handle any events that might have raced with us flipping the
# bit above.
last_depth, last_stream, count = await self.db_pool.runInteraction(
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
)
total_rows_processed += count
# Note that at this point its technically possible that more events
# than our `batch_size` have been persisted without their chain
# cover, so we need to continue processing this room if the last
# count returned was equal to the `batch_size`.
if count < batch_size:
# We've finished calculating the index for this room, move on to the
# next room.
await self.db_pool.updates._background_update_progress(
"chain_cover", {"current_room_id": current_room_id, "finished": True},
)
else:
# We still have outstanding events to calculate the index for.
await self.db_pool.updates._background_update_progress(
"chain_cover",
{
"current_room_id": current_room_id,
"last_depth": last_depth,
"last_stream": last_stream,
"has_auth_chain_index": has_set_room_has_chain_index,
"finished": False,
},
)
return total_rows_processed

View File

@@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
(5906, 'chain_cover', '{}', 'rejected_events_metadata');

View File

@@ -464,19 +464,17 @@ class TransactionStore(TransactionWorkerStore):
txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
) -> List[str]:
q = """
SELECT destination FROM destinations
WHERE destination IN (
SELECT destination FROM destination_rooms
WHERE destination_rooms.stream_ordering >
destinations.last_successful_stream_ordering
)
AND destination > ?
AND (
retry_last_ts IS NULL OR
retry_last_ts + retry_interval < ?
)
ORDER BY destination
LIMIT 25
SELECT DISTINCT destination FROM destinations
INNER JOIN destination_rooms USING (destination)
WHERE
stream_ordering > last_successful_stream_ordering
AND destination > ?
AND (
retry_last_ts IS NULL OR
retry_last_ts + retry_interval < ?
)
ORDER BY destination
LIMIT 25
"""
txn.execute(
q,

View File

@@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler()
self.provider = self.handler._provider
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
@@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def metadata_edit(self, values):
return patch.dict(self.handler._provider_metadata, values)
return patch.dict(self.provider._provider_metadata, values)
def assertRenderedError(self, error, error_description=None):
self.render_error.assert_called_once()
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
@@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.handler._callback_url, CALLBACK_URL)
self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}})
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.handler.load_metadata())
metadata = self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER)
@@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self.http_client.reset_mock()
self.get_success(self.handler.load_metadata())
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.handler.load_metadata())
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.handler.load_jwks())
jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached…
self.http_client.reset_mock()
self.get_success(self.handler.load_jwks())
self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_not_called()
# …unless forced
self.http_client.reset_mock()
self.get_success(self.handler.load_jwks(force=True))
self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
self.handler._scopes = [] # not asking the openid scope
self.provider._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
jwks = self.get_success(self.handler.load_jwks(force=True))
jwks = self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@override_config({"oidc_config": COMMON_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
h = self.handler
h = self.provider
# Default test config does not throw
h._validate_metadata()
@@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
self.handler._validate_metadata()
self.provider._validate_metadata()
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie"])
url = self.get_success(
self.handler.handle_redirect_request(req, b"http://client/redirect")
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# ensure that we are correctly testing the fallback when "get_extra_attributes"
# is not implemented.
mapping_provider = self.handler._user_mapping_provider
mapping_provider = self.provider._user_mapping_provider
with self.assertRaises(AttributeError):
_ = mapping_provider.get_extra_attributes
@@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.provider._exchange_code = simple_async_mock(return_value=token)
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, None,
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
self.handler._fetch_userinfo.assert_not_called()
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
self.provider._fetch_userinfo.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
with patch.object(
self.handler,
self.provider,
"_remote_id_from_userinfo",
new=Mock(side_effect=MappingException()),
):
@@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception())
self.provider._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
auth_handler.complete_sso_login.reset_mock()
self.handler._exchange_code.reset_mock()
self.handler._parse_id_token.reset_mock()
self.handler._fetch_userinfo.reset_mock()
self.provider._exchange_code.reset_mock()
self.provider._parse_id_token.reset_mock()
self.provider._fetch_userinfo.reset_mock()
# With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope
self.provider._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, None,
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
self.handler._fetch_userinfo.assert_called_once_with(token)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_not_called()
self.provider._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called()
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
from synapse.handlers.oidc_handler import OidcError
self.handler._exchange_code = simple_async_mock(
self.provider._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
@@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
)
code = "code"
ret = self.get_success(self.handler._exchange_code(code))
ret = self.get_success(self.provider._exchange_code(code))
kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token)
@@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
from synapse.handlers.oidc_handler import OidcError
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "foo")
self.assertEqual(exc.value.error_description, "bar")
@@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
)
)
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
@@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
)
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
)
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
@@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
)
)
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error")
@override_config(
@@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo",
"phone": "1234567",
}
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.provider._exchange_code = simple_async_mock(return_value=token)
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -979,9 +981,10 @@ async def _make_callback_with_userinfo(
from synapse.handlers.oidc_handler import OidcSessionData
handler = hs.get_oidc_handler()
handler._exchange_code = simple_async_mock(return_value={})
handler._parse_id_token = simple_async_mock(return_value=userinfo)
handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
provider = handler._provider
provider._exchange_code = simple_async_mock(return_value={})
provider._parse_id_token = simple_async_mock(return_value=userinfo)
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
state = "state"
session = handler._token_generator.generate_oidc_session_token(

View File

@@ -20,7 +20,10 @@ from twisted.trial import unittest
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.storage.databases.main.events import _LinkMap
from synapse.types import create_requester
from tests.unittest import HomeserverTestCase
@@ -470,3 +473,114 @@ class LinkMapTestCase(unittest.TestCase):
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def test_background_update(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
user_id = self.register_user("foo", "pass")
token = self.login("foo", "pass")
room_id = self.helper.create_room_as(user_id, tok=token)
requester = create_requester(user_id)
store = self.hs.get_datastore()
# Mark the room as not having a chain cover index
self.get_success(
store.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": False},
desc="test",
)
)
# Create a fork in the DAG with different events.
event_handler = self.hs.get_event_creation_handler()
latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
event, context = self.get_success(
event_handler.create_event(
requester,
{
"type": "some_state_type",
"state_key": "",
"content": {},
"room_id": room_id,
"sender": user_id,
},
prev_event_ids=latest_event_ids,
)
)
self.get_success(
event_handler.handle_new_client_event(requester, event, context)
)
state1 = list(self.get_success(context.get_current_state_ids()).values())
event, context = self.get_success(
event_handler.create_event(
requester,
{
"type": "some_state_type",
"state_key": "",
"content": {},
"room_id": room_id,
"sender": user_id,
},
prev_event_ids=latest_event_ids,
)
)
self.get_success(
event_handler.handle_new_client_event(requester, event, context)
)
state2 = list(self.get_success(context.get_current_state_ids()).values())
# Delete the chain cover info.
def _delete_tables(txn):
txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links")
self.get_success(store.db_pool.runInteraction("test", _delete_tables))
# Insert and run the background update.
self.get_success(
store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
store.db_pool.updates._all_done = False
while not self.get_success(
store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
store.db_pool.runInteraction(
"test",
store._get_auth_chain_difference_using_cover_index_txn,
room_id,
[state1, state2],
)
)