1
0

Merge commit 'cf7d3c90d' into dinsic

This commit is contained in:
Andrew Morgan
2021-04-16 12:33:45 +01:00
92 changed files with 1580 additions and 555 deletions

1
changelog.d/8802.doc Normal file
View File

@@ -0,0 +1 @@
Fix the "Event persist rate" section of the included grafana dashboard by adding missing prometheus rules.

1
changelog.d/8827.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug where we might not correctly calculate the current state for rooms with multiple extremities.

1
changelog.d/8837.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a long standing bug in the register admin endpoint (`/_synapse/admin/v1/register`) when the `mac` field was not provided. The endpoint now properly returns a 400 error. Contributed by @edwargix.

1
changelog.d/8853.feature Normal file
View File

@@ -0,0 +1 @@
Add optional HTTP authentication to replication endpoints.

1
changelog.d/8858.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password.

1
changelog.d/8861.misc Normal file
View File

@@ -0,0 +1 @@
Remove some unnecessary stubbing from unit tests.

1
changelog.d/8862.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a longstanding bug where a 500 error would be returned if the `Content-Length` header was not provided to the upload media resource.

1
changelog.d/8864.misc Normal file
View File

@@ -0,0 +1 @@
Remove unused `FakeResponse` class from unit tests.

1
changelog.d/8865.bugfix Normal file
View File

@@ -0,0 +1 @@
Add additional validation to pusher URLs to be compliant with the specification.

1
changelog.d/8867.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix the error code that is returned when a user tries to register on a homeserver on which new-user registration has been disabled.

1
changelog.d/8872.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug where `PUT /_synapse/admin/v2/users/<user_id>` failed to create a new user when `avatar_url` is specified. Bug introduced in Synapse v1.9.0.

1
changelog.d/8873.doc Normal file
View File

@@ -0,0 +1 @@
Fix an error in the documentation for the SAML username mapping provider.

1
changelog.d/8874.feature Normal file
View File

@@ -0,0 +1 @@
Improve the error messages printed as a result of configuration problems for extension modules.

1
changelog.d/8879.misc Normal file
View File

@@ -0,0 +1 @@
Pass `room_id` to `get_auth_chain_difference`.

1
changelog.d/8880.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to push module.

1
changelog.d/8881.misc Normal file
View File

@@ -0,0 +1 @@
Simplify logic for handling user-interactive-auth via single-sign-on servers.

1
changelog.d/8882.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to push module.

1
changelog.d/8883.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a 500 error when attempting to preview an empty HTML file.

1
changelog.d/8887.feature Normal file
View File

@@ -0,0 +1 @@
Add `X-Robots-Tag` header to stop web crawlers from indexing media.

1
changelog.d/8891.doc Normal file
View File

@@ -0,0 +1 @@
Clarify comments around template directories in `sample_config.yaml`.

View File

@@ -58,3 +58,21 @@ groups:
labels:
type: "PDU"
expr: 'synapse_federation_transaction_queue_pending_pdus + 0'
- record: synapse_storage_events_persisted_by_source_type
expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_type="remote"})
labels:
type: remote
- record: synapse_storage_events_persisted_by_source_type
expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity="*client*",origin_type="local"})
labels:
type: local
- record: synapse_storage_events_persisted_by_source_type
expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity!="*client*",origin_type="local"})
labels:
type: bridges
- record: synapse_storage_events_persisted_by_event_type
expr: sum without(origin_entity, origin_type) (synapse_storage_events_persisted_events_sep)
- record: synapse_storage_events_persisted_by_origin
expr: sum without(type) (synapse_storage_events_persisted_events_sep)

View File

@@ -69,7 +69,8 @@ RUN apt-get update -qq -o Acquire::Languages=none \
python3-setuptools \
python3-venv \
sqlite3 \
libpq-dev
libpq-dev \
xmlsec1
COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb /

View File

@@ -2059,11 +2059,8 @@ sso:
# - https://my.custom.client/
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
# If not set, or the files named below are not found within the template
# directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#
@@ -2293,9 +2290,8 @@ email:
#validation_token_lifetime: 15m
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# Do not uncomment this setting unless you want to customise the templates.
# If not set, or the files named below are not found within the template
# directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#
@@ -2779,6 +2775,13 @@ opentracing:
#
#run_background_tasks_on: worker1
# A shared secret used by the replication APIs to authenticate HTTP requests
# from workers.
#
# By default this is unused and traffic is not authenticated.
#
#worker_replication_secret: ""
# Configuration for Redis when using workers. This *must* be enabled when
# using workers (unless using old style direct TCP configuration).

View File

@@ -116,11 +116,13 @@ comment these options out and use those specified by the module instead.
A custom mapping provider must specify the following methods:
* `__init__(self, parsed_config)`
* `__init__(self, parsed_config, module_api)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
stable API available for extension modules.
* `parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:

View File

@@ -89,7 +89,8 @@ shared configuration file.
Normally, only a couple of changes are needed to make an existing configuration
file suitable for use with workers. First, you need to enable an "HTTP replication
listener" for the main process; and secondly, you need to enable redis-based
replication. For example:
replication. Optionally, a shared secret can be used to authenticate HTTP
traffic between workers. For example:
```yaml
@@ -103,6 +104,9 @@ listeners:
resources:
- names: [replication]
# Add a random shared secret to authenticate traffic.
worker_replication_secret: ""
redis:
enabled: true
```

View File

@@ -43,6 +43,7 @@ files =
synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py,
synapse/handlers/saml_handler.py,
synapse/handlers/sso.py,
synapse/handlers/sync.py,
synapse/handlers/ui_auth,
synapse/http/client.py,
@@ -55,6 +56,10 @@ files =
synapse/metrics,
synapse/module_api,
synapse/notifier.py,
synapse/push/emailpusher.py,
synapse/push/httppusher.py,
synapse/push/mailer.py,
synapse/push/pusher.py,
synapse/push/pusherpool.py,
synapse/push/push_rule_evaluator.py,
synapse/replication,

View File

@@ -19,7 +19,7 @@ import gc
import logging
import os
import sys
from typing import Iterable
from typing import Iterable, Iterator
from twisted.application import service
from twisted.internet import defer, reactor
@@ -90,7 +90,7 @@ class SynapseHomeServer(HomeServer):
tls = listener_config.tls
site_tag = listener_config.http_options.tag
if site_tag is None:
site_tag = port
site_tag = str(port)
# We always include a health resource.
resources = {"/health": HealthResource()}
@@ -107,7 +107,10 @@ class SynapseHomeServer(HomeServer):
logger.debug("Configuring additional resources: %r", additional_resources)
module_api = self.get_module_api()
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
handler_cls, config = load_module(
resmodule,
("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
)
handler = handler_cls(config, module_api)
if IResource.providedBy(handler):
resource = handler
@@ -342,7 +345,10 @@ def setup(config_options):
"Synapse Homeserver", config_options
)
except ConfigError as e:
sys.stderr.write("\nERROR: %s\n" % (e,))
sys.stderr.write("\n")
for f in format_config_error(e):
sys.stderr.write(f)
sys.stderr.write("\n")
sys.exit(1)
if not config:
@@ -445,6 +451,38 @@ def setup(config_options):
return hs
def format_config_error(e: ConfigError) -> Iterator[str]:
"""
Formats a config error neatly
The idea is to format the immediate error, plus the "causes" of those errors,
hopefully in a way that makes sense to the user. For example:
Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
Failed to parse config for module 'JinjaOidcMappingProvider':
invalid jinja template:
unexpected end of template, expected 'end of print statement'.
Args:
e: the error to be formatted
Returns: An iterator which yields string fragments to be formatted
"""
yield "Error in configuration"
if e.path:
yield " at '%s'" % (".".join(e.path),)
yield ":\n %s" % (e.msg,)
e = e.__cause__
indent = 1
while e:
indent += 1
yield ":\n%s%s" % (" " * indent, str(e))
e = e.__cause__
class SynapseService(service.Service):
"""
A twisted Service class that will start synapse. Used to run synapse

View File

@@ -24,7 +24,7 @@ from collections import OrderedDict
from hashlib import sha256
from io import open as io_open
from textwrap import dedent
from typing import Any, Callable, List, MutableMapping, Optional
from typing import Any, Callable, Iterable, List, MutableMapping, Optional
import attr
import jinja2
@@ -33,7 +33,17 @@ import yaml
class ConfigError(Exception):
pass
"""Represents a problem parsing the configuration
Args:
msg: A textual description of the error.
path: Where appropriate, an indication of where in the configuration
the problem lies.
"""
def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
self.msg = msg
self.path = path
# We split these messages out to allow packages to override with package

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, Iterable, List, Optional
from synapse.config import (
account_validity,
@@ -37,7 +37,10 @@ from synapse.config import (
workers,
)
class ConfigError(Exception): ...
class ConfigError(Exception):
def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
self.msg = msg
self.path = path
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str

View File

@@ -38,14 +38,27 @@ def validate_config(
try:
jsonschema.validate(config, json_schema)
except jsonschema.ValidationError as e:
# copy `config_path` before modifying it.
path = list(config_path)
for p in list(e.path):
if isinstance(p, int):
path.append("<item %i>" % p)
else:
path.append(str(p))
raise json_error_to_config_error(e, config_path)
raise ConfigError(
"Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
)
def json_error_to_config_error(
e: jsonschema.ValidationError, config_path: Iterable[str]
) -> ConfigError:
"""Converts a json validation error to a user-readable ConfigError
Args:
e: the exception to be converted
config_path: the path within the config file. This will be used as a basis
for the error message.
Returns:
a ConfigError
"""
# copy `config_path` before modifying it.
path = list(config_path)
for p in list(e.path):
if isinstance(p, int):
path.append("<item %i>" % p)
else:
path.append(str(p))
return ConfigError(e.message, path)

View File

@@ -390,9 +390,8 @@ class EmailConfig(Config):
#validation_token_lifetime: 15m
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# Do not uncomment this setting unless you want to customise the templates.
# If not set, or the files named below are not found within the template
# directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#

View File

@@ -66,7 +66,7 @@ class OIDCConfig(Config):
(
self.oidc_user_mapping_provider_class,
self.oidc_user_mapping_provider_config,
) = load_module(ump_config)
) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [

View File

@@ -36,7 +36,7 @@ class PasswordAuthProviderConfig(Config):
providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
providers.extend(config.get("password_providers") or [])
for provider in providers:
for i, provider in enumerate(providers):
mod_name = provider["module"]
# This is for backwards compat when the ldap auth provider resided
@@ -45,7 +45,8 @@ class PasswordAuthProviderConfig(Config):
mod_name = LDAP_PROVIDER
(provider_class, provider_config) = load_module(
{"module": mod_name, "config": provider["config"]}
{"module": mod_name, "config": provider["config"]},
("password_providers", "<item %i>" % i),
)
self.password_providers.append((provider_class, provider_config))

View File

@@ -148,7 +148,7 @@ class ContentRepositoryConfig(Config):
# them to be started.
self.media_storage_providers = [] # type: List[tuple]
for provider_config in storage_providers:
for i, provider_config in enumerate(storage_providers):
# We special case the module "file_system" so as not to need to
# expose FileStorageProviderBackend
if provider_config["module"] == "file_system":
@@ -157,7 +157,9 @@ class ContentRepositoryConfig(Config):
".FileStorageProviderBackend"
)
provider_class, parsed_config = load_module(provider_config)
provider_class, parsed_config = load_module(
provider_config, ("media_storage_providers", "<item %i>" % i)
)
wrapper_config = MediaStorageProviderConfig(
provider_config.get("store_local", False),

View File

@@ -180,7 +180,7 @@ class _RoomDirectoryRule:
self._alias_regex = glob_to_regex(alias)
self._room_id_regex = glob_to_regex(room_id)
except Exception as e:
raise ConfigError("Failed to parse glob into regex: %s", e)
raise ConfigError("Failed to parse glob into regex") from e
def matches(self, user_id, room_id, aliases):
"""Tests if this rule matches the given user_id, room_id and aliases.

