Merge commit '631dd06f2' into anoa/dinsic_release_1_31_0
This commit is contained in:
@@ -1 +1 @@
|
||||
Improve efficiency of large state resolutions for new rooms.
|
||||
Improve efficiency of large state resolutions.
|
||||
|
||||
1
changelog.d/9029.misc
Normal file
1
changelog.d/9029.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve efficiency of large state resolutions.
|
||||
1
changelog.d/9107.feature
Normal file
1
changelog.d/9107.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add support for multiple SSO Identity Providers.
|
||||
1
changelog.d/9114.bugfix
Normal file
1
changelog.d/9114.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix bug in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.21.0.
|
||||
@@ -71,7 +71,7 @@ logger = logging.getLogger("synapse_port_db")
|
||||
|
||||
BOOLEAN_COLUMNS = {
|
||||
"events": ["processed", "outlier", "contains_url"],
|
||||
"rooms": ["is_public"],
|
||||
"rooms": ["is_public", "has_auth_chain_index"],
|
||||
"event_edges": ["is_state"],
|
||||
"presence_list": ["accepted"],
|
||||
"presence_stream": ["currently_active"],
|
||||
|
||||
@@ -429,7 +429,6 @@ def setup(config_options):
|
||||
oidc = hs.get_oidc_handler()
|
||||
# Loading the provider metadata also ensures the provider config is valid.
|
||||
await oidc.load_metadata()
|
||||
await oidc.load_jwks()
|
||||
|
||||
await _base.start(hs, config.listeners)
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from typing_extensions import TypedDict
|
||||
from twisted.web.client import readBody
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.oidc_config import OidcProviderConfig
|
||||
from synapse.handlers.sso import MappingException, UserAttributes
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
@@ -70,6 +71,131 @@ JWK = Dict[str, str]
|
||||
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
|
||||
|
||||
|
||||
class OidcHandler:
|
||||
"""Handles requests related to the OpenID Connect login flow.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
provider_conf = hs.config.oidc.oidc_provider
|
||||
# we should not have been instantiated if there is no configured provider.
|
||||
assert provider_conf is not None
|
||||
|
||||
self._token_generator = OidcSessionTokenGenerator(hs)
|
||||
|
||||
self._provider = OidcProvider(hs, self._token_generator, provider_conf)
|
||||
|
||||
async def load_metadata(self) -> None:
|
||||
"""Validate the config and load the metadata from the remote endpoint.
|
||||
|
||||
Called at startup to ensure we have everything we need.
|
||||
"""
|
||||
await self._provider.load_metadata()
|
||||
await self._provider.load_jwks()
|
||||
|
||||
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
|
||||
"""Handle an incoming request to /_synapse/oidc/callback
|
||||
|
||||
Since we might want to display OIDC-related errors in a user-friendly
|
||||
way, we don't raise SynapseError from here. Instead, we call
|
||||
``self._sso_handler.render_error`` which displays an HTML page for the error.
|
||||
|
||||
Most of the OpenID Connect logic happens here:
|
||||
|
||||
- first, we check if there was any error returned by the provider and
|
||||
display it
|
||||
- then we fetch the session cookie, decode and verify it
|
||||
- the ``state`` query parameter should match with the one stored in the
|
||||
session cookie
|
||||
|
||||
Once we know the session is legit, we then delegate to the OIDC Provider
|
||||
implementation, which will exchange the code with the provider and complete the
|
||||
login/authentication.
|
||||
|
||||
Args:
|
||||
request: the incoming request from the browser.
|
||||
"""
|
||||
|
||||
# The provider might redirect with an error.
|
||||
# In that case, just display it as-is.
|
||||
if b"error" in request.args:
|
||||
# error response from the auth server. see:
|
||||
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
|
||||
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
|
||||
error = request.args[b"error"][0].decode()
|
||||
description = request.args.get(b"error_description", [b""])[0].decode()
|
||||
|
||||
# Most of the errors returned by the provider could be due by
|
||||
# either the provider misbehaving or Synapse being misconfigured.
|
||||
# The only exception of that is "access_denied", where the user
|
||||
# probably cancelled the login flow. In other cases, log those errors.
|
||||
if error != "access_denied":
|
||||
logger.error("Error from the OIDC provider: %s %s", error, description)
|
||||
|
||||
self._sso_handler.render_error(request, error, description)
|
||||
return
|
||||
|
||||
# otherwise, it is presumably a successful response. see:
|
||||
# https://tools.ietf.org/html/rfc6749#section-4.1.2
|
||||
|
||||
# Fetch the session cookie
|
||||
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
|
||||
if session is None:
|
||||
logger.info("No session cookie found")
|
||||
self._sso_handler.render_error(
|
||||
request, "missing_session", "No session cookie found"
|
||||
)
|
||||
return
|
||||
|
||||
# Remove the cookie. There is a good chance that if the callback failed
|
||||
# once, it will fail next time and the code will already be exchanged.
|
||||
# Removing it early avoids spamming the provider with token requests.
|
||||
request.addCookie(
|
||||
SESSION_COOKIE_NAME,
|
||||
b"",
|
||||
path="/_synapse/oidc",
|
||||
expires="Thu, Jan 01 1970 00:00:00 UTC",
|
||||
httpOnly=True,
|
||||
sameSite="lax",
|
||||
)
|
||||
|
||||
# Check for the state query parameter
|
||||
if b"state" not in request.args:
|
||||
logger.info("State parameter is missing")
|
||||
self._sso_handler.render_error(
|
||||
request, "invalid_request", "State parameter is missing"
|
||||
)
|
||||
return
|
||||
|
||||
state = request.args[b"state"][0].decode()
|
||||
|
||||
# Deserialize the session token and verify it.
|
||||
try:
|
||||
session_data = self._token_generator.verify_oidc_session_token(
|
||||
session, state
|
||||
)
|
||||
except MacaroonDeserializationException as e:
|
||||
logger.exception("Invalid session")
|
||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||
return
|
||||
except MacaroonInvalidSignatureException as e:
|
||||
logger.exception("Could not verify session")
|
||||
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
||||
return
|
||||
|
||||
if b"code" not in request.args:
|
||||
logger.info("Code parameter is missing")
|
||||
self._sso_handler.render_error(
|
||||
request, "invalid_request", "Code parameter is missing"
|
||||
)
|
||||
return
|
||||
|
||||
code = request.args[b"code"][0].decode()
|
||||
|
||||
await self._provider.handle_oidc_callback(request, session_data, code)
|
||||
|
||||
|
||||
class OidcError(Exception):
|
||||
"""Used to catch errors when calling the token_endpoint
|
||||
"""
|
||||
@@ -84,21 +210,25 @@ class OidcError(Exception):
|
||||
return self.error
|
||||
|
||||
|
||||
class OidcHandler:
|
||||
"""Handles requests related to the OpenID Connect login flow.
|
||||
class OidcProvider:
|
||||
"""Wraps the config for a single OIDC IdentityProvider
|
||||
|
||||
Provides methods for handling redirect requests and callbacks via that particular
|
||||
IdP.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
token_generator: "OidcSessionTokenGenerator",
|
||||
provider: OidcProviderConfig,
|
||||
):
|
||||
self._store = hs.get_datastore()
|
||||
|
||||
self._token_generator = OidcSessionTokenGenerator(hs)
|
||||
self._token_generator = token_generator
|
||||
|
||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||
|
||||
provider = hs.config.oidc.oidc_provider
|
||||
# we should not have been instantiated if there is no configured provider.
|
||||
assert provider is not None
|
||||
|
||||
self._scopes = provider.scopes
|
||||
self._user_profile_method = provider.user_profile_method
|
||||
self._client_auth = ClientAuth(
|
||||
@@ -552,22 +682,16 @@ class OidcHandler:
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
|
||||
async def handle_oidc_callback(
|
||||
self, request: SynapseRequest, session_data: "OidcSessionData", code: str
|
||||
) -> None:
|
||||
"""Handle an incoming request to /_synapse/oidc/callback
|
||||
|
||||
Since we might want to display OIDC-related errors in a user-friendly
|
||||
way, we don't raise SynapseError from here. Instead, we call
|
||||
``self._sso_handler.render_error`` which displays an HTML page for the error.
|
||||
By this time we have already validated the session on the synapse side, and
|
||||
now need to do the provider-specific operations. This includes:
|
||||
|
||||
Most of the OpenID Connect logic happens here:
|
||||
|
||||
- first, we check if there was any error returned by the provider and
|
||||
display it
|
||||
- then we fetch the session cookie, decode and verify it
|
||||
- the ``state`` query parameter should match with the one stored in the
|
||||
session cookie
|
||||
- once we known this session is legit, exchange the code with the
|
||||
provider using the ``token_endpoint`` (see ``_exchange_code``)
|
||||
- exchange the code with the provider using the ``token_endpoint`` (see
|
||||
``_exchange_code``)
|
||||
- once we have the token, use it to either extract the UserInfo from
|
||||
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
|
||||
to fetch UserInfo from the ``userinfo_endpoint``
|
||||
@@ -577,86 +701,12 @@ class OidcHandler:
|
||||
|
||||
Args:
|
||||
request: the incoming request from the browser.
|
||||
session_data: the session data, extracted from our cookie
|
||||
code: The authorization code we got from the callback.
|
||||
"""
|
||||
|
||||
# The provider might redirect with an error.
|
||||
# In that case, just display it as-is.
|
||||
if b"error" in request.args:
|
||||
# error response from the auth server. see:
|
||||
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
|
||||
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
|
||||
error = request.args[b"error"][0].decode()
|
||||
description = request.args.get(b"error_description", [b""])[0].decode()
|
||||
|
||||
# Most of the errors returned by the provider could be due by
|
||||
# either the provider misbehaving or Synapse being misconfigured.
|
||||
# The only exception of that is "access_denied", where the user
|
||||
# probably cancelled the login flow. In other cases, log those errors.
|
||||
if error != "access_denied":
|
||||
logger.error("Error from the OIDC provider: %s %s", error, description)
|
||||
|
||||
self._sso_handler.render_error(request, error, description)
|
||||
return
|
||||
|
||||
# otherwise, it is presumably a successful response. see:
|
||||
# https://tools.ietf.org/html/rfc6749#section-4.1.2
|
||||
|
||||
# Fetch the session cookie
|
||||
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
|
||||
if session is None:
|
||||
logger.info("No session cookie found")
|
||||
self._sso_handler.render_error(
|
||||
request, "missing_session", "No session cookie found"
|
||||
)
|
||||
return
|
||||
|
||||
# Remove the cookie. There is a good chance that if the callback failed
|
||||
# once, it will fail next time and the code will already be exchanged.
|
||||
# Removing it early avoids spamming the provider with token requests.
|
||||
request.addCookie(
|
||||
SESSION_COOKIE_NAME,
|
||||
b"",
|
||||
path="/_synapse/oidc",
|
||||
expires="Thu, Jan 01 1970 00:00:00 UTC",
|
||||
httpOnly=True,
|
||||
sameSite="lax",
|
||||
)
|
||||
|
||||
# Check for the state query parameter
|
||||
if b"state" not in request.args:
|
||||
logger.info("State parameter is missing")
|
||||
self._sso_handler.render_error(
|
||||
request, "invalid_request", "State parameter is missing"
|
||||
)
|
||||
return
|
||||
|
||||
state = request.args[b"state"][0].decode()
|
||||
|
||||
# Deserialize the session token and verify it.
|
||||
try:
|
||||
session_data = self._token_generator.verify_oidc_session_token(
|
||||
session, state
|
||||
)
|
||||
except MacaroonDeserializationException as e:
|
||||
logger.exception("Invalid session")
|
||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||
return
|
||||
except MacaroonInvalidSignatureException as e:
|
||||
logger.exception("Could not verify session")
|
||||
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
||||
return
|
||||
|
||||
# Exchange the code with the provider
|
||||
if b"code" not in request.args:
|
||||
logger.info("Code parameter is missing")
|
||||
self._sso_handler.render_error(
|
||||
request, "invalid_request", "Code parameter is missing"
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("Exchanging code")
|
||||
code = request.args[b"code"][0].decode()
|
||||
try:
|
||||
logger.debug("Exchanging code")
|
||||
token = await self._exchange_code(code)
|
||||
except OidcError as e:
|
||||
logger.exception("Could not exchange code")
|
||||
|
||||
@@ -466,9 +466,6 @@ class PersistEventsStore:
|
||||
if not state_events:
|
||||
return
|
||||
|
||||
# Map from event ID to chain ID/sequence number.
|
||||
chain_map = {} # type: Dict[str, Tuple[int, int]]
|
||||
|
||||
# We need to know the type/state_key and auth events of the events we're
|
||||
# calculating chain IDs for. We don't rely on having the full Event
|
||||
# instances as we'll potentially be pulling more events from the DB and
|
||||
@@ -479,9 +476,33 @@ class PersistEventsStore:
|
||||
event_to_auth_chain = {
|
||||
e.event_id: e.auth_event_ids() for e in state_events.values()
|
||||
}
|
||||
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
|
||||
|
||||
self._add_chain_cover_index(
|
||||
txn, event_to_room_id, event_to_types, event_to_auth_chain
|
||||
)
|
||||
|
||||
def _add_chain_cover_index(
|
||||
self,
|
||||
txn,
|
||||
event_to_room_id: Dict[str, str],
|
||||
event_to_types: Dict[str, Tuple[str, str]],
|
||||
event_to_auth_chain: Dict[str, List[str]],
|
||||
) -> None:
|
||||
"""Calculate the chain cover index for the given events.
|
||||
|
||||
Args:
|
||||
event_to_room_id: Event ID to the room ID of the event
|
||||
event_to_types: Event ID to type and state_key of the event
|
||||
event_to_auth_chain: Event ID to list of auth event IDs of the
|
||||
event (events with no auth events can be excluded).
|
||||
"""
|
||||
|
||||
# Map from event ID to chain ID/sequence number.
|
||||
chain_map = {} # type: Dict[str, Tuple[int, int]]
|
||||
|
||||
# Set of event IDs to calculate chain ID/seq numbers for.
|
||||
events_to_calc_chain_id_for = set(state_events)
|
||||
events_to_calc_chain_id_for = set(event_to_room_id)
|
||||
|
||||
# We check if there are any events that need to be handled in the rooms
|
||||
# we're looking at. These should just be out of band memberships, where
|
||||
@@ -491,7 +512,7 @@ class PersistEventsStore:
|
||||
table="event_auth_chain_to_calculate",
|
||||
keyvalues={},
|
||||
column="room_id",
|
||||
iterable={e.room_id for e in state_events.values()},
|
||||
iterable=set(event_to_room_id.values()),
|
||||
retcols=("event_id", "type", "state_key"),
|
||||
)
|
||||
for row in rows:
|
||||
@@ -582,16 +603,17 @@ class PersistEventsStore:
|
||||
# the list of events to calculate chain IDs for next time
|
||||
# around. (Otherwise we will have already added it to the
|
||||
# table).
|
||||
event = state_events.get(event_id)
|
||||
if event:
|
||||
room_id = event_to_room_id.get(event_id)
|
||||
if room_id:
|
||||
e_type, state_key = event_to_types[event_id]
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="event_auth_chain_to_calculate",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
"event_id": event_id,
|
||||
"room_id": room_id,
|
||||
"type": e_type,
|
||||
"state_key": state_key,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -617,7 +639,7 @@ class PersistEventsStore:
|
||||
events_to_calc_chain_id_for, event_to_auth_chain
|
||||
):
|
||||
existing_chain_id = None
|
||||
for auth_id in event_to_auth_chain[event_id]:
|
||||
for auth_id in event_to_auth_chain.get(event_id, []):
|
||||
if event_to_types.get(event_id) == event_to_types.get(auth_id):
|
||||
existing_chain_id = chain_map[auth_id]
|
||||
break
|
||||
@@ -730,11 +752,11 @@ class PersistEventsStore:
|
||||
# auth events (A, B) to check if B is reachable from A.
|
||||
reduction = {
|
||||
a_id
|
||||
for a_id in event_to_auth_chain[event_id]
|
||||
for a_id in event_to_auth_chain.get(event_id, [])
|
||||
if chain_map[a_id][0] != chain_id
|
||||
}
|
||||
for start_auth_id, end_auth_id in itertools.permutations(
|
||||
event_to_auth_chain[event_id], r=2,
|
||||
event_to_auth_chain.get(event_id, []), r=2,
|
||||
):
|
||||
if chain_links.exists_path_from(
|
||||
chain_map[start_auth_id], chain_map[end_auth_id]
|
||||
|
||||
@@ -14,13 +14,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.types import JsonDict
|
||||
|
||||
@@ -108,6 +108,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
"rejected_events_metadata", self._rejected_events_metadata,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"chain_cover", self._chain_cover_index,
|
||||
)
|
||||
|
||||
async def _background_reindex_fields_sender(self, progress, batch_size):
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
@@ -706,3 +710,187 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
return len(results)
|
||||
|
||||
async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
|
||||
"""A background updates that iterates over all rooms and generates the
|
||||
chain cover index for them.
|
||||
"""
|
||||
|
||||
current_room_id = progress.get("current_room_id", "")
|
||||
|
||||
# Have we finished processing the current room.
|
||||
finished = progress.get("finished", True)
|
||||
|
||||
# Where we've processed up to in the room, defaults to the start of the
|
||||
# room.
|
||||
last_depth = progress.get("last_depth", -1)
|
||||
last_stream = progress.get("last_stream", -1)
|
||||
|
||||
# Have we set the `has_auth_chain_index` for the room yet.
|
||||
has_set_room_has_chain_index = progress.get(
|
||||
"has_set_room_has_chain_index", False
|
||||
)
|
||||
|
||||
if finished:
|
||||
# If we've finished with the previous room (or its our first
|
||||
# iteration) we move on to the next room.
|
||||
|
||||
def _get_next_room(txn: Cursor) -> Optional[str]:
|
||||
sql = """
|
||||
SELECT room_id FROM rooms
|
||||
WHERE room_id > ?
|
||||
AND (
|
||||
NOT has_auth_chain_index
|
||||
OR has_auth_chain_index IS NULL
|
||||
)
|
||||
ORDER BY room_id
|
||||
LIMIT 1
|
||||
"""
|
||||
txn.execute(sql, (current_room_id,))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
|
||||
return None
|
||||
|
||||
current_room_id = await self.db_pool.runInteraction(
|
||||
"_chain_cover_index", _get_next_room
|
||||
)
|
||||
if not current_room_id:
|
||||
await self.db_pool.updates._end_background_update("chain_cover")
|
||||
return 0
|
||||
|
||||
logger.debug("Adding chain cover to %s", current_room_id)
|
||||
|
||||
def _calculate_auth_chain(
|
||||
txn: Cursor, last_depth: int, last_stream: int
|
||||
) -> Tuple[int, int, int]:
|
||||
# Get the next set of events in the room (that we haven't already
|
||||
# computed chain cover for). We do this in topological order.
|
||||
|
||||
# We want to do a `(topological_ordering, stream_ordering) > (?,?)`
|
||||
# comparison, but that is not supported on older SQLite versions
|
||||
tuple_clause, tuple_args = make_tuple_comparison_clause(
|
||||
self.database_engine,
|
||||
[
|
||||
("topological_ordering", last_depth),
|
||||
("stream_ordering", last_stream),
|
||||
],
|
||||
)
|
||||
|
||||
sql = """
|
||||
SELECT
|
||||
event_id, state_events.type, state_events.state_key,
|
||||
topological_ordering, stream_ordering
|
||||
FROM events
|
||||
INNER JOIN state_events USING (event_id)
|
||||
LEFT JOIN event_auth_chains USING (event_id)
|
||||
LEFT JOIN event_auth_chain_to_calculate USING (event_id)
|
||||
WHERE events.room_id = ?
|
||||
AND event_auth_chains.event_id IS NULL
|
||||
AND event_auth_chain_to_calculate.event_id IS NULL
|
||||
AND %(tuple_cmp)s
|
||||
ORDER BY topological_ordering, stream_ordering
|
||||
LIMIT ?
|
||||
""" % {
|
||||
"tuple_cmp": tuple_clause,
|
||||
}
|
||||
|
||||
args = [current_room_id]
|
||||
args.extend(tuple_args)
|
||||
args.append(batch_size)
|
||||
|
||||
txn.execute(sql, args)
|
||||
rows = txn.fetchall()
|
||||
|
||||
# Put the results in the necessary format for
|
||||
# `_add_chain_cover_index`
|
||||
event_to_room_id = {row[0]: current_room_id for row in rows}
|
||||
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
|
||||
|
||||
new_last_depth = rows[-1][3] if rows else last_depth # type: int
|
||||
new_last_stream = rows[-1][4] if rows else last_stream # type: int
|
||||
|
||||
count = len(rows)
|
||||
|
||||
# We also need to fetch the auth events for them.
|
||||
auth_events = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
column="event_id",
|
||||
iterable=event_to_room_id,
|
||||
keyvalues={},
|
||||
retcols=("event_id", "auth_id"),
|
||||
)
|
||||
|
||||
event_to_auth_chain = {} # type: Dict[str, List[str]]
|
||||
for row in auth_events:
|
||||
event_to_auth_chain.setdefault(row["event_id"], []).append(
|
||||
row["auth_id"]
|
||||
)
|
||||
|
||||
# Calculate and persist the chain cover index for this set of events.
|
||||
#
|
||||
# Annoyingly we need to gut wrench into the persit event store so that
|
||||
# we can reuse the function to calculate the chain cover for rooms.
|
||||
self.hs.get_datastores().persist_events._add_chain_cover_index(
|
||||
txn, event_to_room_id, event_to_types, event_to_auth_chain,
|
||||
)
|
||||
|
||||
return new_last_depth, new_last_stream, count
|
||||
|
||||
last_depth, last_stream, count = await self.db_pool.runInteraction(
|
||||
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
|
||||
)
|
||||
|
||||
total_rows_processed = count
|
||||
|
||||
if count < batch_size and not has_set_room_has_chain_index:
|
||||
# If we've done all the events in the room we flip the
|
||||
# `has_auth_chain_index` in the DB. Note that its possible for
|
||||
# further events to be persisted between the above and setting the
|
||||
# flag without having the chain cover calculated for them. This is
|
||||
# fine as a) the code gracefully handles these cases and b) we'll
|
||||
# calculate them below.
|
||||
|
||||
await self.db_pool.simple_update(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": current_room_id},
|
||||
updatevalues={"has_auth_chain_index": True},
|
||||
desc="_chain_cover_index",
|
||||
)
|
||||
has_set_room_has_chain_index = True
|
||||
|
||||
# Handle any events that might have raced with us flipping the
|
||||
# bit above.
|
||||
last_depth, last_stream, count = await self.db_pool.runInteraction(
|
||||
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
|
||||
)
|
||||
|
||||
total_rows_processed += count
|
||||
|
||||
# Note that at this point its technically possible that more events
|
||||
# than our `batch_size` have been persisted without their chain
|
||||
# cover, so we need to continue processing this room if the last
|
||||
# count returned was equal to the `batch_size`.
|
||||
|
||||
if count < batch_size:
|
||||
# We've finished calculating the index for this room, move on to the
|
||||
# next room.
|
||||
await self.db_pool.updates._background_update_progress(
|
||||
"chain_cover", {"current_room_id": current_room_id, "finished": True},
|
||||
)
|
||||
else:
|
||||
# We still have outstanding events to calculate the index for.
|
||||
await self.db_pool.updates._background_update_progress(
|
||||
"chain_cover",
|
||||
{
|
||||
"current_room_id": current_room_id,
|
||||
"last_depth": last_depth,
|
||||
"last_stream": last_stream,
|
||||
"has_auth_chain_index": has_set_room_has_chain_index,
|
||||
"finished": False,
|
||||
},
|
||||
)
|
||||
|
||||
return total_rows_processed
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
|
||||
(5906, 'chain_cover', '{}', 'rejected_events_metadata');
|
||||
@@ -464,19 +464,17 @@ class TransactionStore(TransactionWorkerStore):
|
||||
txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
|
||||
) -> List[str]:
|
||||
q = """
|
||||
SELECT destination FROM destinations
|
||||
WHERE destination IN (
|
||||
SELECT destination FROM destination_rooms
|
||||
WHERE destination_rooms.stream_ordering >
|
||||
destinations.last_successful_stream_ordering
|
||||
)
|
||||
AND destination > ?
|
||||
AND (
|
||||
retry_last_ts IS NULL OR
|
||||
retry_last_ts + retry_interval < ?
|
||||
)
|
||||
ORDER BY destination
|
||||
LIMIT 25
|
||||
SELECT DISTINCT destination FROM destinations
|
||||
INNER JOIN destination_rooms USING (destination)
|
||||
WHERE
|
||||
stream_ordering > last_successful_stream_ordering
|
||||
AND destination > ?
|
||||
AND (
|
||||
retry_last_ts IS NULL OR
|
||||
retry_last_ts + retry_interval < ?
|
||||
)
|
||||
ORDER BY destination
|
||||
LIMIT 25
|
||||
"""
|
||||
txn.execute(
|
||||
q,
|
||||
|
||||
@@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
|
||||
|
||||
self.handler = hs.get_oidc_handler()
|
||||
self.provider = self.handler._provider
|
||||
sso_handler = hs.get_sso_handler()
|
||||
# Mock the render error method.
|
||||
self.render_error = Mock(return_value=None)
|
||||
@@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
return hs
|
||||
|
||||
def metadata_edit(self, values):
|
||||
return patch.dict(self.handler._provider_metadata, values)
|
||||
return patch.dict(self.provider._provider_metadata, values)
|
||||
|
||||
def assertRenderedError(self, error, error_description=None):
|
||||
self.render_error.assert_called_once()
|
||||
args = self.render_error.call_args[0]
|
||||
self.assertEqual(args[1], error)
|
||||
if error_description is not None:
|
||||
@@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
def test_config(self):
|
||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||
self.assertEqual(self.handler._callback_url, CALLBACK_URL)
|
||||
self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
|
||||
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
|
||||
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
||||
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
||||
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
||||
|
||||
@override_config({"oidc_config": {"discover": True}})
|
||||
def test_discovery(self):
|
||||
"""The handler should discover the endpoints from OIDC discovery document."""
|
||||
# This would throw if some metadata were invalid
|
||||
metadata = self.get_success(self.handler.load_metadata())
|
||||
metadata = self.get_success(self.provider.load_metadata())
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
|
||||
self.assertEqual(metadata.issuer, ISSUER)
|
||||
@@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# subsequent calls should be cached
|
||||
self.http_client.reset_mock()
|
||||
self.get_success(self.handler.load_metadata())
|
||||
self.get_success(self.provider.load_metadata())
|
||||
self.http_client.get_json.assert_not_called()
|
||||
|
||||
@override_config({"oidc_config": COMMON_CONFIG})
|
||||
def test_no_discovery(self):
|
||||
"""When discovery is disabled, it should not try to load from discovery document."""
|
||||
self.get_success(self.handler.load_metadata())
|
||||
self.get_success(self.provider.load_metadata())
|
||||
self.http_client.get_json.assert_not_called()
|
||||
|
||||
@override_config({"oidc_config": COMMON_CONFIG})
|
||||
def test_load_jwks(self):
|
||||
"""JWKS loading is done once (then cached) if used."""
|
||||
jwks = self.get_success(self.handler.load_jwks())
|
||||
jwks = self.get_success(self.provider.load_jwks())
|
||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||
self.assertEqual(jwks, {"keys": []})
|
||||
|
||||
# subsequent calls should be cached…
|
||||
self.http_client.reset_mock()
|
||||
self.get_success(self.handler.load_jwks())
|
||||
self.get_success(self.provider.load_jwks())
|
||||
self.http_client.get_json.assert_not_called()
|
||||
|
||||
# …unless forced
|
||||
self.http_client.reset_mock()
|
||||
self.get_success(self.handler.load_jwks(force=True))
|
||||
self.get_success(self.provider.load_jwks(force=True))
|
||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||
|
||||
# Throw if the JWKS uri is missing
|
||||
with self.metadata_edit({"jwks_uri": None}):
|
||||
self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
|
||||
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
||||
|
||||
# Return empty key set if JWKS are not used
|
||||
self.handler._scopes = [] # not asking the openid scope
|
||||
self.provider._scopes = [] # not asking the openid scope
|
||||
self.http_client.get_json.reset_mock()
|
||||
jwks = self.get_success(self.handler.load_jwks(force=True))
|
||||
jwks = self.get_success(self.provider.load_jwks(force=True))
|
||||
self.http_client.get_json.assert_not_called()
|
||||
self.assertEqual(jwks, {"keys": []})
|
||||
|
||||
@override_config({"oidc_config": COMMON_CONFIG})
|
||||
def test_validate_config(self):
|
||||
"""Provider metadatas are extensively validated."""
|
||||
h = self.handler
|
||||
h = self.provider
|
||||
|
||||
# Default test config does not throw
|
||||
h._validate_metadata()
|
||||
@@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"""Provider metadata validation can be disabled by config."""
|
||||
with self.metadata_edit({"issuer": "http://insecure"}):
|
||||
# This should not throw
|
||||
self.handler._validate_metadata()
|
||||
self.provider._validate_metadata()
|
||||
|
||||
def test_redirect_request(self):
|
||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||
req = Mock(spec=["addCookie"])
|
||||
url = self.get_success(
|
||||
self.handler.handle_redirect_request(req, b"http://client/redirect")
|
||||
self.provider.handle_redirect_request(req, b"http://client/redirect")
|
||||
)
|
||||
url = urlparse(url)
|
||||
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
|
||||
@@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# ensure that we are correctly testing the fallback when "get_extra_attributes"
|
||||
# is not implemented.
|
||||
mapping_provider = self.handler._user_mapping_provider
|
||||
mapping_provider = self.provider._user_mapping_provider
|
||||
with self.assertRaises(AttributeError):
|
||||
_ = mapping_provider.get_extra_attributes
|
||||
|
||||
@@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"username": username,
|
||||
}
|
||||
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
@@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
expected_user_id, request, client_redirect_url, None,
|
||||
)
|
||||
self.handler._exchange_code.assert_called_once_with(code)
|
||||
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||
self.handler._fetch_userinfo.assert_not_called()
|
||||
self.provider._exchange_code.assert_called_once_with(code)
|
||||
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||
self.provider._fetch_userinfo.assert_not_called()
|
||||
self.render_error.assert_not_called()
|
||||
|
||||
# Handle mapping errors
|
||||
with patch.object(
|
||||
self.handler,
|
||||
self.provider,
|
||||
"_remote_id_from_userinfo",
|
||||
new=Mock(side_effect=MappingException()),
|
||||
):
|
||||
@@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.assertRenderedError("mapping_error")
|
||||
|
||||
# Handle ID token errors
|
||||
self.handler._parse_id_token = simple_async_mock(raises=Exception())
|
||||
self.provider._parse_id_token = simple_async_mock(raises=Exception())
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_token")
|
||||
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
self.handler._exchange_code.reset_mock()
|
||||
self.handler._parse_id_token.reset_mock()
|
||||
self.handler._fetch_userinfo.reset_mock()
|
||||
self.provider._exchange_code.reset_mock()
|
||||
self.provider._parse_id_token.reset_mock()
|
||||
self.provider._fetch_userinfo.reset_mock()
|
||||
|
||||
# With userinfo fetching
|
||||
self.handler._scopes = [] # do not ask the "openid" scope
|
||||
self.provider._scopes = [] # do not ask the "openid" scope
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
expected_user_id, request, client_redirect_url, None,
|
||||
)
|
||||
self.handler._exchange_code.assert_called_once_with(code)
|
||||
self.handler._parse_id_token.assert_not_called()
|
||||
self.handler._fetch_userinfo.assert_called_once_with(token)
|
||||
self.provider._exchange_code.assert_called_once_with(code)
|
||||
self.provider._parse_id_token.assert_not_called()
|
||||
self.provider._fetch_userinfo.assert_called_once_with(token)
|
||||
self.render_error.assert_not_called()
|
||||
|
||||
# Handle userinfo fetching error
|
||||
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
|
||||
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("fetch_error")
|
||||
|
||||
# Handle code exchange failure
|
||||
from synapse.handlers.oidc_handler import OidcError
|
||||
|
||||
self.handler._exchange_code = simple_async_mock(
|
||||
self.provider._exchange_code = simple_async_mock(
|
||||
raises=OidcError("invalid_request")
|
||||
)
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
@@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
|
||||
)
|
||||
code = "code"
|
||||
ret = self.get_success(self.handler._exchange_code(code))
|
||||
ret = self.get_success(self.provider._exchange_code(code))
|
||||
kwargs = self.http_client.request.call_args[1]
|
||||
|
||||
self.assertEqual(ret, token)
|
||||
@@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
from synapse.handlers.oidc_handler import OidcError
|
||||
|
||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
||||
self.assertEqual(exc.value.error, "foo")
|
||||
self.assertEqual(exc.value.error_description, "bar")
|
||||
|
||||
@@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
|
||||
)
|
||||
)
|
||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
||||
self.assertEqual(exc.value.error, "server_error")
|
||||
|
||||
# Internal server error with JSON body
|
||||
@@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
||||
self.assertEqual(exc.value.error, "internal_server_error")
|
||||
|
||||
# 4xx error without "error" field
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
|
||||
)
|
||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
||||
self.assertEqual(exc.value.error, "server_error")
|
||||
|
||||
# 2xx error with "error" field
|
||||
@@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
|
||||
)
|
||||
)
|
||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
||||
self.assertEqual(exc.value.error, "some_error")
|
||||
|
||||
@override_config(
|
||||
@@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"username": "foo",
|
||||
"phone": "1234567",
|
||||
}
|
||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
@@ -979,9 +981,10 @@ async def _make_callback_with_userinfo(
|
||||
from synapse.handlers.oidc_handler import OidcSessionData
|
||||
|
||||
handler = hs.get_oidc_handler()
|
||||
handler._exchange_code = simple_async_mock(return_value={})
|
||||
handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
provider = handler._provider
|
||||
provider._exchange_code = simple_async_mock(return_value={})
|
||||
provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
|
||||
state = "state"
|
||||
session = handler._token_generator.generate_oidc_session_token(
|
||||
|
||||
@@ -20,7 +20,10 @@ from twisted.trial import unittest
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import EventBase
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.storage.databases.main.events import _LinkMap
|
||||
from synapse.types import create_requester
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
@@ -470,3 +473,114 @@ class LinkMapTestCase(unittest.TestCase):
|
||||
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
|
||||
|
||||
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
|
||||
|
||||
|
||||
class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def test_background_update(self):
|
||||
"""Test that the background update to calculate auth chains for historic
|
||||
rooms works correctly.
|
||||
"""
|
||||
|
||||
# Create a room
|
||||
user_id = self.register_user("foo", "pass")
|
||||
token = self.login("foo", "pass")
|
||||
room_id = self.helper.create_room_as(user_id, tok=token)
|
||||
requester = create_requester(user_id)
|
||||
|
||||
store = self.hs.get_datastore()
|
||||
|
||||
# Mark the room as not having a chain cover index
|
||||
self.get_success(
|
||||
store.db_pool.simple_update(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
updatevalues={"has_auth_chain_index": False},
|
||||
desc="test",
|
||||
)
|
||||
)
|
||||
|
||||
# Create a fork in the DAG with different events.
|
||||
event_handler = self.hs.get_event_creation_handler()
|
||||
latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
|
||||
event, context = self.get_success(
|
||||
event_handler.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": "some_state_type",
|
||||
"state_key": "",
|
||||
"content": {},
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
},
|
||||
prev_event_ids=latest_event_ids,
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
event_handler.handle_new_client_event(requester, event, context)
|
||||
)
|
||||
state1 = list(self.get_success(context.get_current_state_ids()).values())
|
||||
|
||||
event, context = self.get_success(
|
||||
event_handler.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": "some_state_type",
|
||||
"state_key": "",
|
||||
"content": {},
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
},
|
||||
prev_event_ids=latest_event_ids,
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
event_handler.handle_new_client_event(requester, event, context)
|
||||
)
|
||||
state2 = list(self.get_success(context.get_current_state_ids()).values())
|
||||
|
||||
# Delete the chain cover info.
|
||||
|
||||
def _delete_tables(txn):
|
||||
txn.execute("DELETE FROM event_auth_chains")
|
||||
txn.execute("DELETE FROM event_auth_chain_links")
|
||||
|
||||
self.get_success(store.db_pool.runInteraction("test", _delete_tables))
|
||||
|
||||
# Insert and run the background update.
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"background_updates",
|
||||
{"update_name": "chain_cover", "progress_json": "{}"},
|
||||
)
|
||||
)
|
||||
|
||||
# Ugh, have to reset this flag
|
||||
store.db_pool.updates._all_done = False
|
||||
|
||||
while not self.get_success(
|
||||
store.db_pool.updates.has_completed_background_updates()
|
||||
):
|
||||
self.get_success(
|
||||
store.db_pool.updates.do_next_background_update(100), by=0.1
|
||||
)
|
||||
|
||||
# Test that the `has_auth_chain_index` has been set
|
||||
self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
|
||||
|
||||
# Test that calculating the auth chain difference using the newly
|
||||
# calculated chain cover works.
|
||||
self.get_success(
|
||||
store.db_pool.runInteraction(
|
||||
"test",
|
||||
store._get_auth_chain_difference_using_cover_index_txn,
|
||||
room_id,
|
||||
[state1, state2],
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user