View File

@@ -125,7 +125,7 @@ class SAML2Config(Config):
(
self.saml2_user_mapping_provider_class,
self.saml2_user_mapping_provider_config,
) = load_module(ump_dict)
) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
# Note parse_config() is already checked during the call to load_module

View File

@@ -33,13 +33,14 @@ class SpamCheckerConfig(Config):
# spam checker, and thus was simply a dictionary with module
# and config keys. Support this old behaviour by checking
# to see if the option resolves to a dictionary
self.spam_checkers.append(load_module(spam_checkers))
self.spam_checkers.append(load_module(spam_checkers, ("spam_checker",)))
elif isinstance(spam_checkers, list):
for spam_checker in spam_checkers:
for i, spam_checker in enumerate(spam_checkers):
config_path = ("spam_checker", "<item %i>" % i)
if not isinstance(spam_checker, dict):
raise ConfigError("spam_checker syntax is incorrect")
raise ConfigError("expected a mapping", config_path)
self.spam_checkers.append(load_module(spam_checker))
self.spam_checkers.append(load_module(spam_checker, config_path))
else:
raise ConfigError("spam_checker syntax is incorrect")

View File

@@ -93,11 +93,8 @@ class SSOConfig(Config):
# - https://my.custom.client/
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
# If not set, or the files named below are not found within the template
# directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#

View File

@@ -26,7 +26,9 @@ class ThirdPartyRulesConfig(Config):
provider = config.get("third_party_event_rules", None)
if provider is not None:
self.third_party_event_rules = load_module(provider)
self.third_party_event_rules = load_module(
provider, ("third_party_event_rules",)
)
def generate_config_section(self, **kwargs):
return """\

View File

@@ -85,6 +85,9 @@ class WorkerConfig(Config):
# The port on the main synapse for HTTP replication endpoint
self.worker_replication_http_port = config.get("worker_replication_http_port")
# The shared secret used for authentication when connecting to the main synapse.
self.worker_replication_secret = config.get("worker_replication_secret", None)
self.worker_name = config.get("worker_name", self.worker_app)
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@@ -185,6 +188,13 @@ class WorkerConfig(Config):
# data). If not provided this defaults to the main process.
#
#run_background_tasks_on: worker1
# A shared secret used by the replication APIs to authenticate HTTP requests
# from workers.
#
# By default this is unused and traffic is not authenticated.
#
#worker_replication_secret: ""
"""
def read_arguments(self, args):

View File

@@ -1546,7 +1546,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args:
hs (synapse.server.HomeServer): homeserver
resource (TransportLayerServer): resource class to register to
resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register.

View File

@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
class BaseHandler:
"""
Common base class for the event handlers.
Deprecated: new code should not use this. Instead, Handler classes should define the
fields they actually need. The utility methods should either be factored out to
standalone helper functions, or to different Handler classes.
"""
def __init__(self, hs: "HomeServer"):

View File

@@ -36,6 +36,8 @@ import attr
import bcrypt
import pymacaroons
from twisted.web.http import Request
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -193,9 +195,7 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
self._sso_enabled = (
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
)
self._password_localdb_enabled = hs.config.password_localdb_enabled
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
@@ -205,7 +205,7 @@ class AuthHandler(BaseHandler):
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
if hs.config.password_localdb_enabled:
if self._password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
for provider in self.password_providers:
@@ -219,14 +219,6 @@ class AuthHandler(BaseHandler):
self._supported_login_types = login_types
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
ui_auth_types = login_types.copy()
if self._sso_enabled:
ui_auth_types.append(LoginType.SSO)
self._supported_ui_auth_types = ui_auth_types
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -339,7 +331,10 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows
flows = [[login_type] for login_type in self._supported_ui_auth_types]
supported_ui_auth_types = await self._get_available_ui_auth_types(
requester.user
)
flows = [[login_type] for login_type in supported_ui_auth_types]
try:
result, params, session_id = await self.check_ui_auth(
@@ -351,7 +346,7 @@ class AuthHandler(BaseHandler):
raise
# find the completed login type
for login_type in self._supported_ui_auth_types:
for login_type in supported_ui_auth_types:
if login_type not in result:
continue
@@ -367,6 +362,41 @@ class AuthHandler(BaseHandler):
return params, session_id
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
"""Get a list of the authentication types this user can use
"""
ui_auth_types = set()
# if the HS supports password auth, and the user has a non-null password, we
# support password auth
if self._password_localdb_enabled and self._password_enabled:
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
if lookupres:
_, password_hash = lookupres
if password_hash:
ui_auth_types.add(LoginType.PASSWORD)
# also allow auth from password providers
for provider in self.password_providers:
for t in provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled:
continue
ui_auth_types.add(t)
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
if await self.store.get_external_ids_by_user(user.to_string()):
ui_auth_types.add(LoginType.SSO)
# Our CAS impl does not (yet) correctly register users in user_external_ids,
# so always offer that if it's available.
if self.hs.config.cas.cas_enabled:
ui_auth_types.add(LoginType.SSO)
return ui_auth_types
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -1019,7 +1049,7 @@ class AuthHandler(BaseHandler):
if result:
return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
@@ -1293,15 +1323,14 @@ class AuthHandler(BaseHandler):
)
async def complete_sso_ui_auth(
self, registered_user_id: str, session_id: str, request: SynapseRequest,
self, registered_user_id: str, session_id: str, request: Request,
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id: The registered user ID to complete SSO login for.
session_id: The ID of the user-interactive auth session.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
"""
# Mark the stage of the authentication as successful.
# Save the user who authenticated with SSO, this will be used to ensure
@@ -1317,7 +1346,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
request: SynapseRequest,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
@@ -1345,7 +1374,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
request: SynapseRequest,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):

View File

@@ -47,7 +47,7 @@ class IdentityHandler(BaseHandler):
super().__init__(hs)
# An HTTP client for contacting trusted URLs.
self.http_client = hs.get_simple_http_client()
self.http_client = SimpleHttpClient(hs)
# An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist

View File

@@ -674,6 +674,21 @@ class OidcHandler(BaseHandler):
self._sso_handler.render_error(request, "invalid_token", str(e))
return
# first check if we're doing a UIA
if ui_auth_session_id:
try:
remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
logger.exception("Could not extract remote user id")
self._sso_handler.render_error(request, "mapping_error", str(e))
return
return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id, remote_user_id, ui_auth_session_id, request
)
# otherwise, it's a login
# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)
@@ -698,14 +713,9 @@ class OidcHandler(BaseHandler):
extra_attributes = await get_extra_attributes(userinfo, token)
# and finally complete the login
if ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url, extra_attributes
)
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url, extra_attributes
)
def _generate_oidc_session_token(
self,
@@ -856,14 +866,11 @@ class OidcHandler(BaseHandler):
The mxid of the user
"""
try:
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
# Some OIDC providers use integer IDs, but Synapse expects external IDs
# to be strings.
remote_user_id = str(remote_user_id)
# Older mapping providers don't accept the `failures` argument, so we
# try and detect support.
@@ -933,6 +940,19 @@ class OidcHandler(BaseHandler):
grandfather_existing_users,
)
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
"""Extract the unique remote id from an OIDC UserInfo block
Args:
userinfo: An object representing the user given by the OIDC provider
Returns:
remote user id
"""
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
# Some OIDC providers use integer IDs, but Synapse expects external IDs
# to be strings.
return str(remote_user_id)
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}

View File

@@ -183,6 +183,24 @@ class SamlHandler(BaseHandler):
saml2_auth.in_response_to, None
)
# first check if we're doing a UIA
if current_session and current_session.ui_auth_session_id:
try:
remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
except MappingException as e:
logger.exception("Failed to extract remote user id from SAML response")
self._sso_handler.render_error(request, "mapping_error", str(e))
return
return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id,
remote_user_id,
current_session.ui_auth_session_id,
request,
)
# otherwise, we're handling a login request.
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
@@ -206,14 +224,7 @@ class SamlHandler(BaseHandler):
self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
self,
@@ -239,16 +250,10 @@ class SamlHandler(BaseHandler):
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
remote_user_id = self._user_mapping_provider.get_remote_user_id(
remote_user_id = self._remote_id_from_saml_response(
saml2_auth, client_redirect_url
)
if not remote_user_id:
raise MappingException(
"Failed to extract remote user id from SAML response"
)
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
@@ -304,6 +309,35 @@ class SamlHandler(BaseHandler):
grandfather_existing_users,
)
def _remote_id_from_saml_response(
self,
saml2_auth: saml2.response.AuthnResponse,
client_redirect_url: Optional[str],
) -> str:
"""Extract the unique remote id from a SAML2 AuthnResponse
Args:
saml2_auth: The parsed SAML2 response.
client_redirect_url: The redirect URL passed in by the client.
Returns:
remote user id
Raises:
MappingException if there was an error extracting the user id
"""
# It's not obvious why we need to pass in the redirect URI to the mapping
# provider, but we do :/
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
if not remote_user_id:
raise MappingException(
"Failed to extract remote user id from SAML response"
)
return remote_user_id
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()

View File

@@ -17,8 +17,9 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
import attr
from twisted.web.http import Request
from synapse.api.errors import RedirectException
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters
@@ -42,14 +43,16 @@ class UserAttributes:
emails = attr.ib(type=List[str], default=attr.Factory(list))
class SsoHandler(BaseHandler):
class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._store = hs.get_datastore()
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
self._auth_handler = hs.get_auth_handler()
def render_error(
self, request, error: str, error_description: Optional[str] = None
@@ -95,7 +98,7 @@ class SsoHandler(BaseHandler):
)
# Check if we already have a mapping for this user.
previously_registered_user_id = await self.store.get_user_by_external_id(
previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
@@ -181,7 +184,7 @@ class SsoHandler(BaseHandler):
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id:
# Future logins should also match this user ID.
await self.store.record_user_external_id(
await self._store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
@@ -214,8 +217,8 @@ class SsoHandler(BaseHandler):
)
# Check if this mxid already exists
user_id = UserID(attributes.localpart, self.server_name).to_string()
if not await self.store.get_users_by_id_case_insensitive(user_id):
user_id = UserID(attributes.localpart, self._server_name).to_string()
if not await self._store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
@@ -238,7 +241,47 @@ class SsoHandler(BaseHandler):
user_agent_ips=[(user_agent, ip_address)],
)
await self.store.record_user_external_id(
await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
async def complete_sso_ui_auth_request(
self,
auth_provider_id: str,
remote_user_id: str,
ui_auth_session_id: str,
request: Request,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and complete UIA.
Note that this requires that the user is mapped in the "user_external_ids"
table. This will be the case if they have ever logged in via SAML or OIDC in
recentish synapse versions, but may not be for older users.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
remote_user_id: The unique identifier from the SSO provider.
ui_auth_session_id: The ID of the user-interactive auth session.
request: The request to complete.
"""
user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
)
if not user_id:
logger.warning(
"Remote user %s/%s has not previously logged in here: UIA will fail",
auth_provider_id,
remote_user_id,
)
# Let the UIA flow handle this the same as if they presented creds for a
# different user.
user_id = ""
await self._auth_handler.complete_sso_ui_auth(
user_id, ui_auth_session_id, request
)

View File

@@ -13,7 +13,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Any, Dict, Optional
from synapse.types import RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class Pusher(metaclass=abc.ABCMeta):
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusherdict["id"]
self.user_id = pusherdict["user_name"]
self.app_id = pusherdict["app_id"]
self.pushkey = pusherdict["pushkey"]
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
# should honour this rather than just looking for anything higher
# because of potential out-of-order event serialisation. This starts
# off as None though as we don't know any better.
self.max_stream_ordering = None # type: Optional[int]
@abc.abstractmethod
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
raise NotImplementedError()
@abc.abstractmethod
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
raise NotImplementedError()
@abc.abstractmethod
def on_started(self, have_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
raise NotImplementedError()
@abc.abstractmethod
def on_stop(self) -> None:
raise NotImplementedError()
class PusherConfigException(Exception):
def __init__(self, msg):
super().__init__(msg)
"""An error occurred when creating a pusher."""

View File

@@ -14,12 +14,19 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher
from synapse.push.mailer import Mailer
from synapse.types import RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
# The amount of time we always wait before ever emailing about a notification
@@ -46,7 +53,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
INCLUDE_ALL_UNREAD_NOTIFS = False
class EmailPusher:
class EmailPusher(Pusher):
"""
A pusher that sends email notifications about events (approximately)
when they happen.
@@ -54,37 +61,31 @@ class EmailPusher:
factor out the common parts
"""
def __init__(self, hs, pusherdict, mailer):
self.hs = hs
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
super().__init__(hs, pusherdict)
self.mailer = mailer
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusherdict["id"]
self.user_id = pusherdict["user_name"]
self.app_id = pusherdict["app_id"]
self.email = pusherdict["pushkey"]
self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.timed_call = None
self.throttle_params = None
# See httppusher
self.max_stream_ordering = None
self.timed_call = None # type: Optional[DelayedCall]
self.throttle_params = {} # type: Dict[str, Dict[str, int]]
self._inited = False
self._is_processing = False
def on_started(self, should_check_for_notifs):
def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
should_check_for_notifs (bool): Whether we should immediately
should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
if should_check_for_notifs and self.mailer is not None:
self._start_processing()
def on_stop(self):
def on_stop(self) -> None:
if self.timed_call:
try:
self.timed_call.cancel()
@@ -92,7 +93,7 @@ class EmailPusher:
pass
self.timed_call = None
def on_new_notifications(self, max_token: RoomStreamToken):
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
@@ -106,23 +107,23 @@ class EmailPusher:
self.max_stream_ordering = max_stream_ordering
self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id):
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire
pass
def on_timer(self):
def on_timer(self) -> None:
self.timed_call = None
self._start_processing()
def _start_processing(self):
def _start_processing(self) -> None:
if self._is_processing:
return
run_as_background_process("emailpush.process", self._process)
def _pause_processing(self):
def _pause_processing(self) -> None:
"""Used by tests to temporarily pause processing of events.
Asserts that its not currently processing.
@@ -130,25 +131,26 @@ class EmailPusher:
assert not self._is_processing
self._is_processing = True
def _resume_processing(self):
def _resume_processing(self) -> None:
"""Used by tests to resume processing of events after pausing.
"""
assert self._is_processing
self._is_processing = False
self._start_processing()
async def _process(self):
async def _process(self) -> None:
# we should never get here if we are already processing
assert not self._is_processing
try:
self._is_processing = True
if self.throttle_params is None:
if not self._inited:
# this is our first loop: load up the throttle params
self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id
)
self._inited = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
@@ -163,17 +165,19 @@ class EmailPusher:
finally:
self._is_processing = False
async def _unsafe_process(self):
async def _unsafe_process(self) -> None:
"""
Main logic of the push loop without the wrapper function that sets
up logging, measures and guards against multiple instances of it
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
unprocessed = await fn(self.user_id, start, self.max_stream_ordering)
assert self.max_stream_ordering is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering
)
soonest_due_at = None
soonest_due_at = None # type: Optional[int]
if not unprocessed:
await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
@@ -230,7 +234,9 @@ class EmailPusher:
self.seconds_until(soonest_due_at), self.on_timer
)
async def save_last_stream_ordering_and_success(self, last_stream_ordering):
async def save_last_stream_ordering_and_success(
self, last_stream_ordering: Optional[int]
) -> None:
if last_stream_ordering is None:
# This happens if we haven't yet processed anything
return
@@ -248,28 +254,30 @@ class EmailPusher:
# lets just stop and return.
self.on_stop()
def seconds_until(self, ts_msec):
def seconds_until(self, ts_msec: int) -> float:
secs = (ts_msec - self.clock.time_msec()) / 1000
return max(secs, 0)
def get_room_throttle_ms(self, room_id):
def get_room_throttle_ms(self, room_id: str) -> int:
if room_id in self.throttle_params:
return self.throttle_params[room_id]["throttle_ms"]
else:
return 0
def get_room_last_sent_ts(self, room_id):
def get_room_last_sent_ts(self, room_id: str) -> int:
if room_id in self.throttle_params:
return self.throttle_params[room_id]["last_sent_ts"]
else:
return 0
def room_ready_to_notify_at(self, room_id):
def room_ready_to_notify_at(self, room_id: str) -> int:
"""
Determines whether throttling should prevent us from sending an email
for the given room
Returns: The timestamp when we are next allowed to send an email notif
for this room
Returns:
The timestamp when we are next allowed to send an email notif
for this room
"""
last_sent_ts = self.get_room_last_sent_ts(room_id)
throttle_ms = self.get_room_throttle_ms(room_id)
@@ -277,7 +285,9 @@ class EmailPusher:
may_send_at = last_sent_ts + throttle_ms
return may_send_at
async def sent_notif_update_throttle(self, room_id, notified_push_action):
async def sent_notif_update_throttle(
self, room_id: str, notified_push_action: dict
) -> None:
# We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
@@ -315,7 +325,7 @@ class EmailPusher:
self.pusher_id, room_id, self.throttle_params[room_id]
)
async def send_notification(self, push_actions, reason):
async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
logger.info("Sending notif email for user %r", self.user_id)
await self.mailer.send_notification_mail(

View File

@@ -14,19 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
from prometheus_client import Counter
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
from synapse.push import Pusher, PusherConfigException
from synapse.types import RoomStreamToken
from . import push_rule_evaluator, push_tools
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
http_push_processed_counter = Counter(
@@ -50,24 +56,18 @@ http_badges_failed_counter = Counter(
)
class HttpPusher:
class HttpPusher(Pusher):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
MAX_BACKOFF_SEC = 60 * 60
# This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
def __init__(self, hs, pusherdict):
self.hs = hs
self.store = self.hs.get_datastore()
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
super().__init__(hs, pusherdict)
self.storage = self.hs.get_storage()
self.clock = self.hs.get_clock()
self.state_handler = self.hs.get_state_handler()
self.user_id = pusherdict["user_name"]
self.app_id = pusherdict["app_id"]
self.app_display_name = pusherdict["app_display_name"]
self.device_display_name = pusherdict["device_display_name"]
self.pushkey = pusherdict["pushkey"]
self.pushkey_ts = pusherdict["ts"]
self.data = pusherdict["data"]
self.last_stream_ordering = pusherdict["last_stream_ordering"]
@@ -77,13 +77,6 @@ class HttpPusher:
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
# should honour this rather than just looking for anything higher
# because of potential out-of-order event serialisation. This starts
# off as None though as we don't know any better.
self.max_stream_ordering = None
if "data" not in pusherdict:
raise PusherConfigException("No 'data' key for HTTP pusher")
self.data = pusherdict["data"]
@@ -97,26 +90,39 @@ class HttpPusher:
if self.data is None:
raise PusherConfigException("data can not be null for HTTP pusher")
# Validate that there's a URL and it is of the proper form.
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
self.url = self.data["url"]
url = self.data["url"]
if not isinstance(url, str):
raise PusherConfigException("'url' must be a string")
url_parts = urllib.parse.urlparse(url)
# Note that the specification also says the scheme must be HTTPS, but
# it isn't up to the homeserver to verify that.
if url_parts.path != "/_matrix/push/v1/notify":
raise PusherConfigException(
"'url' must have a path of '/_matrix/push/v1/notify'"
)
self.url = url
self.http_client = hs.get_proxied_blacklisted_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
def on_started(self, should_check_for_notifs):
def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
should_check_for_notifs (bool): Whether we should immediately
should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
if should_check_for_notifs:
self._start_processing()
def on_new_notifications(self, max_token: RoomStreamToken):
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
@@ -127,14 +133,14 @@ class HttpPusher:
)
self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id):
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
async def _update_badge(self):
async def _update_badge(self) -> None:
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
badge = await push_tools.get_badge_count(
@@ -144,10 +150,10 @@ class HttpPusher:
)
await self._send_badge(badge)
def on_timer(self):
def on_timer(self) -> None:
self._start_processing()
def on_stop(self):
def on_stop(self) -> None:
if self.timed_call:
try:
self.timed_call.cancel()
@@ -155,13 +161,13 @@ class HttpPusher:
pass
self.timed_call = None
def _start_processing(self):
def _start_processing(self) -> None:
if self._is_processing:
return
run_as_background_process("httppush.process", self._process)
async def _process(self):
async def _process(self) -> None:
# we should never get here if we are already processing
assert not self._is_processing
@@ -180,7 +186,7 @@ class HttpPusher:
finally:
self._is_processing = False
async def _unsafe_process(self):
async def _unsafe_process(self) -> None:
"""
Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to
@@ -188,6 +194,7 @@ class HttpPusher:
"""
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
assert self.max_stream_ordering is not None
unprocessed = await fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
@@ -257,17 +264,12 @@ class HttpPusher:
)
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
await self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
self.user_id,
self.last_stream_ordering,
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
# lets just stop and return.
self.on_stop()
return
self.failing_since = None
await self.store.update_pusher_failing_since(
@@ -283,7 +285,7 @@ class HttpPusher:
)
break
async def _process_one(self, push_action):
async def _process_one(self, push_action: dict) -> bool:
if "notify" not in push_action["actions"]:
return True
@@ -314,7 +316,9 @@ class HttpPusher:
await self.hs.remove_pusher(self.app_id, pk, self.user_id)
return True
async def _build_notification_dict(self, event, tweaks, badge):
async def _build_notification_dict(
self, event: EventBase, tweaks: Dict[str, bool], badge: int
) -> Dict[str, Any]:
priority = "low"
if (
event.type == EventTypes.Encrypted
@@ -344,9 +348,7 @@ class HttpPusher:
}
return d
ctx = await push_tools.get_context_for_event(
self.storage, self.state_handler, event, self.user_id
)
ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
d = {
"notification": {
@@ -386,7 +388,9 @@ class HttpPusher:
return d
async def dispatch_push(self, event, tweaks, badge):
async def dispatch_push(
self, event: EventBase, tweaks: Dict[str, bool], badge: int
) -> Union[bool, Iterable[str]]:
notification_dict = await self._build_notification_dict(event, tweaks, badge)
if not notification_dict:
return []

View File

@@ -19,7 +19,7 @@ import logging
import urllib.parse
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Iterable, List, TypeVar
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
import bleach
import jinja2
@@ -27,16 +27,20 @@ import jinja2
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.config.emailconfig import EmailSubjectConfig
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import (
calculate_room_name,
descriptor_from_member_events,
name_from_member_event,
)
from synapse.types import UserID
from synapse.types import StateMap, UserID
from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -93,7 +97,13 @@ ALLOWED_ATTRS = {
class Mailer:
def __init__(self, hs, app_name, template_html, template_text):
def __init__(
self,
hs: "HomeServer",
app_name: str,
template_html: jinja2.Template,
template_text: jinja2.Template,
):
self.hs = hs
self.template_html = template_html
self.template_text = template_text
@@ -108,17 +118,19 @@ class Mailer:
logger.info("Created Mailer for app_name %s" % app_name)
async def send_password_reset_mail(self, email_address, token, client_secret, sid):
async def send_password_reset_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a password reset link to a user
Args:
email_address (str): Email address we're sending the password
email_address: Email address we're sending the password
reset to
token (str): Unique token generated by the server to verify
token: Unique token generated by the server to verify
the email was received
client_secret (str): Unique token generated by the client to
client_secret: Unique token generated by the client to
group together multiple email sending attempts
sid (str): The generated session ID
sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -136,17 +148,19 @@ class Mailer:
template_vars,
)
async def send_registration_mail(self, email_address, token, client_secret, sid):
async def send_registration_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a registration confirmation link to a user
Args:
email_address (str): Email address we're sending the registration
email_address: Email address we're sending the registration
link to
token (str): Unique token generated by the server to verify
token: Unique token generated by the server to verify
the email was received
client_secret (str): Unique token generated by the client to
client_secret: Unique token generated by the client to
group together multiple email sending attempts
sid (str): The generated session ID
sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -164,18 +178,20 @@ class Mailer:
template_vars,
)
async def send_add_threepid_mail(self, email_address, token, client_secret, sid):
async def send_add_threepid_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a validation link to a user for adding a 3pid to their account
Args:
email_address (str): Email address we're sending the validation link to
email_address: Email address we're sending the validation link to
token (str): Unique token generated by the server to verify the email was received
token: Unique token generated by the server to verify the email was received
client_secret (str): Unique token generated by the client to group together
client_secret: Unique token generated by the client to group together
multiple email sending attempts
sid (str): The generated session ID
sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -194,8 +210,13 @@ class Mailer:
)
async def send_notification_mail(
self, app_id, user_id, email_address, push_actions, reason
):
self,
app_id: str,
user_id: str,
email_address: str,
push_actions: Iterable[Dict[str, Any]],
reason: Dict[str, Any],
) -> None:
"""Send email regarding a user's room notifications"""
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
@@ -203,7 +224,7 @@ class Mailer:
[pa["event_id"] for pa in push_actions]
)
notifs_by_room = {}
notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]]
for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
@@ -262,7 +283,9 @@ class Mailer:
await self.send_email(email_address, summary_text, template_vars)
async def send_email(self, email_address, subject, extra_template_vars):
async def send_email(
self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
) -> None:
"""Send an email with the given information and template text"""
try:
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@@ -315,8 +338,13 @@ class Mailer:
)
async def get_room_vars(
self, room_id, user_id, notifs, notif_events, room_state_ids
):
self,
room_id: str,
user_id: str,
notifs: Iterable[Dict[str, Any]],
notif_events: Dict[str, EventBase],
room_state_ids: StateMap[str],
) -> Dict[str, Any]:
# Check if one of the notifs is an invite event for the user.
is_invite = False
for n in notifs:
@@ -334,7 +362,7 @@ class Mailer:
"notifs": [],
"invite": is_invite,
"link": self.make_room_link(room_id),
}
} # type: Dict[str, Any]
if not is_invite:
for n in notifs:
@@ -365,7 +393,13 @@ class Mailer:
return room_vars
async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
async def get_notif_vars(
self,
notif: Dict[str, Any],
user_id: str,
notif_event: EventBase,
room_state_ids: StateMap[str],
) -> Dict[str, Any]:
results = await self.store.get_events_around(
notif["room_id"],
notif["event_id"],
@@ -391,7 +425,9 @@ class Mailer:
return ret
async def get_message_vars(self, notif, event, room_state_ids):
async def get_message_vars(
self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
) -> Optional[Dict[str, Any]]:
if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
return None
@@ -432,7 +468,9 @@ class Mailer:
return ret
def add_text_message_vars(self, messagevars, event):
def add_text_message_vars(
self, messagevars: Dict[str, Any], event: EventBase
) -> None:
msgformat = event.content.get("format")
messagevars["format"] = msgformat
@@ -445,15 +483,18 @@ class Mailer:
elif body:
messagevars["body_text_html"] = safe_text(body)
return messagevars
def add_image_message_vars(self, messagevars, event):
def add_image_message_vars(
self, messagevars: Dict[str, Any], event: EventBase
) -> None:
messagevars["image_url"] = event.content["url"]
return messagevars
async def make_summary_text(
self, notifs_by_room, room_state_ids, notif_events, user_id, reason
self,
notifs_by_room: Dict[str, List[Dict[str, Any]]],
room_state_ids: Dict[str, StateMap[str]],
notif_events: Dict[str, EventBase],
user_id: str,
reason: Dict[str, Any],
):
if len(notifs_by_room) == 1:
# Only one room has new stuff
@@ -580,7 +621,7 @@ class Mailer:
"app": self.app_name,
}
def make_room_link(self, room_id):
def make_room_link(self, room_id: str) -> str:
if self.hs.config.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
elif self.app_name == "Vector":
@@ -590,7 +631,7 @@ class Mailer:
base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)
def make_notif_link(self, notif):
def make_notif_link(self, notif: Dict[str, str]) -> str:
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
@@ -606,7 +647,9 @@ class Mailer:
else:
return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
def make_unsubscribe_link(self, user_id, app_id, email_address):
def make_unsubscribe_link(
self, user_id: str, app_id: str, email_address: str
) -> str:
params = {
"access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
"app_id": app_id,
@@ -620,7 +663,7 @@ class Mailer:
)
def safe_markup(raw_html):
def safe_markup(raw_html: str) -> jinja2.Markup:
return jinja2.Markup(
bleach.linkify(
bleach.clean(
@@ -635,7 +678,7 @@ def safe_markup(raw_html):
)
def safe_text(raw_text):
def safe_text(raw_text: str) -> jinja2.Markup:
"""
Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it.
@@ -655,7 +698,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]:
return ret
def string_ordinal_total(s):
def string_ordinal_total(s: str) -> int:
tot = 0
for c in s:
tot += ord(c)

View File

@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
@@ -46,7 +49,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
return badge
async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
async def get_context_for_event(
storage: Storage, ev: EventBase, user_id: str
) -> Dict[str, str]:
ctx = {}
room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)

View File

@@ -14,25 +14,31 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from synapse.push import Pusher
from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher
from synapse.push.mailer import Mailer
from .httppusher import HttpPusher
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class PusherFactory:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.config = hs.config
self.pusher_types = {"http": HttpPusher}
self.pusher_types = {
"http": HttpPusher
} # type: Dict[str, Callable[[HomeServer, dict], Pusher]]
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
self.mailers = {} # app_name -> Mailer
self.mailers = {} # type: Dict[str, Mailer]
self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text
@@ -41,7 +47,7 @@ class PusherFactory:
logger.info("defined email pusher type")
def create_pusher(self, pusherdict):
def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
kind = pusherdict["kind"]
f = self.pusher_types.get(kind, None)
if not f:
@@ -49,7 +55,9 @@ class PusherFactory:
logger.debug("creating %s pusher for %r", kind, pusherdict)
return f(self.hs, pusherdict)
def _create_email_pusher(self, _hs, pusherdict):
def _create_email_pusher(
self, _hs: "HomeServer", pusherdict: Dict[str, Any]
) -> EmailPusher:
app_name = self._app_name_from_pusherdict(pusherdict)
mailer = self.mailers.get(app_name)
if not mailer:
@@ -62,7 +70,7 @@ class PusherFactory:
self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer)
def _app_name_from_pusherdict(self, pusherdict):
def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str:
data = pusherdict["data"]
if isinstance(data, dict):

View File

@@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Union
from typing import TYPE_CHECKING, Any, Dict, Optional
from prometheus_client import Gauge
@@ -23,9 +23,7 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher
from synapse.push import Pusher, PusherConfigException
from synapse.push.pusher import PusherFactory
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
@@ -77,7 +75,7 @@ class PusherPool:
self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
def start(self):
"""Starts the pushers off in a background process.
@@ -99,11 +97,11 @@ class PusherPool:
lang,
data,
profile_tag="",
):
) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool
Returns:
EmailPusher|HttpPusher
The newly created pusher.
"""
time_now_msec = self.clock.time_msec()
@@ -267,17 +265,19 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_receipts")
async def start_pusher_by_id(self, app_id, pushkey, user_id):
async def start_pusher_by_id(
self, app_id: str, pushkey: str, user_id: str
) -> Optional[Pusher]:
"""Look up the details for the given pusher, and start it
Returns:
EmailPusher|HttpPusher|None: The pusher started, if any
The pusher started, if any
"""
if not self._should_start_pushers:
return
return None
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
return None
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
@@ -303,19 +303,19 @@ class PusherPool:
logger.info("Started pushers")
async def _start_pusher(self, pusherdict):
async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
"""Start the given pusher
Args:
pusherdict (dict): dict with the values pulled from the db table
pusherdict: dict with the values pulled from the db table
Returns:
EmailPusher|HttpPusher
The newly created pusher or None.
"""
if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"]
):
return
return None
try:
p = self.pusher_factory.create_pusher(pusherdict)
@@ -328,15 +328,15 @@ class PusherPool:
pusherdict.get("pushkey"),
e,
)
return
return None
except Exception:
logger.exception(
"Couldn't start pusher id %i: caught Exception", pusherdict["id"],
)
return
return None
if not p:
return
return None
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])

View File

@@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
assert self.METHOD in ("PUT", "POST", "GET")
self._replication_secret = None
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret
def _check_auth(self, request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if len(auth_headers) > 1:
raise RuntimeError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
received_secret = parts[1].decode("ascii")
if self._replication_secret == received_secret:
# Success!
return
raise RuntimeError("Invalid Authorization header.")
@abc.abstractmethod
async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request.
@@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
replication_secret = None
if hs.config.worker.worker_replication_secret:
replication_secret = hs.config.worker.worker_replication_secret.encode(
"ascii"
)
@trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
async def send_request(instance_name="master", **kwargs):
@@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# the master, and so whether we should clean up or not.
while True:
headers = {} # type: Dict[bytes, List[bytes]]
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
@@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""
url_args = list(self.PATH_ARGS)
handler = self._handle_request
method = self.METHOD
if self.CACHE:
handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths(
method, [pattern], handler, self.__class__.__name__,
method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
)
def _cached_handler(self, request, txn_id, **kwargs):
def _check_auth_and_handle(self, request, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
@@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# We just use the txn_id here, but we probably also want to use the
# other PATH_ARGS as well.
assert self.CACHE
# Check the authorization headers before handling the request.
if self._replication_secret:
self._check_auth(request)
return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
if self.CACHE:
txn_id = kwargs.pop("txn_id")
return self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs
)
return self._handle_request(request, **kwargs)

View File

@@ -320,9 +320,9 @@ class UserRestServletV2(RestServlet):
data={},
)
if "avatar_url" in body and type(body["avatar_url"]) == str:
if "avatar_url" in body and isinstance(body["avatar_url"], str):
await self.profile_handler.set_avatar_url(
user_id, requester, body["avatar_url"], True
target_user, requester, body["avatar_url"], True
)
ret = await self.admin_handler.get_user(target_user)
@@ -420,6 +420,9 @@ class UserRegisterServlet(RestServlet):
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
if "mac" not in body:
raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
got_mac = body["mac"]
want_mac_builder = hmac.new(

View File

@@ -458,7 +458,7 @@ class RegisterRestServlet(RestServlet):
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled")
raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
# Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)

View File

@@ -155,6 +155,11 @@ def add_file_headers(request, media_type, file_size, upload_name):
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
request.setHeader(b"Content-Length", b"%d" % (file_size,))
# Tell web crawlers to not index, archive, or follow links in media. This
# should help to prevent things in the media repo from showing up in web
# search results.
request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
# separators as defined in RFC2616. SP and HT are handled separately.
# see _can_encode_filename_as_token.

View File

@@ -675,7 +675,11 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
def decode_and_calc_og(body, media_uri, request_encoding=None):
def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
# If there's no body, nothing useful is going to be found.
if not body:
return {}
from lxml import etree
try:

View File

@@ -44,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader(b"Content-Length").decode("ascii")
content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:

View File

@@ -381,6 +381,28 @@ class HomeServer(metaclass=abc.ABCMeta):
)
return MatrixFederationHttpClient(self, tls_client_options_factory)
@cache_in_self
def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
"""
An HTTP client that uses configured HTTP(S) proxies and blacklists IPs
based on the IP range blacklist.
"""
return SimpleHttpClient(
self,
ip_blacklist=self.config.ip_range_blacklist,
use_proxy=True,
)
@cache_in_self
def get_federation_http_client(self) -> MatrixFederationHttpClient:
"""
An HTTP client for federation.
"""
tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
self.config
)
return MatrixFederationHttpClient(self, tls_client_options_factory)
@cache_in_self
def get_room_creation_handler(self) -> RoomCreationHandler:
return RoomCreationHandler(self)

View File

@@ -783,7 +783,7 @@ class StateResolutionStore:
)
def get_auth_chain_difference(
self, state_sets: List[Set[str]]
self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -796,4 +796,4 @@ class StateResolutionStore:
An awaitable that resolves to a set of event IDs.
"""
return self.store.get_auth_chain_difference(state_sets)
return self.store.get_auth_chain_difference(room_id, state_sets)

View File

@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.types import Collection, MutableStateMap, StateMap
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
auth_diff = await _get_auth_chain_difference(
room_id, state_sets, event_map, state_res_store
)
full_conflicted_set = set(
itertools.chain(
@@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
async def _get_auth_chain_difference(
room_id: str,
state_sets: Sequence[StateMap[str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
@@ -252,9 +255,90 @@ async def _get_auth_chain_difference(
Set of event IDs
"""
# The `StateResolutionStore.get_auth_chain_difference` function assumes that
# all events passed to it (and their auth chains) have been persisted
# previously. This is not the case for any events in the `event_map`, and so
# we need to manually handle those events.
#
# We do this by:
# 1. calculating the auth chain difference for the state sets based on the
# events in `event_map` alone
# 2. replacing any events in the state_sets that are also in `event_map`
# with their auth events (recursively), and then calling
# `store.get_auth_chain_difference` as normal
# 3. adding the results of 1 and 2 together.
# Map from event ID in `event_map` to their auth event IDs, and their auth
# event IDs if they appear in the `event_map`. This is the intersection of
# the event's auth chain with the events in the `event_map` *plus* their
# auth event IDs.
events_to_auth_chain = {} # type: Dict[str, Set[str]]
for event in event_map.values():
chain = {event.event_id}
events_to_auth_chain[event.event_id] = chain
to_search = [event]
while to_search:
for auth_id in to_search.pop().auth_event_ids():
chain.add(auth_id)
auth_event = event_map.get(auth_id)
if auth_event:
to_search.append(auth_event)
# We now a) calculate the auth chain difference for the unpersisted events
# and b) work out the state sets to pass to the store.
#
# Note: If the `event_map` is empty (which is the common case), we can do a
# much simpler calculation.
if event_map:
# The list of state sets to pass to the store, where each state set is a set
# of the event ids making up the state. This is similar to `state_sets`,
# except that (a) we only have event ids, not the complete
# ((type, state_key)->event_id) mappings; and (b) we have stripped out
# unpersisted events and replaced them with the persisted events in
# their auth chain.
state_sets_ids = [] # type: List[Set[str]]
# For each state set, the unpersisted event IDs reachable (by their auth
# chain) from the events in that set.
unpersisted_set_ids = [] # type: List[Set[str]]
for state_set in state_sets:
set_ids = set() # type: Set[str]
state_sets_ids.append(set_ids)
unpersisted_ids = set() # type: Set[str]
unpersisted_set_ids.append(unpersisted_ids)
for event_id in state_set.values():
event_chain = events_to_auth_chain.get(event_id)
if event_chain is not None:
# We have an event in `event_map`. We add all the auth
# events that it references (that aren't also in `event_map`).
set_ids.update(e for e in event_chain if e not in event_map)
# We also add the full chain of unpersisted event IDs
# referenced by this state set, so that we can work out the
# auth chain difference of the unpersisted events.
unpersisted_ids.update(e for e in event_chain if e in event_map)
else:
set_ids.add(event_id)
# The auth chain difference of the unpersisted events of the state sets
# is calculated by taking the difference between the union and
# intersections.
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
difference_from_event_map = union - intersection # type: Collection[str]
else:
difference_from_event_map = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
difference = await state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets]
room_id, state_sets_ids
)
difference.update(difference_from_event_map)
return difference

View File

@@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
async def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).

View File

@@ -566,6 +566,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_external_id",
)
async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
"""Look up external ids for the given user
Args:
mxid: the MXID to be looked up
Returns:
Tuples of (auth_provider, external_id)
"""
res = await self.db_pool.simple_select_list(
table="user_external_ids",
keyvalues={"user_id": mxid},
retcols=("auth_provider", "external_id"),
desc="get_external_ids_by_user",
)
return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self):
"""Counts all users registered on the homeserver."""
@@ -1066,6 +1083,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
self.db_pool.updates.register_background_index_update(
"user_external_ids_user_id_idx",
index_name="user_external_ids_user_id_idx",
table="user_external_ids",
columns=["user_id"],
unique=False,
)
async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.

View File

@@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5825, 'user_external_ids_user_id_idx', '{}');

View File

@@ -15,28 +15,56 @@
import importlib
import importlib.util
import itertools
from typing import Any, Iterable, Tuple, Type
import jsonschema
from synapse.config._base import ConfigError
from synapse.config._util import json_error_to_config_error
def load_module(provider):
def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
""" Loads a synapse module with its config
Take a dict with keys 'module' (the module name) and 'config'
(the config dict).
Args:
provider: a dict with keys 'module' (the module name) and 'config'
(the config dict).
config_path: the path within the config file. This will be used as a basis
for any error message.
Returns
Tuple of (provider class, parsed config object)
"""
modulename = provider.get("module")
if not isinstance(modulename, str):
raise ConfigError(
"expected a string", path=itertools.chain(config_path, ("module",))
)
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
module, clz = provider["module"].rsplit(".", 1)
module, clz = modulename.rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
module_config = provider.get("config")
try:
provider_config = provider_class.parse_config(provider.get("config"))
provider_config = provider_class.parse_config(module_config)
except jsonschema.ValidationError as e:
raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
except ConfigError as e:
raise _wrap_config_error(
"Failed to parse config for module %r" % (modulename,),
prefix=itertools.chain(config_path, ("config",)),
e=e,
)
except Exception as e:
raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
raise ConfigError(
"Failed to parse config for module %r" % (modulename,),
path=itertools.chain(config_path, ("config",)),
) from e
return provider_class, provider_config
@@ -56,3 +84,27 @@ def load_python_module(location: str):
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
return mod
def _wrap_config_error(
msg: str, prefix: Iterable[str], e: ConfigError
) -> "ConfigError":
"""Wrap a relative ConfigError with a new path
This is useful when we have a ConfigError with a relative path due to a problem
parsing part of the config, and we now need to set it in context.
"""
path = prefix
if e.path:
path = itertools.chain(prefix, e.path)
e1 = ConfigError(msg, path)
# ideally we would set the 'cause' of the new exception to the original exception;
# however now that we have merged the path into our own, the stringification of
# e will be incorrect, so instead we create a new exception with just the "msg"
# part.
e1.__cause__ = Exception(e.msg)
e1.__cause__.__cause__ = e.__cause__
return e1

View File

@@ -16,8 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
import jsonschema
from twisted.internet import defer
@@ -28,7 +26,7 @@ from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
from tests import unittest
from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
from tests.utils import setup_test_homeserver
user_localpart = "test_user"
@@ -42,21 +40,9 @@ def MockEvent(**kwargs):
class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_federation_resource = MockHttpResource()
self.mock_http_client = Mock(spec=[])
self.mock_http_client.put_json = DeferredMockCallable()
hs = yield setup_test_homeserver(
self.addCleanup,
federation_http_client=self.mock_http_client,
keyring=Mock(),
)
hs = setup_test_homeserver(self.addCleanup)
self.filtering = hs.get_filtering()
self.datastore = hs.get_datastore()
def test_errors_on_invalid_filters(self):

View File

@@ -17,30 +17,15 @@ from urllib.parse import parse_qs, urlparse
from mock import Mock, patch
import attr
import pymacaroons
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
from tests.test_utils import FakeResponse
from tests.unittest import HomeserverTestCase, override_config
@attr.s
class FakeResponse:
code = attr.ib()
body = attr.ib()
phrase = attr.ib()
def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
# These are a few constants that are used as config parameters in the tests.
ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"

View File

@@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
federation_http_client=None,
resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,

View File

@@ -580,7 +580,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Mock Synapse's http json post method to check for the internal bind call
post_json_get_json = Mock(return_value=make_awaitable(None))
self.hs.get_simple_http_client().post_json_get_json = post_json_get_json
self.hs.get_identity_handler().http_client.post_json_get_json = post_json_get_json
# Retrieve a UIA session ID
channel = self.uia_register(

View File

@@ -15,18 +15,20 @@
import json
from typing import Dict
from mock import ANY, Mock, call
from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import register_federation_servlets
# Some local users to test with
U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
servlets = [register_federation_servlets]
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
@@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return hs
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d
def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event

View File

@@ -18,6 +18,7 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfigException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import receipts
@@ -34,6 +35,11 @@ class HTTPPusherTests(HomeserverTestCase):
user_id = True
hijack_auth = False
def default_config(self):
config = super().default_config()
config["start_pushers"] = True
return config
def make_homeserver(self, reactor, clock):
self.push_attempts = []
@@ -46,15 +52,49 @@ class HTTPPusherTests(HomeserverTestCase):
m.post_json_get_json = post_json_get_json
config = self.default_config()
config["start_pushers"] = True
hs = self.setup_test_homeserver(
config=config, proxied_blacklisted_http_client=m
)
hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
return hs
def test_invalid_configuration(self):
"""Invalid push configurations should be rejected."""
# Register the user who gets notified
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
# Register the pusher
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple.token_id
def test_data(data):
self.get_failure(
self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
app_id="m.http",
app_display_name="HTTP Push Notifications",
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data=data,
),
PusherConfigException,
)
# Data must be provided with a URL.
test_data(None)
test_data({})
test_data({"url": 1})
# A bare domain name isn't accepted.
test_data({"url": "example.com"})
# A URL without a path isn't accepted.
test_data({"url": "http://example.com"})
# A url with an incorrect path isn't accepted.
test_data({"url": "http://example.com/foo"})
def test_sends_http(self):
"""
The HTTP pusher will send pushes for each message to a HTTP endpoint
@@ -84,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -119,7 +159,9 @@ class HTTPPusherTests(HomeserverTestCase):
# One push was attempted to be sent -- it'll be the first message
self.assertEqual(len(self.push_attempts), 1)
self.assertEqual(self.push_attempts[0][1], "example.com")
self.assertEqual(
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
)
self.assertEqual(
self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
)
@@ -139,7 +181,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Now it'll try and send the second push message, which will be the second one
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
self.assertEqual(
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
self.assertEqual(
self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
)
@@ -196,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -232,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
self.assertEqual(self.push_attempts[0][1], "example.com")
self.assertEqual(
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
)
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Add yet another person — we want to make this room not a 1:1
@@ -270,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
self.assertEqual(
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_one_to_one_only(self):
@@ -312,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -328,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority — this is a one-to-one room
self.assertEqual(len(self.push_attempts), 1)
self.assertEqual(self.push_attempts[0][1], "example.com")
self.assertEqual(
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
)
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Yet another user joins
@@ -347,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
self.assertEqual(
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
# check that this is high-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
@@ -394,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -410,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
self.assertEqual(self.push_attempts[0][1], "example.com")
self.assertEqual(
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
)
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Send another event, this time with no mention
@@ -419,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
self.assertEqual(
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
# check that this is high-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
@@ -467,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -487,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
self.assertEqual(self.push_attempts[0][1], "example.com")
self.assertEqual(
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
)
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Send another event, this time as someone without the power of @room
@@ -498,7 +556,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
self.assertEqual(
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
# check that this is high-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
@@ -572,7 +632,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -591,7 +651,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it
self.assertEqual(len(self.push_attempts), 1)
self.assertEqual(self.push_attempts[0][1], "example.com")
self.assertEqual(
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
)
# Check that the unread count for the room is 0
#

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
servlets = [
streams.register_servlets,
]
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
return d
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"

View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Tuple
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, make_request
from tests.unittest import override_config
logger = logging.getLogger(__name__)
class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
"""Test the authentication of HTTP calls between workers."""
servlets = [register.register_servlets]
def make_homeserver(self, reactor, clock):
config = self.default_config()
# This isn't a real configuration option but is used to provide the main
# homeserver and worker homeserver different options.
main_replication_secret = config.pop("main_replication_secret", None)
if main_replication_secret:
config["worker_replication_secret"] = main_replication_secret
return self.setup_test_homeserver(config=config)
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]:
"""Run the actual test:
1. Create a worker homeserver.
2. Start registration by providing a user/password.
3. Complete registration by providing dummy auth (this hits the main synapse).
4. Return the final request.
"""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
request_1, channel_1 = make_request(
self.reactor,
site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
self.assertEqual(request_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
return make_request(
self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
)
def test_no_auth(self):
"""With no authentication the request should finish.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 200)
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")
@override_config({"main_replication_secret": "my-secret"})
def test_missing_auth(self):
"""If the main process expects a secret that is not provided, an error results.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 500)
@override_config(
{
"main_replication_secret": "my-secret",
"worker_replication_secret": "wrong-secret",
}
)
def test_unauthorized(self):
"""If the main process receives the wrong secret, an error results.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 500)
@override_config({"worker_replication_secret": "my-secret"})
def test_authorized(self):
"""The request should finish when the worker provides the authentication header.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 200)
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")

View File

@@ -14,27 +14,20 @@
# limitations under the License.
import logging
from synapse.api.constants import LoginType
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Base class for tests of the replication streams"""
"""Test using one or more client readers for registration."""
servlets = [register.register_servlets]
def prepare(self, reactor, clock, hs):
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"

View File

@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "https://push.example.com/push"},
data={"url": "https://push.example.com/_matrix/push/v1/notify"},
)
)
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock.post_json_get_json.call_args[0][0],
"https://push.example.com/push",
"https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_not_called()
self.assertEqual(
http_client_mock1.post_json_get_json.call_args[0][0],
"https://push.example.com/push",
"https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock2.post_json_get_json.call_args[0][0],
"https://push.example.com/push",
"https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,

View File

@@ -561,7 +561,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": True,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
"avatar_url": None,
"avatar_url": "mxc://fibble/wibble",
}
)
@@ -578,6 +578,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
request, channel = self.make_request(
@@ -592,6 +593,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(True, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
def test_create_user(self):
"""
@@ -606,6 +608,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": False,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
"avatar_url": "mxc://fibble/wibble",
}
)
@@ -622,6 +625,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
request, channel = self.make_request(
@@ -636,6 +640,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(False, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -1256,7 +1261,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
data={"url": "https://example.com/_matrix/push/v1/notify"},
)
)

View File

@@ -2,7 +2,7 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,17 +17,23 @@
# limitations under the License.
import json
import re
import time
import urllib.parse
from typing import Any, Dict, Optional
from mock import patch
import attr
from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
from synapse.types import JsonDict
from tests.server import FakeSite, make_request
from tests.test_utils import FakeResponse
@attr.s
@@ -344,3 +350,111 @@ class RestHelper:
)
return channel.json_body
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
"""Log in (as a new user) via OIDC
Returns the result of the final token login.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
"public_base_url".
Also requires the login servlet and the OIDC callback resource to be mounted at
the normal places.
"""
client_redirect_url = "https://x"
# first hit the redirect url (which will issue a cookie and state)
_, channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
"/login/sso/redirect?redirectUrl=" + client_redirect_url,
)
# that will redirect to the OIDC IdP, but we skip that and go straight
# back to synapse's OIDC callback resource. However, we do need the "state"
# param that synapse passes to the IdP via query params, and the cookie that
# synapse passes to the client.
assert channel.code == 302
oauth_uri = channel.headers.getRawHeaders("Location")[0]
params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
redirect_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path,
urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
)
cookies = {}
for h in channel.headers.getRawHeaders("Set-Cookie"):
parts = h.split(";")
k, v = parts[0].split("=", maxsplit=1)
cookies[k] = v
# before we hit the callback uri, stub out some methods in the http client so
# that we don't have to handle full HTTPS requests.
# (expected url, json response) pairs, in the order we expect them.
expected_requests = [
# first we get a hit to the token endpoint, which we tell to return
# a dummy OIDC access token
("https://issuer.test/token", {"access_token": "TEST"}),
# and then one to the user_info endpoint, which returns our remote user id.
("https://issuer.test/userinfo", {"sub": remote_user_id}),
]
async def mock_req(method: str, uri: str, data=None, headers=None):
(expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri
resp = FakeResponse(
code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
)
return resp
with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
# now hit the callback URI with the right params and a made-up code
_, channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
redirect_uri,
custom_headers=[
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
],
)
# expect a confirmation page
assert channel.code == 200
# fish the matrix login token out of the body of the confirmation page
m = re.search(
'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
channel.result["body"].decode("utf-8"),
)
assert m
login_token = m.group(1)
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token and device id.
_, channel = make_request(
self.hs.get_reactor(),
self.site,
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
)
assert channel.code == 200
return channel.json_body
# an 'oidc_config' suitable for login_with_oidc.
TEST_OIDC_CONFIG = {
"enabled": True,
"discover": False,
"issuer": "https://issuer.test",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": "https://z",
"token_endpoint": "https://issuer.test/token",
"userinfo_endpoint": "https://issuer.test/userinfo",
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
}

View File

@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from twisted.internet.defer import succeed
@@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.site import SynapseRequest
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register
from synapse.types import JsonDict
from synapse.rest.oidc import OIDCResource
from synapse.types import JsonDict, UserID
from tests import unittest
from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
@@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]
def default_config(self):
config = super().default_config()
# we enable OIDC as a way of testing SSO flows
oidc_config = {}
oidc_config.update(TEST_OIDC_CONFIG)
oidc_config["allow_existing_users"] = True
config["oidc_config"] = oidc_config
config["public_baseurl"] = "https://synapse.test"
return config
def create_resource_dict(self):
resource_dict = super().create_resource_dict()
# mount the OIDC resource at /_synapse/oidc
resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
return resource_dict
def prepare(self, reactor, clock, hs):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.user_tok = self.login("test", self.user_pass)
def get_device_ids(self) -> List[str]:
def get_device_ids(self, access_token: str) -> List[str]:
# Get the list of devices so one can be deleted.
request, channel = self.make_request(
"GET", "devices", access_token=self.user_tok,
) # type: SynapseRequest, FakeChannel
# Get the ID of the device.
self.assertEqual(request.code, 200)
_, channel = self.make_request("GET", "devices", access_token=access_token,)
self.assertEqual(channel.code, 200)
return [d["device_id"] for d in channel.json_body["devices"]]
def delete_device(
self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
self,
access_token: str,
device: str,
expected_response: int,
body: Union[bytes, JsonDict] = b"",
) -> FakeChannel:
"""Delete an individual device."""
request, channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=self.user_tok
"DELETE", "devices/" + device, body, access_token=access_token,
) # type: SynapseRequest, FakeChannel
# Ensure the response is sane.
@@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""
Test user interactive authentication outside of registration.
"""
device_id = self.get_device_ids()[0]
device_id = self.get_device_ids(self.user_tok)[0]
# Attempt to delete this device.
# Returns a 401 as per the spec
channel = self.delete_device(device_id, 401)
channel = self.delete_device(self.user_tok, device_id, 401)
# Grab the session
session = channel.json_body["session"]
@@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow.
self.delete_device(
self.user_tok,
device_id,
200,
{
@@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
UIA - check that still works.
"""
device_id = self.get_device_ids()[0]
channel = self.delete_device(device_id, 401)
device_id = self.get_device_ids(self.user_tok)[0]
channel = self.delete_device(self.user_tok, device_id, 401)
session = channel.json_body["session"]
# Make another request providing the UI auth flow.
self.delete_device(
self.user_tok,
device_id,
200,
{
@@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login.
self.login("test", self.user_pass)
device_ids = self.get_device_ids()
device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device.
@@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login.
self.login("test", self.user_pass)
device_ids = self.get_device_ids()
device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device.
# Returns a 401 as per the spec
channel = self.delete_device(device_ids[0], 401)
channel = self.delete_device(self.user_tok, device_ids[0], 401)
# Grab the session
session = channel.json_body["session"]
@@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow, but try to delete the
# second device. This results in an error.
self.delete_device(
self.user_tok,
device_ids[1],
403,
{
@@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
},
)
def test_does_not_offer_password_for_sso_user(self):
login_resp = self.helper.login_via_oidc("username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]
# now call the device deletion API: we should get the option to auth with SSO
# and not password.
channel = self.delete_device(user_tok, device_id, 401)
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
def test_does_not_offer_sso_for_password_user(self):
# now call the device deletion API: we should get the option to auth with SSO
# and not password.
device_ids = self.get_device_ids(self.user_tok)
channel = self.delete_device(self.user_tok, device_ids[0], 401)
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.password"]}])
def test_offers_both_flows_for_upgraded_user(self):
"""A user that had a password and then logged in with SSO should get both flows
"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
device_ids = self.get_device_ids(self.user_tok)
channel = self.delete_device(self.user_tok, device_ids[0], 401)
flows = channel.json_body["flows"]
# we have no particular expectations of ordering here
self.assertIn({"stages": ["m.login.password"]}, flows)
self.assertIn({"stages": ["m.login.sso"]}, flows)
self.assertEqual(len(flows), 2)

View File

@@ -120,6 +120,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self):
self.hs.config.macaroon_secret_key = "test"

View File

@@ -362,3 +362,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"error": "Not found [b'example.com', b'12345']",
},
)
def test_x_robots_tag_header(self):
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
to not index, archive, or follow links in media.
"""
channel = self._req(b"inline; filename=out" + self.test_image.extension)
headers = channel.headers
self.assertEqual(
headers.getRawHeaders(b"X-Robots-Tag"),
[b"noindex, nofollow, noarchive, noimageindex"],
)

View File

@@ -18,41 +18,15 @@ import re
from mock import patch
import attr
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web._newclient import ResponseDone
from tests import unittest
from tests.server import FakeTransport
@attr.s
class FakeResponse:
version = attr.ib()
code = attr.ib()
phrase = attr.ib()
headers = attr.ib()
body = attr.ib()
absoluteURI = attr.ib()
@property
def request(self):
@attr.s
class FakeTransport:
absoluteURI = self.absoluteURI
return FakeTransport()
def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
class URLPreviewTests(unittest.HomeserverTestCase):
hijack_auth = True

View File

@@ -216,8 +216,9 @@ def make_request(
and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse")
):
if path.startswith(b"/"):
path = path[1:]
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
if not path.startswith(b"/"):
path = b"/" + path
@@ -258,6 +259,7 @@ def make_request(
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
req.parseCookies()
req.requestReceived(method, path, b"1.1")
if await_result:

View File

@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
from synapse.events import make_event_from_dict
from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
from synapse.state.v2 import (
_get_auth_chain_difference,
lexicographical_topological_sort,
resolve_events_with_store,
)
from synapse.types import EventID
from tests import unittest
@@ -587,6 +591,134 @@ class SimpleParamStateTestCase(unittest.TestCase):
self.assert_dict(self.expected_combined_state, state)
class AuthChainDifferenceTestCase(unittest.TestCase):
"""We test that `_get_auth_chain_difference` correctly handles unpersisted
events.
"""
def test_simple(self):
# Test getting the auth difference for a simple chain with a single
# unpersisted event:
#
# Unpersisted | Persisted
# |
# C -|-> B -> A
a = FakeEvent(
id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([], [])
b = FakeEvent(
id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([a.event_id], [])
c = FakeEvent(
id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([b.event_id], [])
persisted_events = {a.event_id: a, b.event_id: b}
unpersited_events = {c.event_id: c}
state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
store = TestStateResolutionStore(persisted_events)
diff_d = _get_auth_chain_difference(
ROOM_ID, state_sets, unpersited_events, store
)
difference = self.successResultOf(defer.ensureDeferred(diff_d))
self.assertEqual(difference, {c.event_id})
def test_multiple_unpersisted_chain(self):
# Test getting the auth difference for a simple chain with multiple
# unpersisted events:
#
# Unpersisted | Persisted
# |
# D -> C -|-> B -> A
a = FakeEvent(
id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([], [])
b = FakeEvent(
id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([a.event_id], [])
c = FakeEvent(
id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([b.event_id], [])
d = FakeEvent(
id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([c.event_id], [])
persisted_events = {a.event_id: a, b.event_id: b}
unpersited_events = {c.event_id: c, d.event_id: d}
state_sets = [
{"a": a.event_id, "b": b.event_id},
{"c": c.event_id, "d": d.event_id},
]
store = TestStateResolutionStore(persisted_events)
diff_d = _get_auth_chain_difference(
ROOM_ID, state_sets, unpersited_events, store
)
difference = self.successResultOf(defer.ensureDeferred(diff_d))
self.assertEqual(difference, {d.event_id, c.event_id})
def test_unpersisted_events_different_sets(self):
# Test getting the auth difference for with multiple unpersisted events
# in different branches:
#
# Unpersisted | Persisted
# |
# D --> C -|-> B -> A
# E ----^ -|---^
# |
a = FakeEvent(
id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([], [])
b = FakeEvent(
id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([a.event_id], [])
c = FakeEvent(
id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([b.event_id], [])
d = FakeEvent(
id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([c.event_id], [])
e = FakeEvent(
id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([c.event_id, b.event_id], [])
persisted_events = {a.event_id: a, b.event_id: b}
unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
state_sets = [
{"a": a.event_id, "b": b.event_id, "e": e.event_id},
{"c": c.event_id, "d": d.event_id},
]
store = TestStateResolutionStore(persisted_events)
diff_d = _get_auth_chain_difference(
ROOM_ID, state_sets, unpersited_events, store
)
difference = self.successResultOf(defer.ensureDeferred(diff_d))
self.assertEqual(difference, {d.event_id, e.event_id})
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
@@ -647,7 +779,7 @@ class TestStateResolutionStore:
return list(result)
def get_auth_chain_difference(self, auth_sets):
def get_auth_chain_difference(self, room_id, auth_sets):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])

View File

@@ -202,34 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# Now actually test that various combinations give the right result:
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}])
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b", "c"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "d", "e"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}])
)
self.assertSetEqual(difference, set())

View File

@@ -14,9 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from canonicaljson import json
from twisted.internet import defer
@@ -30,12 +27,10 @@ from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
def default_config(self):
config = super().default_config()
config["redaction_retention_period"] = "30d"
return self.setup_test_homeserver(
resource_for_federation=Mock(), federation_http_client=None, config=config
)
return config
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()

View File

@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
@@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
resource_for_federation=Mock(), federation_http_client=None
)
return hs
def prepare(self, reactor, clock, hs: TestHomeServer):
# We can't test the RoomMemberStore on its own without the other event

View File

@@ -56,7 +56,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
self.assertEquals(
self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -69,7 +69,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
self.assertEquals(
self.assertEqual(
desc,
"Tromsø lies in Northern Norway. The municipality has a population of"
" (2015) 72,066, but with an annual influx of students it has over 75,000"
@@ -96,7 +96,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
self.assertEquals(
self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -122,7 +122,7 @@ class PreviewTestCase(unittest.TestCase):
]
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
self.assertEquals(
self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -149,7 +149,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment(self):
html = """
@@ -164,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment2(self):
html = """
@@ -182,7 +182,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(
self.assertEqual(
og,
{
"og:title": "Foo",
@@ -203,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_missing_title(self):
html = """
@@ -216,7 +216,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_h1_as_title(self):
html = """
@@ -230,7 +230,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."})
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
def test_missing_title_and_broken_h1(self):
html = """
@@ -244,4 +244,9 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_empty(self):
html = ""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEqual(og, {})

View File

@@ -22,6 +22,11 @@ import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
import attr
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
TV = TypeVar("TV")
@@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]:
sys.unraisablehook = unraisablehook # type: ignore
return cleanup
@attr.s
class FakeResponse:
"""A fake twisted.web.IResponse object
there is a similar class at treq.test.test_response, but it lacks a `phrase`
attribute, and didn't support deliverBody until recently.
"""
# HTTP response code
code = attr.ib(type=int)
# HTTP response phrase (eg b'OK' for a 200)
phrase = attr.ib(type=bytes)
# body of the response
body = attr.ib(type=bytes)
def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))

View File

@@ -20,7 +20,7 @@ import hmac
import inspect
import logging
import time
from typing import Optional, Tuple, Type, TypeVar, Union, overload
from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload
from mock import Mock, patch
@@ -46,6 +46,7 @@ from synapse.logging.context import (
)
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
"""
Create a the root resource for the test server.
The default implementation creates a JsonResource and calls each function in
`servlets` to register servletes against it
The default calls `self.create_resource_dict` and builds the resultant dict
into a tree.
"""
resource = JsonResource(self.hs)
root_resource = Resource()
create_resource_tree(self.create_resource_dict(), root_resource)
return root_resource
def create_resource_dict(self) -> Dict[str, Resource]:
"""Create a resource tree for the test server
A resource tree is a mapping from path to twisted.web.resource.
The default implementation creates a JsonResource and calls each function in
`servlets` to register servlets against it.
"""
servlet_resource = JsonResource(self.hs)
for servlet in self.servlets:
servlet(self.hs, resource)
return resource
servlet(self.hs, servlet_resource)
return {
"/_matrix/client": servlet_resource,
"/_synapse/admin": servlet_resource,
}
def default_config(self):
"""
@@ -691,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`.
"""
def prepare(self, reactor, clock, homeserver):
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
return d
class TestTransportLayerServer(JsonResource):
"""A test implementation of TransportLayerServer
authenticates incoming requests as `other.example.com`.
"""
def __init__(self, hs):
super().__init__(hs)
class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
authenticator = Authenticator()
ratelimiter = FederationRateLimiter(
clock,
hs.get_clock(),
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
@@ -706,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000,
),
)
federation_server.register_servlets(
homeserver, self.resource, Authenticator(), ratelimiter
)
return super().prepare(reactor, clock, homeserver)
federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config):

View File

@@ -20,13 +20,12 @@ import os
import time
import uuid
import warnings
from inspect import getcallargs
from typing import Type
from urllib import parse as urlparse
from mock import Mock, patch
from twisted.internet import defer, reactor
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
@@ -34,7 +33,6 @@ from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
@@ -42,7 +40,6 @@ from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
#
@@ -344,32 +341,9 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash
fed = kwargs.get("resource_for_federation", None)
if fed:
register_federation_servlets(hs, fed)
return hs
def register_federation_servlets(hs, resource):
federation_server.register_servlets(
hs,
resource=resource,
authenticator=federation_server.Authenticator(hs),
ratelimiter=FederationRateLimiter(
hs.get_clock(), config=hs.config.rc_federation
),
)
def get_mock_call_args(pattern_func, mock_func):
""" Return the arguments the mock function was called with interpreted
by the pattern functions argument list.
"""
invoked_args, invoked_kargs = mock_func.call_args
return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
@@ -555,86 +529,6 @@ class MockClock:
return d
def _format_call(args, kwargs):
return ", ".join(
["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
)
class DeferredMockCallable:
"""A callable instance that stores a set of pending call expectations and
return values for them. It allows a unit test to assert that the given set
of function calls are eventually made, by awaiting on them to be called.
"""
def __init__(self):
self.expectations = []
self.calls = []
def __call__(self, *args, **kwargs):
self.calls.append((args, kwargs))
if not self.expectations:
raise ValueError(
"%r has no pending calls to handle call(%s)"
% (self, _format_call(args, kwargs))
)
for (call, result, d) in self.expectations:
if args == call[1] and kwargs == call[2]:
d.callback(None)
return result
failure = AssertionError(
"Was not expecting call(%s)" % (_format_call(args, kwargs))
)
for _, _, d in self.expectations:
try:
d.errback(failure)
except Exception:
pass
raise failure
def expect_call_and_return(self, call, result):
self.expectations.append((call, result, defer.Deferred()))
@defer.inlineCallbacks
def await_calls(self, timeout=1000):
deferred = defer.DeferredList(
[d for _, _, d in self.expectations], fireOnOneErrback=True
)
timer = reactor.callLater(
timeout / 1000,
deferred.errback,
AssertionError(
"%d pending calls left: %s"
% (
len([e for e in self.expectations if not e[2].called]),
[e for e in self.expectations if not e[2].called],
)
),
)
yield deferred
timer.cancel()
self.calls = []
def assert_had_no_calls(self):
if self.calls:
calls = self.calls
self.calls = []
raise AssertionError(
"Expected not to received any calls, got:\n"
+ "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
)
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room
"""