Merge commit 'cf7d3c90d' into dinsic
This commit is contained in:
1
changelog.d/8802.doc
Normal file
1
changelog.d/8802.doc
Normal file
@@ -0,0 +1 @@
|
||||
Fix the "Event persist rate" section of the included grafana dashboard by adding missing prometheus rules.
|
||||
1
changelog.d/8827.bugfix
Normal file
1
changelog.d/8827.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix bug where we might not correctly calculate the current state for rooms with multiple extremities.
|
||||
1
changelog.d/8837.bugfix
Normal file
1
changelog.d/8837.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a long standing bug in the register admin endpoint (`/_synapse/admin/v1/register`) when the `mac` field was not provided. The endpoint now properly returns a 400 error. Contributed by @edwargix.
|
||||
1
changelog.d/8853.feature
Normal file
1
changelog.d/8853.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add optional HTTP authentication to replication endpoints.
|
||||
1
changelog.d/8858.bugfix
Normal file
1
changelog.d/8858.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password.
|
||||
1
changelog.d/8861.misc
Normal file
1
changelog.d/8861.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove some unnecessary stubbing from unit tests.
|
||||
1
changelog.d/8862.bugfix
Normal file
1
changelog.d/8862.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a longstanding bug where a 500 error would be returned if the `Content-Length` header was not provided to the upload media resource.
|
||||
1
changelog.d/8864.misc
Normal file
1
changelog.d/8864.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove unused `FakeResponse` class from unit tests.
|
||||
1
changelog.d/8865.bugfix
Normal file
1
changelog.d/8865.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Add additional validation to pusher URLs to be compliant with the specification.
|
||||
1
changelog.d/8867.bugfix
Normal file
1
changelog.d/8867.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix the error code that is returned when a user tries to register on a homeserver on which new-user registration has been disabled.
|
||||
1
changelog.d/8872.bugfix
Normal file
1
changelog.d/8872.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug where `PUT /_synapse/admin/v2/users/<user_id>` failed to create a new user when `avatar_url` is specified. Bug introduced in Synapse v1.9.0.
|
||||
1
changelog.d/8873.doc
Normal file
1
changelog.d/8873.doc
Normal file
@@ -0,0 +1 @@
|
||||
Fix an error in the documentation for the SAML username mapping provider.
|
||||
1
changelog.d/8874.feature
Normal file
1
changelog.d/8874.feature
Normal file
@@ -0,0 +1 @@
|
||||
Improve the error messages printed as a result of configuration problems for extension modules.
|
||||
1
changelog.d/8879.misc
Normal file
1
changelog.d/8879.misc
Normal file
@@ -0,0 +1 @@
|
||||
Pass `room_id` to `get_auth_chain_difference`.
|
||||
1
changelog.d/8880.misc
Normal file
1
changelog.d/8880.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to push module.
|
||||
1
changelog.d/8881.misc
Normal file
1
changelog.d/8881.misc
Normal file
@@ -0,0 +1 @@
|
||||
Simplify logic for handling user-interactive-auth via single-sign-on servers.
|
||||
1
changelog.d/8882.misc
Normal file
1
changelog.d/8882.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to push module.
|
||||
1
changelog.d/8883.bugfix
Normal file
1
changelog.d/8883.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a 500 error when attempting to preview an empty HTML file.
|
||||
1
changelog.d/8887.feature
Normal file
1
changelog.d/8887.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add `X-Robots-Tag` header to stop web crawlers from indexing media.
|
||||
1
changelog.d/8891.doc
Normal file
1
changelog.d/8891.doc
Normal file
@@ -0,0 +1 @@
|
||||
Clarify comments around template directories in `sample_config.yaml`.
|
||||
@@ -58,3 +58,21 @@ groups:
|
||||
labels:
|
||||
type: "PDU"
|
||||
expr: 'synapse_federation_transaction_queue_pending_pdus + 0'
|
||||
|
||||
- record: synapse_storage_events_persisted_by_source_type
|
||||
expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_type="remote"})
|
||||
labels:
|
||||
type: remote
|
||||
- record: synapse_storage_events_persisted_by_source_type
|
||||
expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity="*client*",origin_type="local"})
|
||||
labels:
|
||||
type: local
|
||||
- record: synapse_storage_events_persisted_by_source_type
|
||||
expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity!="*client*",origin_type="local"})
|
||||
labels:
|
||||
type: bridges
|
||||
- record: synapse_storage_events_persisted_by_event_type
|
||||
expr: sum without(origin_entity, origin_type) (synapse_storage_events_persisted_events_sep)
|
||||
- record: synapse_storage_events_persisted_by_origin
|
||||
expr: sum without(type) (synapse_storage_events_persisted_events_sep)
|
||||
|
||||
|
||||
@@ -69,7 +69,8 @@ RUN apt-get update -qq -o Acquire::Languages=none \
|
||||
python3-setuptools \
|
||||
python3-venv \
|
||||
sqlite3 \
|
||||
libpq-dev
|
||||
libpq-dev \
|
||||
xmlsec1
|
||||
|
||||
COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb /
|
||||
|
||||
|
||||
@@ -2059,11 +2059,8 @@ sso:
|
||||
# - https://my.custom.client/
|
||||
|
||||
# Directory in which Synapse will try to find the template files below.
|
||||
# If not set, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
|
||||
# If you *do* uncomment it, you will need to make sure that all the templates
|
||||
# below are in the directory.
|
||||
# If not set, or the files named below are not found within the template
|
||||
# directory, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# Synapse will look for the following templates in this directory:
|
||||
#
|
||||
@@ -2293,9 +2290,8 @@ email:
|
||||
#validation_token_lifetime: 15m
|
||||
|
||||
# Directory in which Synapse will try to find the template files below.
|
||||
# If not set, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# Do not uncomment this setting unless you want to customise the templates.
|
||||
# If not set, or the files named below are not found within the template
|
||||
# directory, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# Synapse will look for the following templates in this directory:
|
||||
#
|
||||
@@ -2779,6 +2775,13 @@ opentracing:
|
||||
#
|
||||
#run_background_tasks_on: worker1
|
||||
|
||||
# A shared secret used by the replication APIs to authenticate HTTP requests
|
||||
# from workers.
|
||||
#
|
||||
# By default this is unused and traffic is not authenticated.
|
||||
#
|
||||
#worker_replication_secret: ""
|
||||
|
||||
|
||||
# Configuration for Redis when using workers. This *must* be enabled when
|
||||
# using workers (unless using old style direct TCP configuration).
|
||||
|
||||
@@ -116,11 +116,13 @@ comment these options out and use those specified by the module instead.
|
||||
|
||||
A custom mapping provider must specify the following methods:
|
||||
|
||||
* `__init__(self, parsed_config)`
|
||||
* `__init__(self, parsed_config, module_api)`
|
||||
- Arguments:
|
||||
- `parsed_config` - A configuration object that is the return value of the
|
||||
`parse_config` method. You should set any configuration options needed by
|
||||
the module here.
|
||||
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
|
||||
stable API available for extension modules.
|
||||
* `parse_config(config)`
|
||||
- This method should have the `@staticmethod` decoration.
|
||||
- Arguments:
|
||||
|
||||
@@ -89,7 +89,8 @@ shared configuration file.
|
||||
Normally, only a couple of changes are needed to make an existing configuration
|
||||
file suitable for use with workers. First, you need to enable an "HTTP replication
|
||||
listener" for the main process; and secondly, you need to enable redis-based
|
||||
replication. For example:
|
||||
replication. Optionally, a shared secret can be used to authenticate HTTP
|
||||
traffic between workers. For example:
|
||||
|
||||
|
||||
```yaml
|
||||
@@ -103,6 +104,9 @@ listeners:
|
||||
resources:
|
||||
- names: [replication]
|
||||
|
||||
# Add a random shared secret to authenticate traffic.
|
||||
worker_replication_secret: ""
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
```
|
||||
|
||||
5
mypy.ini
5
mypy.ini
@@ -43,6 +43,7 @@ files =
|
||||
synapse/handlers/room_member.py,
|
||||
synapse/handlers/room_member_worker.py,
|
||||
synapse/handlers/saml_handler.py,
|
||||
synapse/handlers/sso.py,
|
||||
synapse/handlers/sync.py,
|
||||
synapse/handlers/ui_auth,
|
||||
synapse/http/client.py,
|
||||
@@ -55,6 +56,10 @@ files =
|
||||
synapse/metrics,
|
||||
synapse/module_api,
|
||||
synapse/notifier.py,
|
||||
synapse/push/emailpusher.py,
|
||||
synapse/push/httppusher.py,
|
||||
synapse/push/mailer.py,
|
||||
synapse/push/pusher.py,
|
||||
synapse/push/pusherpool.py,
|
||||
synapse/push/push_rule_evaluator.py,
|
||||
synapse/replication,
|
||||
|
||||
@@ -19,7 +19,7 @@ import gc
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Iterator
|
||||
|
||||
from twisted.application import service
|
||||
from twisted.internet import defer, reactor
|
||||
@@ -90,7 +90,7 @@ class SynapseHomeServer(HomeServer):
|
||||
tls = listener_config.tls
|
||||
site_tag = listener_config.http_options.tag
|
||||
if site_tag is None:
|
||||
site_tag = port
|
||||
site_tag = str(port)
|
||||
|
||||
# We always include a health resource.
|
||||
resources = {"/health": HealthResource()}
|
||||
@@ -107,7 +107,10 @@ class SynapseHomeServer(HomeServer):
|
||||
logger.debug("Configuring additional resources: %r", additional_resources)
|
||||
module_api = self.get_module_api()
|
||||
for path, resmodule in additional_resources.items():
|
||||
handler_cls, config = load_module(resmodule)
|
||||
handler_cls, config = load_module(
|
||||
resmodule,
|
||||
("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
|
||||
)
|
||||
handler = handler_cls(config, module_api)
|
||||
if IResource.providedBy(handler):
|
||||
resource = handler
|
||||
@@ -342,7 +345,10 @@ def setup(config_options):
|
||||
"Synapse Homeserver", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\nERROR: %s\n" % (e,))
|
||||
sys.stderr.write("\n")
|
||||
for f in format_config_error(e):
|
||||
sys.stderr.write(f)
|
||||
sys.stderr.write("\n")
|
||||
sys.exit(1)
|
||||
|
||||
if not config:
|
||||
@@ -445,6 +451,38 @@ def setup(config_options):
|
||||
return hs
|
||||
|
||||
|
||||
def format_config_error(e: ConfigError) -> Iterator[str]:
|
||||
"""
|
||||
Formats a config error neatly
|
||||
|
||||
The idea is to format the immediate error, plus the "causes" of those errors,
|
||||
hopefully in a way that makes sense to the user. For example:
|
||||
|
||||
Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
|
||||
Failed to parse config for module 'JinjaOidcMappingProvider':
|
||||
invalid jinja template:
|
||||
unexpected end of template, expected 'end of print statement'.
|
||||
|
||||
Args:
|
||||
e: the error to be formatted
|
||||
|
||||
Returns: An iterator which yields string fragments to be formatted
|
||||
"""
|
||||
yield "Error in configuration"
|
||||
|
||||
if e.path:
|
||||
yield " at '%s'" % (".".join(e.path),)
|
||||
|
||||
yield ":\n %s" % (e.msg,)
|
||||
|
||||
e = e.__cause__
|
||||
indent = 1
|
||||
while e:
|
||||
indent += 1
|
||||
yield ":\n%s%s" % (" " * indent, str(e))
|
||||
e = e.__cause__
|
||||
|
||||
|
||||
class SynapseService(service.Service):
|
||||
"""
|
||||
A twisted Service class that will start synapse. Used to run synapse
|
||||
|
||||
@@ -24,7 +24,7 @@ from collections import OrderedDict
|
||||
from hashlib import sha256
|
||||
from io import open as io_open
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, List, MutableMapping, Optional
|
||||
from typing import Any, Callable, Iterable, List, MutableMapping, Optional
|
||||
|
||||
import attr
|
||||
import jinja2
|
||||
@@ -33,7 +33,17 @@ import yaml
|
||||
|
||||
|
||||
class ConfigError(Exception):
|
||||
pass
|
||||
"""Represents a problem parsing the configuration
|
||||
|
||||
Args:
|
||||
msg: A textual description of the error.
|
||||
path: Where appropriate, an indication of where in the configuration
|
||||
the problem lies.
|
||||
"""
|
||||
|
||||
def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
|
||||
self.msg = msg
|
||||
self.path = path
|
||||
|
||||
|
||||
# We split these messages out to allow packages to override with package
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from synapse.config import (
|
||||
account_validity,
|
||||
@@ -37,7 +37,10 @@ from synapse.config import (
|
||||
workers,
|
||||
)
|
||||
|
||||
class ConfigError(Exception): ...
|
||||
class ConfigError(Exception):
|
||||
def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
|
||||
self.msg = msg
|
||||
self.path = path
|
||||
|
||||
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
|
||||
MISSING_REPORT_STATS_SPIEL: str
|
||||
|
||||
@@ -38,14 +38,27 @@ def validate_config(
|
||||
try:
|
||||
jsonschema.validate(config, json_schema)
|
||||
except jsonschema.ValidationError as e:
|
||||
# copy `config_path` before modifying it.
|
||||
path = list(config_path)
|
||||
for p in list(e.path):
|
||||
if isinstance(p, int):
|
||||
path.append("<item %i>" % p)
|
||||
else:
|
||||
path.append(str(p))
|
||||
raise json_error_to_config_error(e, config_path)
|
||||
|
||||
raise ConfigError(
|
||||
"Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
|
||||
)
|
||||
|
||||
def json_error_to_config_error(
|
||||
e: jsonschema.ValidationError, config_path: Iterable[str]
|
||||
) -> ConfigError:
|
||||
"""Converts a json validation error to a user-readable ConfigError
|
||||
|
||||
Args:
|
||||
e: the exception to be converted
|
||||
config_path: the path within the config file. This will be used as a basis
|
||||
for the error message.
|
||||
|
||||
Returns:
|
||||
a ConfigError
|
||||
"""
|
||||
# copy `config_path` before modifying it.
|
||||
path = list(config_path)
|
||||
for p in list(e.path):
|
||||
if isinstance(p, int):
|
||||
path.append("<item %i>" % p)
|
||||
else:
|
||||
path.append(str(p))
|
||||
return ConfigError(e.message, path)
|
||||
|
||||
@@ -390,9 +390,8 @@ class EmailConfig(Config):
|
||||
#validation_token_lifetime: 15m
|
||||
|
||||
# Directory in which Synapse will try to find the template files below.
|
||||
# If not set, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# Do not uncomment this setting unless you want to customise the templates.
|
||||
# If not set, or the files named below are not found within the template
|
||||
# directory, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# Synapse will look for the following templates in this directory:
|
||||
#
|
||||
|
||||
@@ -66,7 +66,7 @@ class OIDCConfig(Config):
|
||||
(
|
||||
self.oidc_user_mapping_provider_class,
|
||||
self.oidc_user_mapping_provider_config,
|
||||
) = load_module(ump_config)
|
||||
) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
|
||||
|
||||
# Ensure loaded user mapping module has defined all necessary methods
|
||||
required_methods = [
|
||||
|
||||
@@ -36,7 +36,7 @@ class PasswordAuthProviderConfig(Config):
|
||||
providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
|
||||
|
||||
providers.extend(config.get("password_providers") or [])
|
||||
for provider in providers:
|
||||
for i, provider in enumerate(providers):
|
||||
mod_name = provider["module"]
|
||||
|
||||
# This is for backwards compat when the ldap auth provider resided
|
||||
@@ -45,7 +45,8 @@ class PasswordAuthProviderConfig(Config):
|
||||
mod_name = LDAP_PROVIDER
|
||||
|
||||
(provider_class, provider_config) = load_module(
|
||||
{"module": mod_name, "config": provider["config"]}
|
||||
{"module": mod_name, "config": provider["config"]},
|
||||
("password_providers", "<item %i>" % i),
|
||||
)
|
||||
|
||||
self.password_providers.append((provider_class, provider_config))
|
||||
|
||||
@@ -148,7 +148,7 @@ class ContentRepositoryConfig(Config):
|
||||
# them to be started.
|
||||
self.media_storage_providers = [] # type: List[tuple]
|
||||
|
||||
for provider_config in storage_providers:
|
||||
for i, provider_config in enumerate(storage_providers):
|
||||
# We special case the module "file_system" so as not to need to
|
||||
# expose FileStorageProviderBackend
|
||||
if provider_config["module"] == "file_system":
|
||||
@@ -157,7 +157,9 @@ class ContentRepositoryConfig(Config):
|
||||
".FileStorageProviderBackend"
|
||||
)
|
||||
|
||||
provider_class, parsed_config = load_module(provider_config)
|
||||
provider_class, parsed_config = load_module(
|
||||
provider_config, ("media_storage_providers", "<item %i>" % i)
|
||||
)
|
||||
|
||||
wrapper_config = MediaStorageProviderConfig(
|
||||
provider_config.get("store_local", False),
|
||||
|
||||
@@ -180,7 +180,7 @@ class _RoomDirectoryRule:
|
||||
self._alias_regex = glob_to_regex(alias)
|
||||
self._room_id_regex = glob_to_regex(room_id)
|
||||
except Exception as e:
|
||||
raise ConfigError("Failed to parse glob into regex: %s", e)
|
||||
raise ConfigError("Failed to parse glob into regex") from e
|
||||
|
||||
def matches(self, user_id, room_id, aliases):
|
||||
"""Tests if this rule matches the given user_id, room_id and aliases.
|
||||
|
||||
@@ -125,7 +125,7 @@ class SAML2Config(Config):
|
||||
(
|
||||
self.saml2_user_mapping_provider_class,
|
||||
self.saml2_user_mapping_provider_config,
|
||||
) = load_module(ump_dict)
|
||||
) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
|
||||
|
||||
# Ensure loaded user mapping module has defined all necessary methods
|
||||
# Note parse_config() is already checked during the call to load_module
|
||||
|
||||
@@ -33,13 +33,14 @@ class SpamCheckerConfig(Config):
|
||||
# spam checker, and thus was simply a dictionary with module
|
||||
# and config keys. Support this old behaviour by checking
|
||||
# to see if the option resolves to a dictionary
|
||||
self.spam_checkers.append(load_module(spam_checkers))
|
||||
self.spam_checkers.append(load_module(spam_checkers, ("spam_checker",)))
|
||||
elif isinstance(spam_checkers, list):
|
||||
for spam_checker in spam_checkers:
|
||||
for i, spam_checker in enumerate(spam_checkers):
|
||||
config_path = ("spam_checker", "<item %i>" % i)
|
||||
if not isinstance(spam_checker, dict):
|
||||
raise ConfigError("spam_checker syntax is incorrect")
|
||||
raise ConfigError("expected a mapping", config_path)
|
||||
|
||||
self.spam_checkers.append(load_module(spam_checker))
|
||||
self.spam_checkers.append(load_module(spam_checker, config_path))
|
||||
else:
|
||||
raise ConfigError("spam_checker syntax is incorrect")
|
||||
|
||||
|
||||
@@ -93,11 +93,8 @@ class SSOConfig(Config):
|
||||
# - https://my.custom.client/
|
||||
|
||||
# Directory in which Synapse will try to find the template files below.
|
||||
# If not set, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
|
||||
# If you *do* uncomment it, you will need to make sure that all the templates
|
||||
# below are in the directory.
|
||||
# If not set, or the files named below are not found within the template
|
||||
# directory, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# Synapse will look for the following templates in this directory:
|
||||
#
|
||||
|
||||
@@ -26,7 +26,9 @@ class ThirdPartyRulesConfig(Config):
|
||||
|
||||
provider = config.get("third_party_event_rules", None)
|
||||
if provider is not None:
|
||||
self.third_party_event_rules = load_module(provider)
|
||||
self.third_party_event_rules = load_module(
|
||||
provider, ("third_party_event_rules",)
|
||||
)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
|
||||
@@ -85,6 +85,9 @@ class WorkerConfig(Config):
|
||||
# The port on the main synapse for HTTP replication endpoint
|
||||
self.worker_replication_http_port = config.get("worker_replication_http_port")
|
||||
|
||||
# The shared secret used for authentication when connecting to the main synapse.
|
||||
self.worker_replication_secret = config.get("worker_replication_secret", None)
|
||||
|
||||
self.worker_name = config.get("worker_name", self.worker_app)
|
||||
|
||||
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
|
||||
@@ -185,6 +188,13 @@ class WorkerConfig(Config):
|
||||
# data). If not provided this defaults to the main process.
|
||||
#
|
||||
#run_background_tasks_on: worker1
|
||||
|
||||
# A shared secret used by the replication APIs to authenticate HTTP requests
|
||||
# from workers.
|
||||
#
|
||||
# By default this is unused and traffic is not authenticated.
|
||||
#
|
||||
#worker_replication_secret: ""
|
||||
"""
|
||||
|
||||
def read_arguments(self, args):
|
||||
|
||||
@@ -1546,7 +1546,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): homeserver
|
||||
resource (TransportLayerServer): resource class to register to
|
||||
resource (JsonResource): resource class to register to
|
||||
authenticator (Authenticator): authenticator to use
|
||||
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
|
||||
servlet_groups (list[str], optional): List of servlet groups to register.
|
||||
|
||||
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
|
||||
class BaseHandler:
|
||||
"""
|
||||
Common base class for the event handlers.
|
||||
|
||||
Deprecated: new code should not use this. Instead, Handler classes should define the
|
||||
fields they actually need. The utility methods should either be factored out to
|
||||
standalone helper functions, or to different Handler classes.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
||||
@@ -36,6 +36,8 @@ import attr
|
||||
import bcrypt
|
||||
import pymacaroons
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
@@ -193,9 +195,7 @@ class AuthHandler(BaseHandler):
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._password_enabled = hs.config.password_enabled
|
||||
self._sso_enabled = (
|
||||
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
|
||||
)
|
||||
self._password_localdb_enabled = hs.config.password_localdb_enabled
|
||||
|
||||
# we keep this as a list despite the O(N^2) implication so that we can
|
||||
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||
@@ -205,7 +205,7 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
||||
login_types = []
|
||||
if hs.config.password_localdb_enabled:
|
||||
if self._password_localdb_enabled:
|
||||
login_types.append(LoginType.PASSWORD)
|
||||
|
||||
for provider in self.password_providers:
|
||||
@@ -219,14 +219,6 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
self._supported_login_types = login_types
|
||||
|
||||
# Login types and UI Auth types have a heavy overlap, but are not
|
||||
# necessarily identical. Login types have SSO (and other login types)
|
||||
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
|
||||
ui_auth_types = login_types.copy()
|
||||
if self._sso_enabled:
|
||||
ui_auth_types.append(LoginType.SSO)
|
||||
self._supported_ui_auth_types = ui_auth_types
|
||||
|
||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||
# as per `rc_login.failed_attempts`.
|
||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||
@@ -339,7 +331,10 @@ class AuthHandler(BaseHandler):
|
||||
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
|
||||
|
||||
# build a list of supported flows
|
||||
flows = [[login_type] for login_type in self._supported_ui_auth_types]
|
||||
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
||||
requester.user
|
||||
)
|
||||
flows = [[login_type] for login_type in supported_ui_auth_types]
|
||||
|
||||
try:
|
||||
result, params, session_id = await self.check_ui_auth(
|
||||
@@ -351,7 +346,7 @@ class AuthHandler(BaseHandler):
|
||||
raise
|
||||
|
||||
# find the completed login type
|
||||
for login_type in self._supported_ui_auth_types:
|
||||
for login_type in supported_ui_auth_types:
|
||||
if login_type not in result:
|
||||
continue
|
||||
|
||||
@@ -367,6 +362,41 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
return params, session_id
|
||||
|
||||
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
|
||||
"""Get a list of the authentication types this user can use
|
||||
"""
|
||||
|
||||
ui_auth_types = set()
|
||||
|
||||
# if the HS supports password auth, and the user has a non-null password, we
|
||||
# support password auth
|
||||
if self._password_localdb_enabled and self._password_enabled:
|
||||
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
|
||||
if lookupres:
|
||||
_, password_hash = lookupres
|
||||
if password_hash:
|
||||
ui_auth_types.add(LoginType.PASSWORD)
|
||||
|
||||
# also allow auth from password providers
|
||||
for provider in self.password_providers:
|
||||
for t in provider.get_supported_login_types().keys():
|
||||
if t == LoginType.PASSWORD and not self._password_enabled:
|
||||
continue
|
||||
ui_auth_types.add(t)
|
||||
|
||||
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
||||
# from sso to mxid.
|
||||
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
|
||||
if await self.store.get_external_ids_by_user(user.to_string()):
|
||||
ui_auth_types.add(LoginType.SSO)
|
||||
|
||||
# Our CAS impl does not (yet) correctly register users in user_external_ids,
|
||||
# so always offer that if it's available.
|
||||
if self.hs.config.cas.cas_enabled:
|
||||
ui_auth_types.add(LoginType.SSO)
|
||||
|
||||
return ui_auth_types
|
||||
|
||||
def get_enabled_auth_types(self):
|
||||
"""Return the enabled user-interactive authentication types
|
||||
|
||||
@@ -1019,7 +1049,7 @@ class AuthHandler(BaseHandler):
|
||||
if result:
|
||||
return result
|
||||
|
||||
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
||||
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
|
||||
known_login_type = True
|
||||
|
||||
# we've already checked that there is a (valid) password field
|
||||
@@ -1293,15 +1323,14 @@ class AuthHandler(BaseHandler):
|
||||
)
|
||||
|
||||
async def complete_sso_ui_auth(
|
||||
self, registered_user_id: str, session_id: str, request: SynapseRequest,
|
||||
self, registered_user_id: str, session_id: str, request: Request,
|
||||
):
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
||||
Args:
|
||||
registered_user_id: The registered user ID to complete SSO login for.
|
||||
session_id: The ID of the user-interactive auth session.
|
||||
request: The request to complete.
|
||||
client_redirect_url: The URL to which to redirect the user at the end of the
|
||||
process.
|
||||
"""
|
||||
# Mark the stage of the authentication as successful.
|
||||
# Save the user who authenticated with SSO, this will be used to ensure
|
||||
@@ -1317,7 +1346,7 @@ class AuthHandler(BaseHandler):
|
||||
async def complete_sso_login(
|
||||
self,
|
||||
registered_user_id: str,
|
||||
request: SynapseRequest,
|
||||
request: Request,
|
||||
client_redirect_url: str,
|
||||
extra_attributes: Optional[JsonDict] = None,
|
||||
):
|
||||
@@ -1345,7 +1374,7 @@ class AuthHandler(BaseHandler):
|
||||
def _complete_sso_login(
|
||||
self,
|
||||
registered_user_id: str,
|
||||
request: SynapseRequest,
|
||||
request: Request,
|
||||
client_redirect_url: str,
|
||||
extra_attributes: Optional[JsonDict] = None,
|
||||
):
|
||||
|
||||
@@ -47,7 +47,7 @@ class IdentityHandler(BaseHandler):
|
||||
super().__init__(hs)
|
||||
|
||||
# An HTTP client for contacting trusted URLs.
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.http_client = SimpleHttpClient(hs)
|
||||
# An HTTP client for contacting identity servers specified by clients.
|
||||
self.blacklisting_http_client = SimpleHttpClient(
|
||||
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
|
||||
|
||||
@@ -674,6 +674,21 @@ class OidcHandler(BaseHandler):
|
||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||
return
|
||||
|
||||
# first check if we're doing a UIA
|
||||
if ui_auth_session_id:
|
||||
try:
|
||||
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
||||
except Exception as e:
|
||||
logger.exception("Could not extract remote user id")
|
||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||
return
|
||||
|
||||
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||
self._auth_provider_id, remote_user_id, ui_auth_session_id, request
|
||||
)
|
||||
|
||||
# otherwise, it's a login
|
||||
|
||||
# Pull out the user-agent and IP from the request.
|
||||
user_agent = request.get_user_agent("")
|
||||
ip_address = self.hs.get_ip_from_request(request)
|
||||
@@ -698,14 +713,9 @@ class OidcHandler(BaseHandler):
|
||||
extra_attributes = await get_extra_attributes(userinfo, token)
|
||||
|
||||
# and finally complete the login
|
||||
if ui_auth_session_id:
|
||||
await self._auth_handler.complete_sso_ui_auth(
|
||||
user_id, ui_auth_session_id, request
|
||||
)
|
||||
else:
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id, request, client_redirect_url, extra_attributes
|
||||
)
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id, request, client_redirect_url, extra_attributes
|
||||
)
|
||||
|
||||
def _generate_oidc_session_token(
|
||||
self,
|
||||
@@ -856,14 +866,11 @@ class OidcHandler(BaseHandler):
|
||||
The mxid of the user
|
||||
"""
|
||||
try:
|
||||
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
|
||||
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
||||
except Exception as e:
|
||||
raise MappingException(
|
||||
"Failed to extract subject from OIDC response: %s" % (e,)
|
||||
)
|
||||
# Some OIDC providers use integer IDs, but Synapse expects external IDs
|
||||
# to be strings.
|
||||
remote_user_id = str(remote_user_id)
|
||||
|
||||
# Older mapping providers don't accept the `failures` argument, so we
|
||||
# try and detect support.
|
||||
@@ -933,6 +940,19 @@ class OidcHandler(BaseHandler):
|
||||
grandfather_existing_users,
|
||||
)
|
||||
|
||||
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|
||||
"""Extract the unique remote id from an OIDC UserInfo block
|
||||
|
||||
Args:
|
||||
userinfo: An object representing the user given by the OIDC provider
|
||||
Returns:
|
||||
remote user id
|
||||
"""
|
||||
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
|
||||
# Some OIDC providers use integer IDs, but Synapse expects external IDs
|
||||
# to be strings.
|
||||
return str(remote_user_id)
|
||||
|
||||
|
||||
UserAttributeDict = TypedDict(
|
||||
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
|
||||
|
||||
@@ -183,6 +183,24 @@ class SamlHandler(BaseHandler):
|
||||
saml2_auth.in_response_to, None
|
||||
)
|
||||
|
||||
# first check if we're doing a UIA
|
||||
if current_session and current_session.ui_auth_session_id:
|
||||
try:
|
||||
remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
|
||||
except MappingException as e:
|
||||
logger.exception("Failed to extract remote user id from SAML response")
|
||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||
return
|
||||
|
||||
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||
self._auth_provider_id,
|
||||
remote_user_id,
|
||||
current_session.ui_auth_session_id,
|
||||
request,
|
||||
)
|
||||
|
||||
# otherwise, we're handling a login request.
|
||||
|
||||
# Ensure that the attributes of the logged in user meet the required
|
||||
# attributes.
|
||||
for requirement in self._saml2_attribute_requirements:
|
||||
@@ -206,14 +224,7 @@ class SamlHandler(BaseHandler):
|
||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||
return
|
||||
|
||||
# Complete the interactive auth session or the login.
|
||||
if current_session and current_session.ui_auth_session_id:
|
||||
await self._auth_handler.complete_sso_ui_auth(
|
||||
user_id, current_session.ui_auth_session_id, request
|
||||
)
|
||||
|
||||
else:
|
||||
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||
|
||||
async def _map_saml_response_to_user(
|
||||
self,
|
||||
@@ -239,16 +250,10 @@ class SamlHandler(BaseHandler):
|
||||
RedirectException: some mapping providers may raise this if they need
|
||||
to redirect to an interstitial page.
|
||||
"""
|
||||
|
||||
remote_user_id = self._user_mapping_provider.get_remote_user_id(
|
||||
remote_user_id = self._remote_id_from_saml_response(
|
||||
saml2_auth, client_redirect_url
|
||||
)
|
||||
|
||||
if not remote_user_id:
|
||||
raise MappingException(
|
||||
"Failed to extract remote user id from SAML response"
|
||||
)
|
||||
|
||||
async def saml_response_to_remapped_user_attributes(
|
||||
failures: int,
|
||||
) -> UserAttributes:
|
||||
@@ -304,6 +309,35 @@ class SamlHandler(BaseHandler):
|
||||
grandfather_existing_users,
|
||||
)
|
||||
|
||||
def _remote_id_from_saml_response(
|
||||
self,
|
||||
saml2_auth: saml2.response.AuthnResponse,
|
||||
client_redirect_url: Optional[str],
|
||||
) -> str:
|
||||
"""Extract the unique remote id from a SAML2 AuthnResponse
|
||||
|
||||
Args:
|
||||
saml2_auth: The parsed SAML2 response.
|
||||
client_redirect_url: The redirect URL passed in by the client.
|
||||
Returns:
|
||||
remote user id
|
||||
|
||||
Raises:
|
||||
MappingException if there was an error extracting the user id
|
||||
"""
|
||||
# It's not obvious why we need to pass in the redirect URI to the mapping
|
||||
# provider, but we do :/
|
||||
remote_user_id = self._user_mapping_provider.get_remote_user_id(
|
||||
saml2_auth, client_redirect_url
|
||||
)
|
||||
|
||||
if not remote_user_id:
|
||||
raise MappingException(
|
||||
"Failed to extract remote user id from SAML response"
|
||||
)
|
||||
|
||||
return remote_user_id
|
||||
|
||||
def expire_sessions(self):
|
||||
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
||||
to_expire = set()
|
||||
|
||||
@@ -17,8 +17,9 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import RedirectException
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.types import UserID, contains_invalid_mxid_characters
|
||||
|
||||
@@ -42,14 +43,16 @@ class UserAttributes:
|
||||
emails = attr.ib(type=List[str], default=attr.Factory(list))
|
||||
|
||||
|
||||
class SsoHandler(BaseHandler):
|
||||
class SsoHandler:
|
||||
# The number of attempts to ask the mapping provider for when generating an MXID.
|
||||
_MAP_USERNAME_RETRIES = 1000
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self._store = hs.get_datastore()
|
||||
self._server_name = hs.hostname
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._error_template = hs.config.sso_error_template
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
||||
def render_error(
|
||||
self, request, error: str, error_description: Optional[str] = None
|
||||
@@ -95,7 +98,7 @@ class SsoHandler(BaseHandler):
|
||||
)
|
||||
|
||||
# Check if we already have a mapping for this user.
|
||||
previously_registered_user_id = await self.store.get_user_by_external_id(
|
||||
previously_registered_user_id = await self._store.get_user_by_external_id(
|
||||
auth_provider_id, remote_user_id,
|
||||
)
|
||||
|
||||
@@ -181,7 +184,7 @@ class SsoHandler(BaseHandler):
|
||||
previously_registered_user_id = await grandfather_existing_users()
|
||||
if previously_registered_user_id:
|
||||
# Future logins should also match this user ID.
|
||||
await self.store.record_user_external_id(
|
||||
await self._store.record_user_external_id(
|
||||
auth_provider_id, remote_user_id, previously_registered_user_id
|
||||
)
|
||||
return previously_registered_user_id
|
||||
@@ -214,8 +217,8 @@ class SsoHandler(BaseHandler):
|
||||
)
|
||||
|
||||
# Check if this mxid already exists
|
||||
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
||||
if not await self.store.get_users_by_id_case_insensitive(user_id):
|
||||
user_id = UserID(attributes.localpart, self._server_name).to_string()
|
||||
if not await self._store.get_users_by_id_case_insensitive(user_id):
|
||||
# This mxid is free
|
||||
break
|
||||
else:
|
||||
@@ -238,7 +241,47 @@ class SsoHandler(BaseHandler):
|
||||
user_agent_ips=[(user_agent, ip_address)],
|
||||
)
|
||||
|
||||
await self.store.record_user_external_id(
|
||||
await self._store.record_user_external_id(
|
||||
auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
return registered_user_id
|
||||
|
||||
async def complete_sso_ui_auth_request(
|
||||
self,
|
||||
auth_provider_id: str,
|
||||
remote_user_id: str,
|
||||
ui_auth_session_id: str,
|
||||
request: Request,
|
||||
) -> None:
|
||||
"""
|
||||
Given an SSO ID, retrieve the user ID for it and complete UIA.
|
||||
|
||||
Note that this requires that the user is mapped in the "user_external_ids"
|
||||
table. This will be the case if they have ever logged in via SAML or OIDC in
|
||||
recentish synapse versions, but may not be for older users.
|
||||
|
||||
Args:
|
||||
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||
"oidc" or "saml".
|
||||
remote_user_id: The unique identifier from the SSO provider.
|
||||
ui_auth_session_id: The ID of the user-interactive auth session.
|
||||
request: The request to complete.
|
||||
"""
|
||||
|
||||
user_id = await self.get_sso_user_by_remote_user_id(
|
||||
auth_provider_id, remote_user_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"Remote user %s/%s has not previously logged in here: UIA will fail",
|
||||
auth_provider_id,
|
||||
remote_user_id,
|
||||
)
|
||||
# Let the UIA flow handle this the same as if they presented creds for a
|
||||
# different user.
|
||||
user_id = ""
|
||||
|
||||
await self._auth_handler.complete_sso_ui_auth(
|
||||
user_id, ui_auth_session_id, request
|
||||
)
|
||||
|
||||
@@ -13,7 +13,56 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from synapse.types import RoomStreamToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
|
||||
class Pusher(metaclass=abc.ABCMeta):
|
||||
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
|
||||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
self.pusher_id = pusherdict["id"]
|
||||
self.user_id = pusherdict["user_name"]
|
||||
self.app_id = pusherdict["app_id"]
|
||||
self.pushkey = pusherdict["pushkey"]
|
||||
|
||||
# This is the highest stream ordering we know it's safe to process.
|
||||
# When new events arrive, we'll be given a window of new events: we
|
||||
# should honour this rather than just looking for anything higher
|
||||
# because of potential out-of-order event serialisation. This starts
|
||||
# off as None though as we don't know any better.
|
||||
self.max_stream_ordering = None # type: Optional[int]
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_started(self, have_notifs: bool) -> None:
|
||||
"""Called when this pusher has been started.
|
||||
|
||||
Args:
|
||||
should_check_for_notifs: Whether we should immediately
|
||||
check for push to send. Set to False only if it's known there
|
||||
is nothing to send
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_stop(self) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PusherConfigException(Exception):
|
||||
def __init__(self, msg):
|
||||
super().__init__(msg)
|
||||
"""An error occurred when creating a pusher."""
|
||||
|
||||
@@ -14,12 +14,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from twisted.internet.base import DelayedCall
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.push import Pusher
|
||||
from synapse.push.mailer import Mailer
|
||||
from synapse.types import RoomStreamToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The amount of time we always wait before ever emailing about a notification
|
||||
@@ -46,7 +53,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
|
||||
INCLUDE_ALL_UNREAD_NOTIFS = False
|
||||
|
||||
|
||||
class EmailPusher:
|
||||
class EmailPusher(Pusher):
|
||||
"""
|
||||
A pusher that sends email notifications about events (approximately)
|
||||
when they happen.
|
||||
@@ -54,37 +61,31 @@ class EmailPusher:
|
||||
factor out the common parts
|
||||
"""
|
||||
|
||||
def __init__(self, hs, pusherdict, mailer):
|
||||
self.hs = hs
|
||||
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
|
||||
super().__init__(hs, pusherdict)
|
||||
self.mailer = mailer
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
self.clock = self.hs.get_clock()
|
||||
self.pusher_id = pusherdict["id"]
|
||||
self.user_id = pusherdict["user_name"]
|
||||
self.app_id = pusherdict["app_id"]
|
||||
self.email = pusherdict["pushkey"]
|
||||
self.last_stream_ordering = pusherdict["last_stream_ordering"]
|
||||
self.timed_call = None
|
||||
self.throttle_params = None
|
||||
|
||||
# See httppusher
|
||||
self.max_stream_ordering = None
|
||||
self.timed_call = None # type: Optional[DelayedCall]
|
||||
self.throttle_params = {} # type: Dict[str, Dict[str, int]]
|
||||
self._inited = False
|
||||
|
||||
self._is_processing = False
|
||||
|
||||
def on_started(self, should_check_for_notifs):
|
||||
def on_started(self, should_check_for_notifs: bool) -> None:
|
||||
"""Called when this pusher has been started.
|
||||
|
||||
Args:
|
||||
should_check_for_notifs (bool): Whether we should immediately
|
||||
should_check_for_notifs: Whether we should immediately
|
||||
check for push to send. Set to False only if it's known there
|
||||
is nothing to send
|
||||
"""
|
||||
if should_check_for_notifs and self.mailer is not None:
|
||||
self._start_processing()
|
||||
|
||||
def on_stop(self):
|
||||
def on_stop(self) -> None:
|
||||
if self.timed_call:
|
||||
try:
|
||||
self.timed_call.cancel()
|
||||
@@ -92,7 +93,7 @@ class EmailPusher:
|
||||
pass
|
||||
self.timed_call = None
|
||||
|
||||
def on_new_notifications(self, max_token: RoomStreamToken):
|
||||
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
|
||||
# We just use the minimum stream ordering and ignore the vector clock
|
||||
# component. This is safe to do as long as we *always* ignore the vector
|
||||
# clock components.
|
||||
@@ -106,23 +107,23 @@ class EmailPusher:
|
||||
self.max_stream_ordering = max_stream_ordering
|
||||
self._start_processing()
|
||||
|
||||
def on_new_receipts(self, min_stream_id, max_stream_id):
|
||||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
|
||||
# We could wake up and cancel the timer but there tend to be quite a
|
||||
# lot of read receipts so it's probably less work to just let the
|
||||
# timer fire
|
||||
pass
|
||||
|
||||
def on_timer(self):
|
||||
def on_timer(self) -> None:
|
||||
self.timed_call = None
|
||||
self._start_processing()
|
||||
|
||||
def _start_processing(self):
|
||||
def _start_processing(self) -> None:
|
||||
if self._is_processing:
|
||||
return
|
||||
|
||||
run_as_background_process("emailpush.process", self._process)
|
||||
|
||||
def _pause_processing(self):
|
||||
def _pause_processing(self) -> None:
|
||||
"""Used by tests to temporarily pause processing of events.
|
||||
|
||||
Asserts that its not currently processing.
|
||||
@@ -130,25 +131,26 @@ class EmailPusher:
|
||||
assert not self._is_processing
|
||||
self._is_processing = True
|
||||
|
||||
def _resume_processing(self):
|
||||
def _resume_processing(self) -> None:
|
||||
"""Used by tests to resume processing of events after pausing.
|
||||
"""
|
||||
assert self._is_processing
|
||||
self._is_processing = False
|
||||
self._start_processing()
|
||||
|
||||
async def _process(self):
|
||||
async def _process(self) -> None:
|
||||
# we should never get here if we are already processing
|
||||
assert not self._is_processing
|
||||
|
||||
try:
|
||||
self._is_processing = True
|
||||
|
||||
if self.throttle_params is None:
|
||||
if not self._inited:
|
||||
# this is our first loop: load up the throttle params
|
||||
self.throttle_params = await self.store.get_throttle_params_by_room(
|
||||
self.pusher_id
|
||||
)
|
||||
self._inited = True
|
||||
|
||||
# if the max ordering changes while we're running _unsafe_process,
|
||||
# call it again, and so on until we've caught up.
|
||||
@@ -163,17 +165,19 @@ class EmailPusher:
|
||||
finally:
|
||||
self._is_processing = False
|
||||
|
||||
async def _unsafe_process(self):
|
||||
async def _unsafe_process(self) -> None:
|
||||
"""
|
||||
Main logic of the push loop without the wrapper function that sets
|
||||
up logging, measures and guards against multiple instances of it
|
||||
being run.
|
||||
"""
|
||||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
||||
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
|
||||
unprocessed = await fn(self.user_id, start, self.max_stream_ordering)
|
||||
assert self.max_stream_ordering is not None
|
||||
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
|
||||
self.user_id, start, self.max_stream_ordering
|
||||
)
|
||||
|
||||
soonest_due_at = None
|
||||
soonest_due_at = None # type: Optional[int]
|
||||
|
||||
if not unprocessed:
|
||||
await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
|
||||
@@ -230,7 +234,9 @@ class EmailPusher:
|
||||
self.seconds_until(soonest_due_at), self.on_timer
|
||||
)
|
||||
|
||||
async def save_last_stream_ordering_and_success(self, last_stream_ordering):
|
||||
async def save_last_stream_ordering_and_success(
|
||||
self, last_stream_ordering: Optional[int]
|
||||
) -> None:
|
||||
if last_stream_ordering is None:
|
||||
# This happens if we haven't yet processed anything
|
||||
return
|
||||
@@ -248,28 +254,30 @@ class EmailPusher:
|
||||
# lets just stop and return.
|
||||
self.on_stop()
|
||||
|
||||
def seconds_until(self, ts_msec):
|
||||
def seconds_until(self, ts_msec: int) -> float:
|
||||
secs = (ts_msec - self.clock.time_msec()) / 1000
|
||||
return max(secs, 0)
|
||||
|
||||
def get_room_throttle_ms(self, room_id):
|
||||
def get_room_throttle_ms(self, room_id: str) -> int:
|
||||
if room_id in self.throttle_params:
|
||||
return self.throttle_params[room_id]["throttle_ms"]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get_room_last_sent_ts(self, room_id):
|
||||
def get_room_last_sent_ts(self, room_id: str) -> int:
|
||||
if room_id in self.throttle_params:
|
||||
return self.throttle_params[room_id]["last_sent_ts"]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def room_ready_to_notify_at(self, room_id):
|
||||
def room_ready_to_notify_at(self, room_id: str) -> int:
|
||||
"""
|
||||
Determines whether throttling should prevent us from sending an email
|
||||
for the given room
|
||||
Returns: The timestamp when we are next allowed to send an email notif
|
||||
for this room
|
||||
|
||||
Returns:
|
||||
The timestamp when we are next allowed to send an email notif
|
||||
for this room
|
||||
"""
|
||||
last_sent_ts = self.get_room_last_sent_ts(room_id)
|
||||
throttle_ms = self.get_room_throttle_ms(room_id)
|
||||
@@ -277,7 +285,9 @@ class EmailPusher:
|
||||
may_send_at = last_sent_ts + throttle_ms
|
||||
return may_send_at
|
||||
|
||||
async def sent_notif_update_throttle(self, room_id, notified_push_action):
|
||||
async def sent_notif_update_throttle(
|
||||
self, room_id: str, notified_push_action: dict
|
||||
) -> None:
|
||||
# We have sent a notification, so update the throttle accordingly.
|
||||
# If the event that triggered the notif happened more than
|
||||
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
|
||||
@@ -315,7 +325,7 @@ class EmailPusher:
|
||||
self.pusher_id, room_id, self.throttle_params[room_id]
|
||||
)
|
||||
|
||||
async def send_notification(self, push_actions, reason):
|
||||
async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
|
||||
logger.info("Sending notif email for user %r", self.user_id)
|
||||
|
||||
await self.mailer.send_notification_mail(
|
||||
|
||||
@@ -14,19 +14,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging import opentracing
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.push import Pusher, PusherConfigException
|
||||
from synapse.types import RoomStreamToken
|
||||
|
||||
from . import push_rule_evaluator, push_tools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
http_push_processed_counter = Counter(
|
||||
@@ -50,24 +56,18 @@ http_badges_failed_counter = Counter(
|
||||
)
|
||||
|
||||
|
||||
class HttpPusher:
|
||||
class HttpPusher(Pusher):
|
||||
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
|
||||
MAX_BACKOFF_SEC = 60 * 60
|
||||
|
||||
# This one's in ms because we compare it against the clock
|
||||
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
|
||||
|
||||
def __init__(self, hs, pusherdict):
|
||||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
|
||||
super().__init__(hs, pusherdict)
|
||||
self.storage = self.hs.get_storage()
|
||||
self.clock = self.hs.get_clock()
|
||||
self.state_handler = self.hs.get_state_handler()
|
||||
self.user_id = pusherdict["user_name"]
|
||||
self.app_id = pusherdict["app_id"]
|
||||
self.app_display_name = pusherdict["app_display_name"]
|
||||
self.device_display_name = pusherdict["device_display_name"]
|
||||
self.pushkey = pusherdict["pushkey"]
|
||||
self.pushkey_ts = pusherdict["ts"]
|
||||
self.data = pusherdict["data"]
|
||||
self.last_stream_ordering = pusherdict["last_stream_ordering"]
|
||||
@@ -77,13 +77,6 @@ class HttpPusher:
|
||||
self._is_processing = False
|
||||
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
|
||||
|
||||
# This is the highest stream ordering we know it's safe to process.
|
||||
# When new events arrive, we'll be given a window of new events: we
|
||||
# should honour this rather than just looking for anything higher
|
||||
# because of potential out-of-order event serialisation. This starts
|
||||
# off as None though as we don't know any better.
|
||||
self.max_stream_ordering = None
|
||||
|
||||
if "data" not in pusherdict:
|
||||
raise PusherConfigException("No 'data' key for HTTP pusher")
|
||||
self.data = pusherdict["data"]
|
||||
@@ -97,26 +90,39 @@ class HttpPusher:
|
||||
if self.data is None:
|
||||
raise PusherConfigException("data can not be null for HTTP pusher")
|
||||
|
||||
# Validate that there's a URL and it is of the proper form.
|
||||
if "url" not in self.data:
|
||||
raise PusherConfigException("'url' required in data for HTTP pusher")
|
||||
self.url = self.data["url"]
|
||||
|
||||
url = self.data["url"]
|
||||
if not isinstance(url, str):
|
||||
raise PusherConfigException("'url' must be a string")
|
||||
url_parts = urllib.parse.urlparse(url)
|
||||
# Note that the specification also says the scheme must be HTTPS, but
|
||||
# it isn't up to the homeserver to verify that.
|
||||
if url_parts.path != "/_matrix/push/v1/notify":
|
||||
raise PusherConfigException(
|
||||
"'url' must have a path of '/_matrix/push/v1/notify'"
|
||||
)
|
||||
|
||||
self.url = url
|
||||
self.http_client = hs.get_proxied_blacklisted_http_client()
|
||||
self.data_minus_url = {}
|
||||
self.data_minus_url.update(self.data)
|
||||
del self.data_minus_url["url"]
|
||||
|
||||
def on_started(self, should_check_for_notifs):
|
||||
def on_started(self, should_check_for_notifs: bool) -> None:
|
||||
"""Called when this pusher has been started.
|
||||
|
||||
Args:
|
||||
should_check_for_notifs (bool): Whether we should immediately
|
||||
should_check_for_notifs: Whether we should immediately
|
||||
check for push to send. Set to False only if it's known there
|
||||
is nothing to send
|
||||
"""
|
||||
if should_check_for_notifs:
|
||||
self._start_processing()
|
||||
|
||||
def on_new_notifications(self, max_token: RoomStreamToken):
|
||||
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
|
||||
# We just use the minimum stream ordering and ignore the vector clock
|
||||
# component. This is safe to do as long as we *always* ignore the vector
|
||||
# clock components.
|
||||
@@ -127,14 +133,14 @@ class HttpPusher:
|
||||
)
|
||||
self._start_processing()
|
||||
|
||||
def on_new_receipts(self, min_stream_id, max_stream_id):
|
||||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
|
||||
# Note that the min here shouldn't be relied upon to be accurate.
|
||||
|
||||
# We could check the receipts are actually m.read receipts here,
|
||||
# but currently that's the only type of receipt anyway...
|
||||
run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
|
||||
|
||||
async def _update_badge(self):
|
||||
async def _update_badge(self) -> None:
|
||||
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
|
||||
# to be largely redundant. perhaps we can remove it.
|
||||
badge = await push_tools.get_badge_count(
|
||||
@@ -144,10 +150,10 @@ class HttpPusher:
|
||||
)
|
||||
await self._send_badge(badge)
|
||||
|
||||
def on_timer(self):
|
||||
def on_timer(self) -> None:
|
||||
self._start_processing()
|
||||
|
||||
def on_stop(self):
|
||||
def on_stop(self) -> None:
|
||||
if self.timed_call:
|
||||
try:
|
||||
self.timed_call.cancel()
|
||||
@@ -155,13 +161,13 @@ class HttpPusher:
|
||||
pass
|
||||
self.timed_call = None
|
||||
|
||||
def _start_processing(self):
|
||||
def _start_processing(self) -> None:
|
||||
if self._is_processing:
|
||||
return
|
||||
|
||||
run_as_background_process("httppush.process", self._process)
|
||||
|
||||
async def _process(self):
|
||||
async def _process(self) -> None:
|
||||
# we should never get here if we are already processing
|
||||
assert not self._is_processing
|
||||
|
||||
@@ -180,7 +186,7 @@ class HttpPusher:
|
||||
finally:
|
||||
self._is_processing = False
|
||||
|
||||
async def _unsafe_process(self):
|
||||
async def _unsafe_process(self) -> None:
|
||||
"""
|
||||
Looks for unset notifications and dispatch them, in order
|
||||
Never call this directly: use _process which will only allow this to
|
||||
@@ -188,6 +194,7 @@ class HttpPusher:
|
||||
"""
|
||||
|
||||
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
|
||||
assert self.max_stream_ordering is not None
|
||||
unprocessed = await fn(
|
||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||
)
|
||||
@@ -257,17 +264,12 @@ class HttpPusher:
|
||||
)
|
||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||
self.last_stream_ordering = push_action["stream_ordering"]
|
||||
pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
|
||||
await self.store.update_pusher_last_stream_ordering(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_id,
|
||||
self.last_stream_ordering,
|
||||
)
|
||||
if not pusher_still_exists:
|
||||
# The pusher has been deleted while we were processing, so
|
||||
# lets just stop and return.
|
||||
self.on_stop()
|
||||
return
|
||||
|
||||
self.failing_since = None
|
||||
await self.store.update_pusher_failing_since(
|
||||
@@ -283,7 +285,7 @@ class HttpPusher:
|
||||
)
|
||||
break
|
||||
|
||||
async def _process_one(self, push_action):
|
||||
async def _process_one(self, push_action: dict) -> bool:
|
||||
if "notify" not in push_action["actions"]:
|
||||
return True
|
||||
|
||||
@@ -314,7 +316,9 @@ class HttpPusher:
|
||||
await self.hs.remove_pusher(self.app_id, pk, self.user_id)
|
||||
return True
|
||||
|
||||
async def _build_notification_dict(self, event, tweaks, badge):
|
||||
async def _build_notification_dict(
|
||||
self, event: EventBase, tweaks: Dict[str, bool], badge: int
|
||||
) -> Dict[str, Any]:
|
||||
priority = "low"
|
||||
if (
|
||||
event.type == EventTypes.Encrypted
|
||||
@@ -344,9 +348,7 @@ class HttpPusher:
|
||||
}
|
||||
return d
|
||||
|
||||
ctx = await push_tools.get_context_for_event(
|
||||
self.storage, self.state_handler, event, self.user_id
|
||||
)
|
||||
ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
|
||||
|
||||
d = {
|
||||
"notification": {
|
||||
@@ -386,7 +388,9 @@ class HttpPusher:
|
||||
|
||||
return d
|
||||
|
||||
async def dispatch_push(self, event, tweaks, badge):
|
||||
async def dispatch_push(
|
||||
self, event: EventBase, tweaks: Dict[str, bool], badge: int
|
||||
) -> Union[bool, Iterable[str]]:
|
||||
notification_dict = await self._build_notification_dict(event, tweaks, badge)
|
||||
if not notification_dict:
|
||||
return []
|
||||
|
||||
@@ -19,7 +19,7 @@ import logging
|
||||
import urllib.parse
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Iterable, List, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
|
||||
|
||||
import bleach
|
||||
import jinja2
|
||||
@@ -27,16 +27,20 @@ import jinja2
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.config.emailconfig import EmailSubjectConfig
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.push.presentable_names import (
|
||||
calculate_room_name,
|
||||
descriptor_from_member_events,
|
||||
name_from_member_event,
|
||||
)
|
||||
from synapse.types import UserID
|
||||
from synapse.types import StateMap, UserID
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -93,7 +97,13 @@ ALLOWED_ATTRS = {
|
||||
|
||||
|
||||
class Mailer:
|
||||
def __init__(self, hs, app_name, template_html, template_text):
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
app_name: str,
|
||||
template_html: jinja2.Template,
|
||||
template_text: jinja2.Template,
|
||||
):
|
||||
self.hs = hs
|
||||
self.template_html = template_html
|
||||
self.template_text = template_text
|
||||
@@ -108,17 +118,19 @@ class Mailer:
|
||||
|
||||
logger.info("Created Mailer for app_name %s" % app_name)
|
||||
|
||||
async def send_password_reset_mail(self, email_address, token, client_secret, sid):
|
||||
async def send_password_reset_mail(
|
||||
self, email_address: str, token: str, client_secret: str, sid: str
|
||||
) -> None:
|
||||
"""Send an email with a password reset link to a user
|
||||
|
||||
Args:
|
||||
email_address (str): Email address we're sending the password
|
||||
email_address: Email address we're sending the password
|
||||
reset to
|
||||
token (str): Unique token generated by the server to verify
|
||||
token: Unique token generated by the server to verify
|
||||
the email was received
|
||||
client_secret (str): Unique token generated by the client to
|
||||
client_secret: Unique token generated by the client to
|
||||
group together multiple email sending attempts
|
||||
sid (str): The generated session ID
|
||||
sid: The generated session ID
|
||||
"""
|
||||
params = {"token": token, "client_secret": client_secret, "sid": sid}
|
||||
link = (
|
||||
@@ -136,17 +148,19 @@ class Mailer:
|
||||
template_vars,
|
||||
)
|
||||
|
||||
async def send_registration_mail(self, email_address, token, client_secret, sid):
|
||||
async def send_registration_mail(
|
||||
self, email_address: str, token: str, client_secret: str, sid: str
|
||||
) -> None:
|
||||
"""Send an email with a registration confirmation link to a user
|
||||
|
||||
Args:
|
||||
email_address (str): Email address we're sending the registration
|
||||
email_address: Email address we're sending the registration
|
||||
link to
|
||||
token (str): Unique token generated by the server to verify
|
||||
token: Unique token generated by the server to verify
|
||||
the email was received
|
||||
client_secret (str): Unique token generated by the client to
|
||||
client_secret: Unique token generated by the client to
|
||||
group together multiple email sending attempts
|
||||
sid (str): The generated session ID
|
||||
sid: The generated session ID
|
||||
"""
|
||||
params = {"token": token, "client_secret": client_secret, "sid": sid}
|
||||
link = (
|
||||
@@ -164,18 +178,20 @@ class Mailer:
|
||||
template_vars,
|
||||
)
|
||||
|
||||
async def send_add_threepid_mail(self, email_address, token, client_secret, sid):
|
||||
async def send_add_threepid_mail(
|
||||
self, email_address: str, token: str, client_secret: str, sid: str
|
||||
) -> None:
|
||||
"""Send an email with a validation link to a user for adding a 3pid to their account
|
||||
|
||||
Args:
|
||||
email_address (str): Email address we're sending the validation link to
|
||||
email_address: Email address we're sending the validation link to
|
||||
|
||||
token (str): Unique token generated by the server to verify the email was received
|
||||
token: Unique token generated by the server to verify the email was received
|
||||
|
||||
client_secret (str): Unique token generated by the client to group together
|
||||
client_secret: Unique token generated by the client to group together
|
||||
multiple email sending attempts
|
||||
|
||||
sid (str): The generated session ID
|
||||
sid: The generated session ID
|
||||
"""
|
||||
params = {"token": token, "client_secret": client_secret, "sid": sid}
|
||||
link = (
|
||||
@@ -194,8 +210,13 @@ class Mailer:
|
||||
)
|
||||
|
||||
async def send_notification_mail(
|
||||
self, app_id, user_id, email_address, push_actions, reason
|
||||
):
|
||||
self,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
email_address: str,
|
||||
push_actions: Iterable[Dict[str, Any]],
|
||||
reason: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Send email regarding a user's room notifications"""
|
||||
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
|
||||
|
||||
@@ -203,7 +224,7 @@ class Mailer:
|
||||
[pa["event_id"] for pa in push_actions]
|
||||
)
|
||||
|
||||
notifs_by_room = {}
|
||||
notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]]
|
||||
for pa in push_actions:
|
||||
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
|
||||
|
||||
@@ -262,7 +283,9 @@ class Mailer:
|
||||
|
||||
await self.send_email(email_address, summary_text, template_vars)
|
||||
|
||||
async def send_email(self, email_address, subject, extra_template_vars):
|
||||
async def send_email(
|
||||
self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Send an email with the given information and template text"""
|
||||
try:
|
||||
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
|
||||
@@ -315,8 +338,13 @@ class Mailer:
|
||||
)
|
||||
|
||||
async def get_room_vars(
|
||||
self, room_id, user_id, notifs, notif_events, room_state_ids
|
||||
):
|
||||
self,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
notifs: Iterable[Dict[str, Any]],
|
||||
notif_events: Dict[str, EventBase],
|
||||
room_state_ids: StateMap[str],
|
||||
) -> Dict[str, Any]:
|
||||
# Check if one of the notifs is an invite event for the user.
|
||||
is_invite = False
|
||||
for n in notifs:
|
||||
@@ -334,7 +362,7 @@ class Mailer:
|
||||
"notifs": [],
|
||||
"invite": is_invite,
|
||||
"link": self.make_room_link(room_id),
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if not is_invite:
|
||||
for n in notifs:
|
||||
@@ -365,7 +393,13 @@ class Mailer:
|
||||
|
||||
return room_vars
|
||||
|
||||
async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
|
||||
async def get_notif_vars(
|
||||
self,
|
||||
notif: Dict[str, Any],
|
||||
user_id: str,
|
||||
notif_event: EventBase,
|
||||
room_state_ids: StateMap[str],
|
||||
) -> Dict[str, Any]:
|
||||
results = await self.store.get_events_around(
|
||||
notif["room_id"],
|
||||
notif["event_id"],
|
||||
@@ -391,7 +425,9 @@ class Mailer:
|
||||
|
||||
return ret
|
||||
|
||||
async def get_message_vars(self, notif, event, room_state_ids):
|
||||
async def get_message_vars(
|
||||
self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
|
||||
return None
|
||||
|
||||
@@ -432,7 +468,9 @@ class Mailer:
|
||||
|
||||
return ret
|
||||
|
||||
def add_text_message_vars(self, messagevars, event):
|
||||
def add_text_message_vars(
|
||||
self, messagevars: Dict[str, Any], event: EventBase
|
||||
) -> None:
|
||||
msgformat = event.content.get("format")
|
||||
|
||||
messagevars["format"] = msgformat
|
||||
@@ -445,15 +483,18 @@ class Mailer:
|
||||
elif body:
|
||||
messagevars["body_text_html"] = safe_text(body)
|
||||
|
||||
return messagevars
|
||||
|
||||
def add_image_message_vars(self, messagevars, event):
|
||||
def add_image_message_vars(
|
||||
self, messagevars: Dict[str, Any], event: EventBase
|
||||
) -> None:
|
||||
messagevars["image_url"] = event.content["url"]
|
||||
|
||||
return messagevars
|
||||
|
||||
async def make_summary_text(
|
||||
self, notifs_by_room, room_state_ids, notif_events, user_id, reason
|
||||
self,
|
||||
notifs_by_room: Dict[str, List[Dict[str, Any]]],
|
||||
room_state_ids: Dict[str, StateMap[str]],
|
||||
notif_events: Dict[str, EventBase],
|
||||
user_id: str,
|
||||
reason: Dict[str, Any],
|
||||
):
|
||||
if len(notifs_by_room) == 1:
|
||||
# Only one room has new stuff
|
||||
@@ -580,7 +621,7 @@ class Mailer:
|
||||
"app": self.app_name,
|
||||
}
|
||||
|
||||
def make_room_link(self, room_id):
|
||||
def make_room_link(self, room_id: str) -> str:
|
||||
if self.hs.config.email_riot_base_url:
|
||||
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
|
||||
elif self.app_name == "Vector":
|
||||
@@ -590,7 +631,7 @@ class Mailer:
|
||||
base_url = "https://matrix.to/#"
|
||||
return "%s/%s" % (base_url, room_id)
|
||||
|
||||
def make_notif_link(self, notif):
|
||||
def make_notif_link(self, notif: Dict[str, str]) -> str:
|
||||
if self.hs.config.email_riot_base_url:
|
||||
return "%s/#/room/%s/%s" % (
|
||||
self.hs.config.email_riot_base_url,
|
||||
@@ -606,7 +647,9 @@ class Mailer:
|
||||
else:
|
||||
return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
|
||||
|
||||
def make_unsubscribe_link(self, user_id, app_id, email_address):
|
||||
def make_unsubscribe_link(
|
||||
self, user_id: str, app_id: str, email_address: str
|
||||
) -> str:
|
||||
params = {
|
||||
"access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
|
||||
"app_id": app_id,
|
||||
@@ -620,7 +663,7 @@ class Mailer:
|
||||
)
|
||||
|
||||
|
||||
def safe_markup(raw_html):
|
||||
def safe_markup(raw_html: str) -> jinja2.Markup:
|
||||
return jinja2.Markup(
|
||||
bleach.linkify(
|
||||
bleach.clean(
|
||||
@@ -635,7 +678,7 @@ def safe_markup(raw_html):
|
||||
)
|
||||
|
||||
|
||||
def safe_text(raw_text):
|
||||
def safe_text(raw_text: str) -> jinja2.Markup:
|
||||
"""
|
||||
Process text: treat it as HTML but escape any tags (ie. just escape the
|
||||
HTML) then linkify it.
|
||||
@@ -655,7 +698,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]:
|
||||
return ret
|
||||
|
||||
|
||||
def string_ordinal_total(s):
|
||||
def string_ordinal_total(s: str) -> int:
|
||||
tot = 0
|
||||
for c in s:
|
||||
tot += ord(c)
|
||||
|
||||
@@ -12,6 +12,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
|
||||
from synapse.storage import Storage
|
||||
from synapse.storage.databases.main import DataStore
|
||||
@@ -46,7 +49,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
|
||||
return badge
|
||||
|
||||
|
||||
async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
|
||||
async def get_context_for_event(
|
||||
storage: Storage, ev: EventBase, user_id: str
|
||||
) -> Dict[str, str]:
|
||||
ctx = {}
|
||||
|
||||
room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
|
||||
|
||||
@@ -14,25 +14,31 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
|
||||
from synapse.push import Pusher
|
||||
from synapse.push.emailpusher import EmailPusher
|
||||
from synapse.push.httppusher import HttpPusher
|
||||
from synapse.push.mailer import Mailer
|
||||
|
||||
from .httppusher import HttpPusher
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PusherFactory:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.config = hs.config
|
||||
|
||||
self.pusher_types = {"http": HttpPusher}
|
||||
self.pusher_types = {
|
||||
"http": HttpPusher
|
||||
} # type: Dict[str, Callable[[HomeServer, dict], Pusher]]
|
||||
|
||||
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
|
||||
if hs.config.email_enable_notifs:
|
||||
self.mailers = {} # app_name -> Mailer
|
||||
self.mailers = {} # type: Dict[str, Mailer]
|
||||
|
||||
self._notif_template_html = hs.config.email_notif_template_html
|
||||
self._notif_template_text = hs.config.email_notif_template_text
|
||||
@@ -41,7 +47,7 @@ class PusherFactory:
|
||||
|
||||
logger.info("defined email pusher type")
|
||||
|
||||
def create_pusher(self, pusherdict):
|
||||
def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
|
||||
kind = pusherdict["kind"]
|
||||
f = self.pusher_types.get(kind, None)
|
||||
if not f:
|
||||
@@ -49,7 +55,9 @@ class PusherFactory:
|
||||
logger.debug("creating %s pusher for %r", kind, pusherdict)
|
||||
return f(self.hs, pusherdict)
|
||||
|
||||
def _create_email_pusher(self, _hs, pusherdict):
|
||||
def _create_email_pusher(
|
||||
self, _hs: "HomeServer", pusherdict: Dict[str, Any]
|
||||
) -> EmailPusher:
|
||||
app_name = self._app_name_from_pusherdict(pusherdict)
|
||||
mailer = self.mailers.get(app_name)
|
||||
if not mailer:
|
||||
@@ -62,7 +70,7 @@ class PusherFactory:
|
||||
self.mailers[app_name] = mailer
|
||||
return EmailPusher(self.hs, pusherdict, mailer)
|
||||
|
||||
def _app_name_from_pusherdict(self, pusherdict):
|
||||
def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str:
|
||||
data = pusherdict["data"]
|
||||
|
||||
if isinstance(data, dict):
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from prometheus_client import Gauge
|
||||
|
||||
@@ -23,9 +23,7 @@ from synapse.metrics.background_process_metrics import (
|
||||
run_as_background_process,
|
||||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.push.emailpusher import EmailPusher
|
||||
from synapse.push.httppusher import HttpPusher
|
||||
from synapse.push import Pusher, PusherConfigException
|
||||
from synapse.push.pusher import PusherFactory
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
@@ -77,7 +75,7 @@ class PusherPool:
|
||||
self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
|
||||
|
||||
# map from user id to app_id:pushkey to pusher
|
||||
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
|
||||
self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
|
||||
|
||||
def start(self):
|
||||
"""Starts the pushers off in a background process.
|
||||
@@ -99,11 +97,11 @@ class PusherPool:
|
||||
lang,
|
||||
data,
|
||||
profile_tag="",
|
||||
):
|
||||
) -> Optional[Pusher]:
|
||||
"""Creates a new pusher and adds it to the pool
|
||||
|
||||
Returns:
|
||||
EmailPusher|HttpPusher
|
||||
The newly created pusher.
|
||||
"""
|
||||
|
||||
time_now_msec = self.clock.time_msec()
|
||||
@@ -267,17 +265,19 @@ class PusherPool:
|
||||
except Exception:
|
||||
logger.exception("Exception in pusher on_new_receipts")
|
||||
|
||||
async def start_pusher_by_id(self, app_id, pushkey, user_id):
|
||||
async def start_pusher_by_id(
|
||||
self, app_id: str, pushkey: str, user_id: str
|
||||
) -> Optional[Pusher]:
|
||||
"""Look up the details for the given pusher, and start it
|
||||
|
||||
Returns:
|
||||
EmailPusher|HttpPusher|None: The pusher started, if any
|
||||
The pusher started, if any
|
||||
"""
|
||||
if not self._should_start_pushers:
|
||||
return
|
||||
return None
|
||||
|
||||
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
|
||||
return
|
||||
return None
|
||||
|
||||
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||
|
||||
@@ -303,19 +303,19 @@ class PusherPool:
|
||||
|
||||
logger.info("Started pushers")
|
||||
|
||||
async def _start_pusher(self, pusherdict):
|
||||
async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
|
||||
"""Start the given pusher
|
||||
|
||||
Args:
|
||||
pusherdict (dict): dict with the values pulled from the db table
|
||||
pusherdict: dict with the values pulled from the db table
|
||||
|
||||
Returns:
|
||||
EmailPusher|HttpPusher
|
||||
The newly created pusher or None.
|
||||
"""
|
||||
if not self._pusher_shard_config.should_handle(
|
||||
self._instance_name, pusherdict["user_name"]
|
||||
):
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
p = self.pusher_factory.create_pusher(pusherdict)
|
||||
@@ -328,15 +328,15 @@ class PusherPool:
|
||||
pusherdict.get("pushkey"),
|
||||
e,
|
||||
)
|
||||
return
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Couldn't start pusher id %i: caught Exception", pusherdict["id"],
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
if not p:
|
||||
return
|
||||
return None
|
||||
|
||||
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
|
||||
|
||||
|
||||
@@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
|
||||
assert self.METHOD in ("PUT", "POST", "GET")
|
||||
|
||||
self._replication_secret = None
|
||||
if hs.config.worker.worker_replication_secret:
|
||||
self._replication_secret = hs.config.worker.worker_replication_secret
|
||||
|
||||
def _check_auth(self, request) -> None:
|
||||
# Get the authorization header.
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
|
||||
if len(auth_headers) > 1:
|
||||
raise RuntimeError("Too many Authorization headers.")
|
||||
parts = auth_headers[0].split(b" ")
|
||||
if parts[0] == b"Bearer" and len(parts) == 2:
|
||||
received_secret = parts[1].decode("ascii")
|
||||
if self._replication_secret == received_secret:
|
||||
# Success!
|
||||
return
|
||||
|
||||
raise RuntimeError("Invalid Authorization header.")
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _serialize_payload(**kwargs):
|
||||
"""Static method that is called when creating a request.
|
||||
@@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
|
||||
outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
|
||||
|
||||
replication_secret = None
|
||||
if hs.config.worker.worker_replication_secret:
|
||||
replication_secret = hs.config.worker.worker_replication_secret.encode(
|
||||
"ascii"
|
||||
)
|
||||
|
||||
@trace(opname="outgoing_replication_request")
|
||||
@outgoing_gauge.track_inprogress()
|
||||
async def send_request(instance_name="master", **kwargs):
|
||||
@@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
# the master, and so whether we should clean up or not.
|
||||
while True:
|
||||
headers = {} # type: Dict[bytes, List[bytes]]
|
||||
# Add an authorization header, if configured.
|
||||
if replication_secret:
|
||||
headers[b"Authorization"] = [b"Bearer " + replication_secret]
|
||||
inject_active_span_byte_dict(headers, None, check_destination=False)
|
||||
try:
|
||||
result = await request_func(uri, data, headers=headers)
|
||||
@@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
url_args = list(self.PATH_ARGS)
|
||||
handler = self._handle_request
|
||||
method = self.METHOD
|
||||
|
||||
if self.CACHE:
|
||||
handler = self._cached_handler # type: ignore
|
||||
url_args.append("txn_id")
|
||||
|
||||
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
|
||||
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
|
||||
|
||||
http_server.register_paths(
|
||||
method, [pattern], handler, self.__class__.__name__,
|
||||
method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
|
||||
)
|
||||
|
||||
def _cached_handler(self, request, txn_id, **kwargs):
|
||||
def _check_auth_and_handle(self, request, **kwargs):
|
||||
"""Called on new incoming requests when caching is enabled. Checks
|
||||
if there is a cached response for the request and returns that,
|
||||
otherwise calls `_handle_request` and caches its response.
|
||||
@@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
# We just use the txn_id here, but we probably also want to use the
|
||||
# other PATH_ARGS as well.
|
||||
|
||||
assert self.CACHE
|
||||
# Check the authorization headers before handling the request.
|
||||
if self._replication_secret:
|
||||
self._check_auth(request)
|
||||
|
||||
return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
|
||||
if self.CACHE:
|
||||
txn_id = kwargs.pop("txn_id")
|
||||
|
||||
return self.response_cache.wrap(
|
||||
txn_id, self._handle_request, request, **kwargs
|
||||
)
|
||||
|
||||
return self._handle_request(request, **kwargs)
|
||||
|
||||
@@ -320,9 +320,9 @@ class UserRestServletV2(RestServlet):
|
||||
data={},
|
||||
)
|
||||
|
||||
if "avatar_url" in body and type(body["avatar_url"]) == str:
|
||||
if "avatar_url" in body and isinstance(body["avatar_url"], str):
|
||||
await self.profile_handler.set_avatar_url(
|
||||
user_id, requester, body["avatar_url"], True
|
||||
target_user, requester, body["avatar_url"], True
|
||||
)
|
||||
|
||||
ret = await self.admin_handler.get_user(target_user)
|
||||
@@ -420,6 +420,9 @@ class UserRegisterServlet(RestServlet):
|
||||
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
|
||||
raise SynapseError(400, "Invalid user type")
|
||||
|
||||
if "mac" not in body:
|
||||
raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
|
||||
|
||||
got_mac = body["mac"]
|
||||
|
||||
want_mac_builder = hmac.new(
|
||||
|
||||
@@ -458,7 +458,7 @@ class RegisterRestServlet(RestServlet):
|
||||
|
||||
# == Normal User Registration == (everyone else)
|
||||
if not self._registration_enabled:
|
||||
raise SynapseError(403, "Registration has been disabled")
|
||||
raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
|
||||
|
||||
# Check if this account is upgrading from a guest account.
|
||||
guest_access_token = body.get("guest_access_token", None)
|
||||
|
||||
@@ -155,6 +155,11 @@ def add_file_headers(request, media_type, file_size, upload_name):
|
||||
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
|
||||
request.setHeader(b"Content-Length", b"%d" % (file_size,))
|
||||
|
||||
# Tell web crawlers to not index, archive, or follow links in media. This
|
||||
# should help to prevent things in the media repo from showing up in web
|
||||
# search results.
|
||||
request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
|
||||
|
||||
|
||||
# separators as defined in RFC2616. SP and HT are handled separately.
|
||||
# see _can_encode_filename_as_token.
|
||||
|
||||
@@ -675,7 +675,11 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
logger.debug("No media removed from url cache")
|
||||
|
||||
|
||||
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
||||
def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
|
||||
# If there's no body, nothing useful is going to be found.
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
from lxml import etree
|
||||
|
||||
try:
|
||||
|
||||
@@ -44,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader(b"Content-Length").decode("ascii")
|
||||
content_length = request.getHeader("Content-Length")
|
||||
if content_length is None:
|
||||
raise SynapseError(msg="Request must specify a Content-Length", code=400)
|
||||
if int(content_length) > self.max_upload_size:
|
||||
|
||||
@@ -381,6 +381,28 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
)
|
||||
return MatrixFederationHttpClient(self, tls_client_options_factory)
|
||||
|
||||
@cache_in_self
|
||||
def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
|
||||
"""
|
||||
An HTTP client that uses configured HTTP(S) proxies and blacklists IPs
|
||||
based on the IP range blacklist.
|
||||
"""
|
||||
return SimpleHttpClient(
|
||||
self,
|
||||
ip_blacklist=self.config.ip_range_blacklist,
|
||||
use_proxy=True,
|
||||
)
|
||||
|
||||
@cache_in_self
|
||||
def get_federation_http_client(self) -> MatrixFederationHttpClient:
|
||||
"""
|
||||
An HTTP client for federation.
|
||||
"""
|
||||
tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
|
||||
self.config
|
||||
)
|
||||
return MatrixFederationHttpClient(self, tls_client_options_factory)
|
||||
|
||||
@cache_in_self
|
||||
def get_room_creation_handler(self) -> RoomCreationHandler:
|
||||
return RoomCreationHandler(self)
|
||||
|
||||
@@ -783,7 +783,7 @@ class StateResolutionStore:
|
||||
)
|
||||
|
||||
def get_auth_chain_difference(
|
||||
self, state_sets: List[Set[str]]
|
||||
self, room_id: str, state_sets: List[Set[str]]
|
||||
) -> Awaitable[Set[str]]:
|
||||
"""Given sets of state events figure out the auth chain difference (as
|
||||
per state res v2 algorithm).
|
||||
@@ -796,4 +796,4 @@ class StateResolutionStore:
|
||||
An awaitable that resolves to a set of event IDs.
|
||||
"""
|
||||
|
||||
return self.store.get_auth_chain_difference(state_sets)
|
||||
return self.store.get_auth_chain_difference(room_id, state_sets)
|
||||
|
||||
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.types import Collection, MutableStateMap, StateMap
|
||||
from synapse.util import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
|
||||
|
||||
# Also fetch all auth events that appear in only some of the state sets'
|
||||
# auth chains.
|
||||
auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
|
||||
auth_diff = await _get_auth_chain_difference(
|
||||
room_id, state_sets, event_map, state_res_store
|
||||
)
|
||||
|
||||
full_conflicted_set = set(
|
||||
itertools.chain(
|
||||
@@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
|
||||
|
||||
|
||||
async def _get_auth_chain_difference(
|
||||
room_id: str,
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
@@ -252,9 +255,90 @@ async def _get_auth_chain_difference(
|
||||
Set of event IDs
|
||||
"""
|
||||
|
||||
# The `StateResolutionStore.get_auth_chain_difference` function assumes that
|
||||
# all events passed to it (and their auth chains) have been persisted
|
||||
# previously. This is not the case for any events in the `event_map`, and so
|
||||
# we need to manually handle those events.
|
||||
#
|
||||
# We do this by:
|
||||
# 1. calculating the auth chain difference for the state sets based on the
|
||||
# events in `event_map` alone
|
||||
# 2. replacing any events in the state_sets that are also in `event_map`
|
||||
# with their auth events (recursively), and then calling
|
||||
# `store.get_auth_chain_difference` as normal
|
||||
# 3. adding the results of 1 and 2 together.
|
||||
|
||||
# Map from event ID in `event_map` to their auth event IDs, and their auth
|
||||
# event IDs if they appear in the `event_map`. This is the intersection of
|
||||
# the event's auth chain with the events in the `event_map` *plus* their
|
||||
# auth event IDs.
|
||||
events_to_auth_chain = {} # type: Dict[str, Set[str]]
|
||||
for event in event_map.values():
|
||||
chain = {event.event_id}
|
||||
events_to_auth_chain[event.event_id] = chain
|
||||
|
||||
to_search = [event]
|
||||
while to_search:
|
||||
for auth_id in to_search.pop().auth_event_ids():
|
||||
chain.add(auth_id)
|
||||
auth_event = event_map.get(auth_id)
|
||||
if auth_event:
|
||||
to_search.append(auth_event)
|
||||
|
||||
# We now a) calculate the auth chain difference for the unpersisted events
|
||||
# and b) work out the state sets to pass to the store.
|
||||
#
|
||||
# Note: If the `event_map` is empty (which is the common case), we can do a
|
||||
# much simpler calculation.
|
||||
if event_map:
|
||||
# The list of state sets to pass to the store, where each state set is a set
|
||||
# of the event ids making up the state. This is similar to `state_sets`,
|
||||
# except that (a) we only have event ids, not the complete
|
||||
# ((type, state_key)->event_id) mappings; and (b) we have stripped out
|
||||
# unpersisted events and replaced them with the persisted events in
|
||||
# their auth chain.
|
||||
state_sets_ids = [] # type: List[Set[str]]
|
||||
|
||||
# For each state set, the unpersisted event IDs reachable (by their auth
|
||||
# chain) from the events in that set.
|
||||
unpersisted_set_ids = [] # type: List[Set[str]]
|
||||
|
||||
for state_set in state_sets:
|
||||
set_ids = set() # type: Set[str]
|
||||
state_sets_ids.append(set_ids)
|
||||
|
||||
unpersisted_ids = set() # type: Set[str]
|
||||
unpersisted_set_ids.append(unpersisted_ids)
|
||||
|
||||
for event_id in state_set.values():
|
||||
event_chain = events_to_auth_chain.get(event_id)
|
||||
if event_chain is not None:
|
||||
# We have an event in `event_map`. We add all the auth
|
||||
# events that it references (that aren't also in `event_map`).
|
||||
set_ids.update(e for e in event_chain if e not in event_map)
|
||||
|
||||
# We also add the full chain of unpersisted event IDs
|
||||
# referenced by this state set, so that we can work out the
|
||||
# auth chain difference of the unpersisted events.
|
||||
unpersisted_ids.update(e for e in event_chain if e in event_map)
|
||||
else:
|
||||
set_ids.add(event_id)
|
||||
|
||||
# The auth chain difference of the unpersisted events of the state sets
|
||||
# is calculated by taking the difference between the union and
|
||||
# intersections.
|
||||
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
|
||||
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
|
||||
|
||||
difference_from_event_map = union - intersection # type: Collection[str]
|
||||
else:
|
||||
difference_from_event_map = ()
|
||||
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
|
||||
|
||||
difference = await state_res_store.get_auth_chain_difference(
|
||||
[set(state_set.values()) for state_set in state_sets]
|
||||
room_id, state_sets_ids
|
||||
)
|
||||
difference.update(difference_from_event_map)
|
||||
|
||||
return difference
|
||||
|
||||
|
||||
@@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
|
||||
return list(results)
|
||||
|
||||
async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
|
||||
async def get_auth_chain_difference(
|
||||
self, room_id: str, state_sets: List[Set[str]]
|
||||
) -> Set[str]:
|
||||
"""Given sets of state events figure out the auth chain difference (as
|
||||
per state res v2 algorithm).
|
||||
|
||||
|
||||
@@ -566,6 +566,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
desc="get_user_by_external_id",
|
||||
)
|
||||
|
||||
async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
|
||||
"""Look up external ids for the given user
|
||||
|
||||
Args:
|
||||
mxid: the MXID to be looked up
|
||||
|
||||
Returns:
|
||||
Tuples of (auth_provider, external_id)
|
||||
"""
|
||||
res = await self.db_pool.simple_select_list(
|
||||
table="user_external_ids",
|
||||
keyvalues={"user_id": mxid},
|
||||
retcols=("auth_provider", "external_id"),
|
||||
desc="get_external_ids_by_user",
|
||||
)
|
||||
return [(r["auth_provider"], r["external_id"]) for r in res]
|
||||
|
||||
async def count_all_users(self):
|
||||
"""Counts all users registered on the homeserver."""
|
||||
|
||||
@@ -1066,6 +1083,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
"user_external_ids_user_id_idx",
|
||||
index_name="user_external_ids_user_id_idx",
|
||||
table="user_external_ids",
|
||||
columns=["user_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
async def _background_update_set_deactivated_flag(self, progress, batch_size):
|
||||
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
||||
for each of them.
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(5825, 'user_external_ids_user_id_idx', '{}');
|
||||
@@ -15,28 +15,56 @@
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import itertools
|
||||
from typing import Any, Iterable, Tuple, Type
|
||||
|
||||
import jsonschema
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config._util import json_error_to_config_error
|
||||
|
||||
|
||||
def load_module(provider):
|
||||
def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
|
||||
""" Loads a synapse module with its config
|
||||
Take a dict with keys 'module' (the module name) and 'config'
|
||||
(the config dict).
|
||||
|
||||
Args:
|
||||
provider: a dict with keys 'module' (the module name) and 'config'
|
||||
(the config dict).
|
||||
config_path: the path within the config file. This will be used as a basis
|
||||
for any error message.
|
||||
|
||||
Returns
|
||||
Tuple of (provider class, parsed config object)
|
||||
"""
|
||||
|
||||
modulename = provider.get("module")
|
||||
if not isinstance(modulename, str):
|
||||
raise ConfigError(
|
||||
"expected a string", path=itertools.chain(config_path, ("module",))
|
||||
)
|
||||
|
||||
# We need to import the module, and then pick the class out of
|
||||
# that, so we split based on the last dot.
|
||||
module, clz = provider["module"].rsplit(".", 1)
|
||||
module, clz = modulename.rsplit(".", 1)
|
||||
module = importlib.import_module(module)
|
||||
provider_class = getattr(module, clz)
|
||||
|
||||
module_config = provider.get("config")
|
||||
try:
|
||||
provider_config = provider_class.parse_config(provider.get("config"))
|
||||
provider_config = provider_class.parse_config(module_config)
|
||||
except jsonschema.ValidationError as e:
|
||||
raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
|
||||
except ConfigError as e:
|
||||
raise _wrap_config_error(
|
||||
"Failed to parse config for module %r" % (modulename,),
|
||||
prefix=itertools.chain(config_path, ("config",)),
|
||||
e=e,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
|
||||
raise ConfigError(
|
||||
"Failed to parse config for module %r" % (modulename,),
|
||||
path=itertools.chain(config_path, ("config",)),
|
||||
) from e
|
||||
|
||||
return provider_class, provider_config
|
||||
|
||||
@@ -56,3 +84,27 @@ def load_python_module(location: str):
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod) # type: ignore
|
||||
return mod
|
||||
|
||||
|
||||
def _wrap_config_error(
|
||||
msg: str, prefix: Iterable[str], e: ConfigError
|
||||
) -> "ConfigError":
|
||||
"""Wrap a relative ConfigError with a new path
|
||||
|
||||
This is useful when we have a ConfigError with a relative path due to a problem
|
||||
parsing part of the config, and we now need to set it in context.
|
||||
"""
|
||||
path = prefix
|
||||
if e.path:
|
||||
path = itertools.chain(prefix, e.path)
|
||||
|
||||
e1 = ConfigError(msg, path)
|
||||
|
||||
# ideally we would set the 'cause' of the new exception to the original exception;
|
||||
# however now that we have merged the path into our own, the stringification of
|
||||
# e will be incorrect, so instead we create a new exception with just the "msg"
|
||||
# part.
|
||||
|
||||
e1.__cause__ = Exception(e.msg)
|
||||
e1.__cause__.__cause__ = e.__cause__
|
||||
return e1
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import jsonschema
|
||||
|
||||
from twisted.internet import defer
|
||||
@@ -28,7 +26,7 @@ from synapse.api.filtering import Filter
|
||||
from synapse.events import make_event_from_dict
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
user_localpart = "test_user"
|
||||
|
||||
@@ -42,21 +40,9 @@ def MockEvent(**kwargs):
|
||||
|
||||
|
||||
class FilteringTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_federation_resource = MockHttpResource()
|
||||
|
||||
self.mock_http_client = Mock(spec=[])
|
||||
self.mock_http_client.put_json = DeferredMockCallable()
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
federation_http_client=self.mock_http_client,
|
||||
keyring=Mock(),
|
||||
)
|
||||
|
||||
hs = setup_test_homeserver(self.addCleanup)
|
||||
self.filtering = hs.get_filtering()
|
||||
|
||||
self.datastore = hs.get_datastore()
|
||||
|
||||
def test_errors_on_invalid_filters(self):
|
||||
|
||||
@@ -17,30 +17,15 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
import attr
|
||||
import pymacaroons
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
||||
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests.test_utils import FakeResponse
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeResponse:
|
||||
code = attr.ib()
|
||||
body = attr.ib()
|
||||
phrase = attr.ib()
|
||||
|
||||
def deliverBody(self, protocol):
|
||||
protocol.dataReceived(self.body)
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
|
||||
# These are a few constants that are used as config parameters in the tests.
|
||||
ISSUER = "https://issuer/"
|
||||
CLIENT_ID = "test-client-id"
|
||||
|
||||
@@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
federation_http_client=None,
|
||||
resource_for_federation=Mock(),
|
||||
federation_client=self.mock_federation,
|
||||
federation_server=Mock(),
|
||||
federation_registry=self.mock_registry,
|
||||
|
||||
@@ -580,7 +580,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Mock Synapse's http json post method to check for the internal bind call
|
||||
post_json_get_json = Mock(return_value=make_awaitable(None))
|
||||
self.hs.get_simple_http_client().post_json_get_json = post_json_get_json
|
||||
self.hs.get_identity_handler().http_client.post_json_get_json = post_json_get_json
|
||||
|
||||
# Retrieve a UIA session ID
|
||||
channel = self.uia_register(
|
||||
|
||||
@@ -15,18 +15,20 @@
|
||||
|
||||
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from mock import ANY, Mock, call
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
from synapse.types import UserID, create_requester
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.unittest import override_config
|
||||
from tests.utils import register_federation_servlets
|
||||
|
||||
# Some local users to test with
|
||||
U_APPLE = UserID.from_string("@apple:test")
|
||||
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
|
||||
|
||||
|
||||
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [register_federation_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
# we mock out the keyring so as to skip the authentication check on the
|
||||
# federation API call.
|
||||
@@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
d["/_matrix/federation"] = TransportLayerServer(self.hs)
|
||||
return d
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
mock_notifier = hs.get_notifier()
|
||||
self.on_new_event = mock_notifier.on_new_event
|
||||
|
||||
@@ -18,6 +18,7 @@ from twisted.internet.defer import Deferred
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.rest.client.v2_alpha import receipts
|
||||
|
||||
@@ -34,6 +35,11 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
user_id = True
|
||||
hijack_auth = False
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["start_pushers"] = True
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.push_attempts = []
|
||||
|
||||
@@ -46,15 +52,49 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
m.post_json_get_json = post_json_get_json
|
||||
|
||||
config = self.default_config()
|
||||
config["start_pushers"] = True
|
||||
|
||||
hs = self.setup_test_homeserver(
|
||||
config=config, proxied_blacklisted_http_client=m
|
||||
)
|
||||
hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
|
||||
|
||||
return hs
|
||||
|
||||
def test_invalid_configuration(self):
|
||||
"""Invalid push configurations should be rejected."""
|
||||
# Register the user who gets notified
|
||||
user_id = self.register_user("user", "pass")
|
||||
access_token = self.login("user", "pass")
|
||||
|
||||
# Register the pusher
|
||||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
def test_data(data):
|
||||
self.get_failure(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
app_id="m.http",
|
||||
app_display_name="HTTP Push Notifications",
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data=data,
|
||||
),
|
||||
PusherConfigException,
|
||||
)
|
||||
|
||||
# Data must be provided with a URL.
|
||||
test_data(None)
|
||||
test_data({})
|
||||
test_data({"url": 1})
|
||||
# A bare domain name isn't accepted.
|
||||
test_data({"url": "example.com"})
|
||||
# A URL without a path isn't accepted.
|
||||
test_data({"url": "http://example.com"})
|
||||
# A url with an incorrect path isn't accepted.
|
||||
test_data({"url": "http://example.com/foo"})
|
||||
|
||||
def test_sends_http(self):
|
||||
"""
|
||||
The HTTP pusher will send pushes for each message to a HTTP endpoint
|
||||
@@ -84,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "example.com"},
|
||||
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -119,7 +159,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
# One push was attempted to be sent -- it'll be the first message
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
self.assertEqual(self.push_attempts[0][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
self.assertEqual(
|
||||
self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
|
||||
)
|
||||
@@ -139,7 +181,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
# Now it'll try and send the second push message, which will be the second one
|
||||
self.assertEqual(len(self.push_attempts), 2)
|
||||
self.assertEqual(self.push_attempts[1][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
self.assertEqual(
|
||||
self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
|
||||
)
|
||||
@@ -196,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "example.com"},
|
||||
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -232,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
# Check our push made it with high priority
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
self.assertEqual(self.push_attempts[0][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
|
||||
|
||||
# Add yet another person — we want to make this room not a 1:1
|
||||
@@ -270,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
# Advance time a bit, so the pusher will register something has happened
|
||||
self.pump()
|
||||
self.assertEqual(len(self.push_attempts), 2)
|
||||
self.assertEqual(self.push_attempts[1][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
|
||||
|
||||
def test_sends_high_priority_for_one_to_one_only(self):
|
||||
@@ -312,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "example.com"},
|
||||
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -328,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
# Check our push made it with high priority — this is a one-to-one room
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
self.assertEqual(self.push_attempts[0][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
|
||||
|
||||
# Yet another user joins
|
||||
@@ -347,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
# Advance time a bit, so the pusher will register something has happened
|
||||
self.pump()
|
||||
self.assertEqual(len(self.push_attempts), 2)
|
||||
self.assertEqual(self.push_attempts[1][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
|
||||
# check that this is high-priority
|
||||
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
|
||||
@@ -394,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "example.com"},
|
||||
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -410,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
# Check our push made it with high priority
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
self.assertEqual(self.push_attempts[0][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
|
||||
|
||||
# Send another event, this time with no mention
|
||||
@@ -419,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
# Advance time a bit, so the pusher will register something has happened
|
||||
self.pump()
|
||||
self.assertEqual(len(self.push_attempts), 2)
|
||||
self.assertEqual(self.push_attempts[1][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
|
||||
# check that this is high-priority
|
||||
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
|
||||
@@ -467,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "example.com"},
|
||||
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -487,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
# Check our push made it with high priority
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
self.assertEqual(self.push_attempts[0][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
|
||||
|
||||
# Send another event, this time as someone without the power of @room
|
||||
@@ -498,7 +556,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
# Advance time a bit, so the pusher will register something has happened
|
||||
self.pump()
|
||||
self.assertEqual(len(self.push_attempts), 2)
|
||||
self.assertEqual(self.push_attempts[1][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
|
||||
# check that this is high-priority
|
||||
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
|
||||
@@ -572,7 +632,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "example.com"},
|
||||
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -591,7 +651,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
# Check our push made it
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
self.assertEqual(self.push_attempts[0][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
|
||||
)
|
||||
|
||||
# Check that the unread count for the room is 0
|
||||
#
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.web.http import HTTPChannel
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.app.generic_worker import (
|
||||
GenericWorkerReplicationHandler,
|
||||
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
|
||||
)
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.site import SynapseRequest, SynapseSite
|
||||
from synapse.replication.http import ReplicationRestResource, streams
|
||||
from synapse.replication.http import ReplicationRestResource
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
if not hiredis:
|
||||
skip = "Requires hiredis"
|
||||
|
||||
servlets = [
|
||||
streams.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# build a replication server
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
self._client_transport = None
|
||||
self._server_transport = None
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
|
||||
return d
|
||||
|
||||
def _get_worker_hs_config(self) -> dict:
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.generic_worker"
|
||||
|
||||
119
tests/replication/test_auth.py
Normal file
119
tests/replication/test_auth.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v2_alpha import register
|
||||
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.server import FakeChannel, make_request
|
||||
from tests.unittest import override_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
|
||||
"""Test the authentication of HTTP calls between workers."""
|
||||
|
||||
servlets = [register.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
config = self.default_config()
|
||||
# This isn't a real configuration option but is used to provide the main
|
||||
# homeserver and worker homeserver different options.
|
||||
main_replication_secret = config.pop("main_replication_secret", None)
|
||||
if main_replication_secret:
|
||||
config["worker_replication_secret"] = main_replication_secret
|
||||
return self.setup_test_homeserver(config=config)
|
||||
|
||||
def _get_worker_hs_config(self) -> dict:
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.client_reader"
|
||||
config["worker_replication_host"] = "testserv"
|
||||
config["worker_replication_http_port"] = "8765"
|
||||
|
||||
return config
|
||||
|
||||
def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]:
|
||||
"""Run the actual test:
|
||||
|
||||
1. Create a worker homeserver.
|
||||
2. Start registration by providing a user/password.
|
||||
3. Complete registration by providing dummy auth (this hits the main synapse).
|
||||
4. Return the final request.
|
||||
|
||||
"""
|
||||
worker_hs = self.make_worker_hs("synapse.app.client_reader")
|
||||
site = self._hs_to_site[worker_hs]
|
||||
|
||||
request_1, channel_1 = make_request(
|
||||
self.reactor,
|
||||
site,
|
||||
"POST",
|
||||
"register",
|
||||
{"username": "user", "type": "m.login.password", "password": "bar"},
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
self.assertEqual(request_1.code, 401)
|
||||
|
||||
# Grab the session
|
||||
session = channel_1.json_body["session"]
|
||||
|
||||
# also complete the dummy auth
|
||||
return make_request(
|
||||
self.reactor,
|
||||
site,
|
||||
"POST",
|
||||
"register",
|
||||
{"auth": {"session": session, "type": "m.login.dummy"}},
|
||||
)
|
||||
|
||||
def test_no_auth(self):
|
||||
"""With no authentication the request should finish.
|
||||
"""
|
||||
request, channel = self._test_register()
|
||||
self.assertEqual(request.code, 200)
|
||||
|
||||
# We're given a registered user.
|
||||
self.assertEqual(channel.json_body["user_id"], "@user:test")
|
||||
|
||||
@override_config({"main_replication_secret": "my-secret"})
|
||||
def test_missing_auth(self):
|
||||
"""If the main process expects a secret that is not provided, an error results.
|
||||
"""
|
||||
request, channel = self._test_register()
|
||||
self.assertEqual(request.code, 500)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"main_replication_secret": "my-secret",
|
||||
"worker_replication_secret": "wrong-secret",
|
||||
}
|
||||
)
|
||||
def test_unauthorized(self):
|
||||
"""If the main process receives the wrong secret, an error results.
|
||||
"""
|
||||
request, channel = self._test_register()
|
||||
self.assertEqual(request.code, 500)
|
||||
|
||||
@override_config({"worker_replication_secret": "my-secret"})
|
||||
def test_authorized(self):
|
||||
"""The request should finish when the worker provides the authentication header.
|
||||
"""
|
||||
request, channel = self._test_register()
|
||||
self.assertEqual(request.code, 200)
|
||||
|
||||
# We're given a registered user.
|
||||
self.assertEqual(channel.json_body["user_id"], "@user:test")
|
||||
@@ -14,27 +14,20 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v2_alpha import register
|
||||
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
|
||||
from tests.server import FakeChannel, make_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
|
||||
"""Base class for tests of the replication streams"""
|
||||
"""Test using one or more client readers for registration."""
|
||||
|
||||
servlets = [register.register_servlets]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.recaptcha_checker = DummyRecaptchaChecker(hs)
|
||||
auth_handler = hs.get_auth_handler()
|
||||
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
|
||||
|
||||
def _get_worker_hs_config(self) -> dict:
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.client_reader"
|
||||
|
||||
@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "https://push.example.com/push"},
|
||||
data={"url": "https://push.example.com/_matrix/push/v1/notify"},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||
http_client_mock.post_json_get_json.assert_called_once()
|
||||
self.assertEqual(
|
||||
http_client_mock.post_json_get_json.call_args[0][0],
|
||||
"https://push.example.com/push",
|
||||
"https://push.example.com/_matrix/push/v1/notify",
|
||||
)
|
||||
self.assertEqual(
|
||||
event_id,
|
||||
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||
http_client_mock2.post_json_get_json.assert_not_called()
|
||||
self.assertEqual(
|
||||
http_client_mock1.post_json_get_json.call_args[0][0],
|
||||
"https://push.example.com/push",
|
||||
"https://push.example.com/_matrix/push/v1/notify",
|
||||
)
|
||||
self.assertEqual(
|
||||
event_id,
|
||||
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||
http_client_mock2.post_json_get_json.assert_called_once()
|
||||
self.assertEqual(
|
||||
http_client_mock2.post_json_get_json.call_args[0][0],
|
||||
"https://push.example.com/push",
|
||||
"https://push.example.com/_matrix/push/v1/notify",
|
||||
)
|
||||
self.assertEqual(
|
||||
event_id,
|
||||
|
||||
@@ -561,7 +561,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
"admin": True,
|
||||
"displayname": "Bob's name",
|
||||
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
|
||||
"avatar_url": None,
|
||||
"avatar_url": "mxc://fibble/wibble",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -578,6 +578,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual(True, channel.json_body["admin"])
|
||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||
|
||||
# Get user
|
||||
request, channel = self.make_request(
|
||||
@@ -592,6 +593,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(True, channel.json_body["admin"])
|
||||
self.assertEqual(False, channel.json_body["is_guest"])
|
||||
self.assertEqual(False, channel.json_body["deactivated"])
|
||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||
|
||||
def test_create_user(self):
|
||||
"""
|
||||
@@ -606,6 +608,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
"admin": False,
|
||||
"displayname": "Bob's name",
|
||||
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
|
||||
"avatar_url": "mxc://fibble/wibble",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -622,6 +625,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual(False, channel.json_body["admin"])
|
||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||
|
||||
# Get user
|
||||
request, channel = self.make_request(
|
||||
@@ -636,6 +640,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(False, channel.json_body["admin"])
|
||||
self.assertEqual(False, channel.json_body["is_guest"])
|
||||
self.assertEqual(False, channel.json_body["deactivated"])
|
||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||
|
||||
@override_config(
|
||||
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
|
||||
@@ -1256,7 +1261,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "example.com"},
|
||||
data={"url": "https://example.com/_matrix/push/v1/notify"},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2018-2019 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -17,17 +17,23 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from mock import patch
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Site
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests.server import FakeSite, make_request
|
||||
from tests.test_utils import FakeResponse
|
||||
|
||||
|
||||
@attr.s
|
||||
@@ -344,3 +350,111 @@ class RestHelper:
|
||||
)
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
|
||||
"""Log in (as a new user) via OIDC
|
||||
|
||||
Returns the result of the final token login.
|
||||
|
||||
Requires that "oidc_config" in the homeserver config be set appropriately
|
||||
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
|
||||
"public_base_url".
|
||||
|
||||
Also requires the login servlet and the OIDC callback resource to be mounted at
|
||||
the normal places.
|
||||
"""
|
||||
client_redirect_url = "https://x"
|
||||
|
||||
# first hit the redirect url (which will issue a cookie and state)
|
||||
_, channel = make_request(
|
||||
self.hs.get_reactor(),
|
||||
self.site,
|
||||
"GET",
|
||||
"/login/sso/redirect?redirectUrl=" + client_redirect_url,
|
||||
)
|
||||
# that will redirect to the OIDC IdP, but we skip that and go straight
|
||||
# back to synapse's OIDC callback resource. However, we do need the "state"
|
||||
# param that synapse passes to the IdP via query params, and the cookie that
|
||||
# synapse passes to the client.
|
||||
assert channel.code == 302
|
||||
oauth_uri = channel.headers.getRawHeaders("Location")[0]
|
||||
params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
|
||||
redirect_uri = "%s?%s" % (
|
||||
urllib.parse.urlparse(params["redirect_uri"][0]).path,
|
||||
urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
|
||||
)
|
||||
cookies = {}
|
||||
for h in channel.headers.getRawHeaders("Set-Cookie"):
|
||||
parts = h.split(";")
|
||||
k, v = parts[0].split("=", maxsplit=1)
|
||||
cookies[k] = v
|
||||
|
||||
# before we hit the callback uri, stub out some methods in the http client so
|
||||
# that we don't have to handle full HTTPS requests.
|
||||
|
||||
# (expected url, json response) pairs, in the order we expect them.
|
||||
expected_requests = [
|
||||
# first we get a hit to the token endpoint, which we tell to return
|
||||
# a dummy OIDC access token
|
||||
("https://issuer.test/token", {"access_token": "TEST"}),
|
||||
# and then one to the user_info endpoint, which returns our remote user id.
|
||||
("https://issuer.test/userinfo", {"sub": remote_user_id}),
|
||||
]
|
||||
|
||||
async def mock_req(method: str, uri: str, data=None, headers=None):
|
||||
(expected_uri, resp_obj) = expected_requests.pop(0)
|
||||
assert uri == expected_uri
|
||||
resp = FakeResponse(
|
||||
code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
|
||||
)
|
||||
return resp
|
||||
|
||||
with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
|
||||
# now hit the callback URI with the right params and a made-up code
|
||||
_, channel = make_request(
|
||||
self.hs.get_reactor(),
|
||||
self.site,
|
||||
"GET",
|
||||
redirect_uri,
|
||||
custom_headers=[
|
||||
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
|
||||
],
|
||||
)
|
||||
|
||||
# expect a confirmation page
|
||||
assert channel.code == 200
|
||||
|
||||
# fish the matrix login token out of the body of the confirmation page
|
||||
m = re.search(
|
||||
'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
|
||||
channel.result["body"].decode("utf-8"),
|
||||
)
|
||||
assert m
|
||||
login_token = m.group(1)
|
||||
|
||||
# finally, submit the matrix login token to the login API, which gives us our
|
||||
# matrix access token and device id.
|
||||
_, channel = make_request(
|
||||
self.hs.get_reactor(),
|
||||
self.site,
|
||||
"POST",
|
||||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
)
|
||||
assert channel.code == 200
|
||||
return channel.json_body
|
||||
|
||||
|
||||
# an 'oidc_config' suitable for login_with_oidc.
|
||||
TEST_OIDC_CONFIG = {
|
||||
"enabled": True,
|
||||
"discover": False,
|
||||
"issuer": "https://issuer.test",
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
"scopes": ["profile"],
|
||||
"authorization_endpoint": "https://z",
|
||||
"token_endpoint": "https://issuer.test/token",
|
||||
"userinfo_endpoint": "https://issuer.test/userinfo",
|
||||
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from twisted.internet.defer import succeed
|
||||
@@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v1 import login
|
||||
from synapse.rest.client.v2_alpha import auth, devices, register
|
||||
from synapse.types import JsonDict
|
||||
from synapse.rest.oidc import OIDCResource
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
|
||||
from tests.server import FakeChannel
|
||||
|
||||
|
||||
@@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
register.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
|
||||
# we enable OIDC as a way of testing SSO flows
|
||||
oidc_config = {}
|
||||
oidc_config.update(TEST_OIDC_CONFIG)
|
||||
oidc_config["allow_existing_users"] = True
|
||||
|
||||
config["oidc_config"] = oidc_config
|
||||
config["public_baseurl"] = "https://synapse.test"
|
||||
return config
|
||||
|
||||
def create_resource_dict(self):
|
||||
resource_dict = super().create_resource_dict()
|
||||
# mount the OIDC resource at /_synapse/oidc
|
||||
resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
|
||||
return resource_dict
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.user_pass = "pass"
|
||||
self.user = self.register_user("test", self.user_pass)
|
||||
self.user_tok = self.login("test", self.user_pass)
|
||||
|
||||
def get_device_ids(self) -> List[str]:
|
||||
def get_device_ids(self, access_token: str) -> List[str]:
|
||||
# Get the list of devices so one can be deleted.
|
||||
request, channel = self.make_request(
|
||||
"GET", "devices", access_token=self.user_tok,
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
|
||||
# Get the ID of the device.
|
||||
self.assertEqual(request.code, 200)
|
||||
_, channel = self.make_request("GET", "devices", access_token=access_token,)
|
||||
self.assertEqual(channel.code, 200)
|
||||
return [d["device_id"] for d in channel.json_body["devices"]]
|
||||
|
||||
def delete_device(
|
||||
self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
|
||||
self,
|
||||
access_token: str,
|
||||
device: str,
|
||||
expected_response: int,
|
||||
body: Union[bytes, JsonDict] = b"",
|
||||
) -> FakeChannel:
|
||||
"""Delete an individual device."""
|
||||
request, channel = self.make_request(
|
||||
"DELETE", "devices/" + device, body, access_token=self.user_tok
|
||||
"DELETE", "devices/" + device, body, access_token=access_token,
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
|
||||
# Ensure the response is sane.
|
||||
@@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Test user interactive authentication outside of registration.
|
||||
"""
|
||||
device_id = self.get_device_ids()[0]
|
||||
device_id = self.get_device_ids(self.user_tok)[0]
|
||||
|
||||
# Attempt to delete this device.
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.delete_device(device_id, 401)
|
||||
channel = self.delete_device(self.user_tok, device_id, 401)
|
||||
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
@@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
|
||||
# Make another request providing the UI auth flow.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
device_id,
|
||||
200,
|
||||
{
|
||||
@@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
UIA - check that still works.
|
||||
"""
|
||||
|
||||
device_id = self.get_device_ids()[0]
|
||||
channel = self.delete_device(device_id, 401)
|
||||
device_id = self.get_device_ids(self.user_tok)[0]
|
||||
channel = self.delete_device(self.user_tok, device_id, 401)
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Make another request providing the UI auth flow.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
device_id,
|
||||
200,
|
||||
{
|
||||
@@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
# Create a second login.
|
||||
self.login("test", self.user_pass)
|
||||
|
||||
device_ids = self.get_device_ids()
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
self.assertEqual(len(device_ids), 2)
|
||||
|
||||
# Attempt to delete the first device.
|
||||
@@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
# Create a second login.
|
||||
self.login("test", self.user_pass)
|
||||
|
||||
device_ids = self.get_device_ids()
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
self.assertEqual(len(device_ids), 2)
|
||||
|
||||
# Attempt to delete the first device.
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.delete_device(device_ids[0], 401)
|
||||
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
@@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
# Make another request providing the UI auth flow, but try to delete the
|
||||
# second device. This results in an error.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
device_ids[1],
|
||||
403,
|
||||
{
|
||||
@@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_does_not_offer_password_for_sso_user(self):
|
||||
login_resp = self.helper.login_via_oidc("username")
|
||||
user_tok = login_resp["access_token"]
|
||||
device_id = login_resp["device_id"]
|
||||
|
||||
# now call the device deletion API: we should get the option to auth with SSO
|
||||
# and not password.
|
||||
channel = self.delete_device(user_tok, device_id, 401)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
|
||||
|
||||
def test_does_not_offer_sso_for_password_user(self):
|
||||
# now call the device deletion API: we should get the option to auth with SSO
|
||||
# and not password.
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
self.assertEqual(flows, [{"stages": ["m.login.password"]}])
|
||||
|
||||
def test_offers_both_flows_for_upgraded_user(self):
|
||||
"""A user that had a password and then logged in with SSO should get both flows
|
||||
"""
|
||||
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
|
||||
self.assertEqual(login_resp["user_id"], self.user)
|
||||
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
# we have no particular expectations of ordering here
|
||||
self.assertIn({"stages": ["m.login.password"]}, flows)
|
||||
self.assertIn({"stages": ["m.login.sso"]}, flows)
|
||||
self.assertEqual(len(flows), 2)
|
||||
|
||||
@@ -120,6 +120,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
|
||||
self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
|
||||
def test_POST_guest_registration(self):
|
||||
self.hs.config.macaroon_secret_key = "test"
|
||||
|
||||
@@ -362,3 +362,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
"error": "Not found [b'example.com', b'12345']",
|
||||
},
|
||||
)
|
||||
|
||||
def test_x_robots_tag_header(self):
|
||||
"""
|
||||
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
|
||||
to not index, archive, or follow links in media.
|
||||
"""
|
||||
channel = self._req(b"inline; filename=out" + self.test_image.extension)
|
||||
|
||||
headers = channel.headers
|
||||
self.assertEqual(
|
||||
headers.getRawHeaders(b"X-Robots-Tag"),
|
||||
[b"noindex, nofollow, noarchive, noimageindex"],
|
||||
)
|
||||
|
||||
@@ -18,41 +18,15 @@ import re
|
||||
|
||||
from mock import patch
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet._resolver import HostResolution
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeResponse:
|
||||
version = attr.ib()
|
||||
code = attr.ib()
|
||||
phrase = attr.ib()
|
||||
headers = attr.ib()
|
||||
body = attr.ib()
|
||||
absoluteURI = attr.ib()
|
||||
|
||||
@property
|
||||
def request(self):
|
||||
@attr.s
|
||||
class FakeTransport:
|
||||
absoluteURI = self.absoluteURI
|
||||
|
||||
return FakeTransport()
|
||||
|
||||
def deliverBody(self, protocol):
|
||||
protocol.dataReceived(self.body)
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
|
||||
class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
|
||||
hijack_auth = True
|
||||
|
||||
@@ -216,8 +216,9 @@ def make_request(
|
||||
and not path.startswith(b"/_matrix")
|
||||
and not path.startswith(b"/_synapse")
|
||||
):
|
||||
if path.startswith(b"/"):
|
||||
path = path[1:]
|
||||
path = b"/_matrix/client/r0/" + path
|
||||
path = path.replace(b"//", b"/")
|
||||
|
||||
if not path.startswith(b"/"):
|
||||
path = b"/" + path
|
||||
@@ -258,6 +259,7 @@ def make_request(
|
||||
for k, v in custom_headers:
|
||||
req.requestHeaders.addRawHeader(k, v)
|
||||
|
||||
req.parseCookies()
|
||||
req.requestReceived(method, path, b"1.1")
|
||||
|
||||
if await_result:
|
||||
|
||||
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.event_auth import auth_types_for_event
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
|
||||
from synapse.state.v2 import (
|
||||
_get_auth_chain_difference,
|
||||
lexicographical_topological_sort,
|
||||
resolve_events_with_store,
|
||||
)
|
||||
from synapse.types import EventID
|
||||
|
||||
from tests import unittest
|
||||
@@ -587,6 +591,134 @@ class SimpleParamStateTestCase(unittest.TestCase):
|
||||
self.assert_dict(self.expected_combined_state, state)
|
||||
|
||||
|
||||
class AuthChainDifferenceTestCase(unittest.TestCase):
|
||||
"""We test that `_get_auth_chain_difference` correctly handles unpersisted
|
||||
events.
|
||||
"""
|
||||
|
||||
def test_simple(self):
|
||||
# Test getting the auth difference for a simple chain with a single
|
||||
# unpersisted event:
|
||||
#
|
||||
# Unpersisted | Persisted
|
||||
# |
|
||||
# C -|-> B -> A
|
||||
|
||||
a = FakeEvent(
|
||||
id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([], [])
|
||||
|
||||
b = FakeEvent(
|
||||
id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([a.event_id], [])
|
||||
|
||||
c = FakeEvent(
|
||||
id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([b.event_id], [])
|
||||
|
||||
persisted_events = {a.event_id: a, b.event_id: b}
|
||||
unpersited_events = {c.event_id: c}
|
||||
|
||||
state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
|
||||
|
||||
store = TestStateResolutionStore(persisted_events)
|
||||
|
||||
diff_d = _get_auth_chain_difference(
|
||||
ROOM_ID, state_sets, unpersited_events, store
|
||||
)
|
||||
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
||||
|
||||
self.assertEqual(difference, {c.event_id})
|
||||
|
||||
def test_multiple_unpersisted_chain(self):
|
||||
# Test getting the auth difference for a simple chain with multiple
|
||||
# unpersisted events:
|
||||
#
|
||||
# Unpersisted | Persisted
|
||||
# |
|
||||
# D -> C -|-> B -> A
|
||||
|
||||
a = FakeEvent(
|
||||
id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([], [])
|
||||
|
||||
b = FakeEvent(
|
||||
id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([a.event_id], [])
|
||||
|
||||
c = FakeEvent(
|
||||
id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([b.event_id], [])
|
||||
|
||||
d = FakeEvent(
|
||||
id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([c.event_id], [])
|
||||
|
||||
persisted_events = {a.event_id: a, b.event_id: b}
|
||||
unpersited_events = {c.event_id: c, d.event_id: d}
|
||||
|
||||
state_sets = [
|
||||
{"a": a.event_id, "b": b.event_id},
|
||||
{"c": c.event_id, "d": d.event_id},
|
||||
]
|
||||
|
||||
store = TestStateResolutionStore(persisted_events)
|
||||
|
||||
diff_d = _get_auth_chain_difference(
|
||||
ROOM_ID, state_sets, unpersited_events, store
|
||||
)
|
||||
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
||||
|
||||
self.assertEqual(difference, {d.event_id, c.event_id})
|
||||
|
||||
def test_unpersisted_events_different_sets(self):
|
||||
# Test getting the auth difference for with multiple unpersisted events
|
||||
# in different branches:
|
||||
#
|
||||
# Unpersisted | Persisted
|
||||
# |
|
||||
# D --> C -|-> B -> A
|
||||
# E ----^ -|---^
|
||||
# |
|
||||
|
||||
a = FakeEvent(
|
||||
id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([], [])
|
||||
|
||||
b = FakeEvent(
|
||||
id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([a.event_id], [])
|
||||
|
||||
c = FakeEvent(
|
||||
id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([b.event_id], [])
|
||||
|
||||
d = FakeEvent(
|
||||
id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([c.event_id], [])
|
||||
|
||||
e = FakeEvent(
|
||||
id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
|
||||
).to_event([c.event_id, b.event_id], [])
|
||||
|
||||
persisted_events = {a.event_id: a, b.event_id: b}
|
||||
unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
|
||||
|
||||
state_sets = [
|
||||
{"a": a.event_id, "b": b.event_id, "e": e.event_id},
|
||||
{"c": c.event_id, "d": d.event_id},
|
||||
]
|
||||
|
||||
store = TestStateResolutionStore(persisted_events)
|
||||
|
||||
diff_d = _get_auth_chain_difference(
|
||||
ROOM_ID, state_sets, unpersited_events, store
|
||||
)
|
||||
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
||||
|
||||
self.assertEqual(difference, {d.event_id, e.event_id})
|
||||
|
||||
|
||||
def pairwise(iterable):
|
||||
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
|
||||
a, b = itertools.tee(iterable)
|
||||
@@ -647,7 +779,7 @@ class TestStateResolutionStore:
|
||||
|
||||
return list(result)
|
||||
|
||||
def get_auth_chain_difference(self, auth_sets):
|
||||
def get_auth_chain_difference(self, room_id, auth_sets):
|
||||
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
|
||||
|
||||
common = set(chains[0]).intersection(*chains[1:])
|
||||
|
||||
@@ -202,34 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||
# Now actually test that various combinations give the right result:
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a"}, {"b"}])
|
||||
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
|
||||
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
|
||||
self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b", "c"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
|
||||
self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b", "d", "e"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
|
||||
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
|
||||
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b"})
|
||||
|
||||
difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference(room_id, [{"a"}])
|
||||
)
|
||||
self.assertSetEqual(difference, set())
|
||||
|
||||
@@ -14,9 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
@@ -30,12 +27,10 @@ from tests.utils import create_room
|
||||
|
||||
|
||||
class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
config = self.default_config()
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["redaction_retention_period"] = "30d"
|
||||
return self.setup_test_homeserver(
|
||||
resource_for_federation=Mock(), federation_http_client=None, config=config
|
||||
)
|
||||
return config
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.rest.admin import register_servlets_for_client_rest_resource
|
||||
from synapse.rest.client.v1 import login, room
|
||||
@@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver(
|
||||
resource_for_federation=Mock(), federation_http_client=None
|
||||
)
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs: TestHomeServer):
|
||||
|
||||
# We can't test the RoomMemberStore on its own without the other event
|
||||
|
||||
@@ -56,7 +56,7 @@ class PreviewTestCase(unittest.TestCase):
|
||||
|
||||
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
|
||||
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
desc,
|
||||
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
|
||||
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
|
||||
@@ -69,7 +69,7 @@ class PreviewTestCase(unittest.TestCase):
|
||||
|
||||
desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
|
||||
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
desc,
|
||||
"Tromsø lies in Northern Norway. The municipality has a population of"
|
||||
" (2015) 72,066, but with an annual influx of students it has over 75,000"
|
||||
@@ -96,7 +96,7 @@ class PreviewTestCase(unittest.TestCase):
|
||||
|
||||
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
|
||||
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
desc,
|
||||
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
|
||||
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
|
||||
@@ -122,7 +122,7 @@ class PreviewTestCase(unittest.TestCase):
|
||||
]
|
||||
|
||||
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
desc,
|
||||
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
|
||||
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
|
||||
@@ -149,7 +149,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_comment(self):
|
||||
html = """
|
||||
@@ -164,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_comment2(self):
|
||||
html = """
|
||||
@@ -182,7 +182,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
og,
|
||||
{
|
||||
"og:title": "Foo",
|
||||
@@ -203,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_missing_title(self):
|
||||
html = """
|
||||
@@ -216,7 +216,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
|
||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||
|
||||
def test_h1_as_title(self):
|
||||
html = """
|
||||
@@ -230,7 +230,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."})
|
||||
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
|
||||
|
||||
def test_missing_title_and_broken_h1(self):
|
||||
html = """
|
||||
@@ -244,4 +244,9 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
|
||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||
|
||||
def test_empty(self):
|
||||
html = ""
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
self.assertEqual(og, {})
|
||||
|
||||
@@ -22,6 +22,11 @@ import warnings
|
||||
from asyncio import Future
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.client import ResponseDone
|
||||
|
||||
TV = TypeVar("TV")
|
||||
|
||||
|
||||
@@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]:
|
||||
sys.unraisablehook = unraisablehook # type: ignore
|
||||
|
||||
return cleanup
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeResponse:
|
||||
"""A fake twisted.web.IResponse object
|
||||
|
||||
there is a similar class at treq.test.test_response, but it lacks a `phrase`
|
||||
attribute, and didn't support deliverBody until recently.
|
||||
"""
|
||||
|
||||
# HTTP response code
|
||||
code = attr.ib(type=int)
|
||||
|
||||
# HTTP response phrase (eg b'OK' for a 200)
|
||||
phrase = attr.ib(type=bytes)
|
||||
|
||||
# body of the response
|
||||
body = attr.ib(type=bytes)
|
||||
|
||||
def deliverBody(self, protocol):
|
||||
protocol.dataReceived(self.body)
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
@@ -20,7 +20,7 @@ import hmac
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Tuple, Type, TypeVar, Union, overload
|
||||
from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
@@ -46,6 +46,7 @@ from synapse.logging.context import (
|
||||
)
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID, create_requester
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
|
||||
@@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
|
||||
"""
|
||||
Create a the root resource for the test server.
|
||||
|
||||
The default implementation creates a JsonResource and calls each function in
|
||||
`servlets` to register servletes against it
|
||||
The default calls `self.create_resource_dict` and builds the resultant dict
|
||||
into a tree.
|
||||
"""
|
||||
resource = JsonResource(self.hs)
|
||||
root_resource = Resource()
|
||||
create_resource_tree(self.create_resource_dict(), root_resource)
|
||||
return root_resource
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
"""Create a resource tree for the test server
|
||||
|
||||
A resource tree is a mapping from path to twisted.web.resource.
|
||||
|
||||
The default implementation creates a JsonResource and calls each function in
|
||||
`servlets` to register servlets against it.
|
||||
"""
|
||||
servlet_resource = JsonResource(self.hs)
|
||||
for servlet in self.servlets:
|
||||
servlet(self.hs, resource)
|
||||
|
||||
return resource
|
||||
servlet(self.hs, servlet_resource)
|
||||
return {
|
||||
"/_matrix/client": servlet_resource,
|
||||
"/_synapse/admin": servlet_resource,
|
||||
}
|
||||
|
||||
def default_config(self):
|
||||
"""
|
||||
@@ -691,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||
A federating homeserver that authenticates incoming requests as `other.example.com`.
|
||||
"""
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
|
||||
return d
|
||||
|
||||
|
||||
class TestTransportLayerServer(JsonResource):
|
||||
"""A test implementation of TransportLayerServer
|
||||
|
||||
authenticates incoming requests as `other.example.com`.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
|
||||
class Authenticator:
|
||||
def authenticate_request(self, request, content):
|
||||
return succeed("other.example.com")
|
||||
|
||||
authenticator = Authenticator()
|
||||
|
||||
ratelimiter = FederationRateLimiter(
|
||||
clock,
|
||||
hs.get_clock(),
|
||||
FederationRateLimitConfig(
|
||||
window_size=1,
|
||||
sleep_limit=1,
|
||||
@@ -706,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||
concurrent_requests=1000,
|
||||
),
|
||||
)
|
||||
federation_server.register_servlets(
|
||||
homeserver, self.resource, Authenticator(), ratelimiter
|
||||
)
|
||||
|
||||
return super().prepare(reactor, clock, homeserver)
|
||||
federation_server.register_servlets(hs, self, authenticator, ratelimiter)
|
||||
|
||||
|
||||
def override_config(extra_config):
|
||||
|
||||
108
tests/utils.py
108
tests/utils.py
@@ -20,13 +20,12 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from inspect import getcallargs
|
||||
from typing import Type
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import CodeMessageException, cs_error
|
||||
@@ -34,7 +33,6 @@ from synapse.api.room_versions import RoomVersions
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||
from synapse.federation.transport import server as federation_server
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.logging.context import current_context, set_current_context
|
||||
from synapse.server import HomeServer
|
||||
@@ -42,7 +40,6 @@ from synapse.storage import DataStore
|
||||
from synapse.storage.database import LoggingDatabaseConnection
|
||||
from synapse.storage.engines import PostgresEngine, create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
# set this to True to run the tests against postgres instead of sqlite.
|
||||
#
|
||||
@@ -344,32 +341,9 @@ def setup_test_homeserver(
|
||||
|
||||
hs.get_auth_handler().validate_hash = validate_hash
|
||||
|
||||
fed = kwargs.get("resource_for_federation", None)
|
||||
if fed:
|
||||
register_federation_servlets(hs, fed)
|
||||
|
||||
return hs
|
||||
|
||||
|
||||
def register_federation_servlets(hs, resource):
|
||||
federation_server.register_servlets(
|
||||
hs,
|
||||
resource=resource,
|
||||
authenticator=federation_server.Authenticator(hs),
|
||||
ratelimiter=FederationRateLimiter(
|
||||
hs.get_clock(), config=hs.config.rc_federation
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_mock_call_args(pattern_func, mock_func):
|
||||
""" Return the arguments the mock function was called with interpreted
|
||||
by the pattern functions argument list.
|
||||
"""
|
||||
invoked_args, invoked_kargs = mock_func.call_args
|
||||
return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
|
||||
|
||||
|
||||
def mock_getRawHeaders(headers=None):
|
||||
headers = headers if headers is not None else {}
|
||||
|
||||
@@ -555,86 +529,6 @@ class MockClock:
|
||||
return d
|
||||
|
||||
|
||||
def _format_call(args, kwargs):
|
||||
return ", ".join(
|
||||
["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
|
||||
)
|
||||
|
||||
|
||||
class DeferredMockCallable:
|
||||
"""A callable instance that stores a set of pending call expectations and
|
||||
return values for them. It allows a unit test to assert that the given set
|
||||
of function calls are eventually made, by awaiting on them to be called.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.expectations = []
|
||||
self.calls = []
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.calls.append((args, kwargs))
|
||||
|
||||
if not self.expectations:
|
||||
raise ValueError(
|
||||
"%r has no pending calls to handle call(%s)"
|
||||
% (self, _format_call(args, kwargs))
|
||||
)
|
||||
|
||||
for (call, result, d) in self.expectations:
|
||||
if args == call[1] and kwargs == call[2]:
|
||||
d.callback(None)
|
||||
return result
|
||||
|
||||
failure = AssertionError(
|
||||
"Was not expecting call(%s)" % (_format_call(args, kwargs))
|
||||
)
|
||||
|
||||
for _, _, d in self.expectations:
|
||||
try:
|
||||
d.errback(failure)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise failure
|
||||
|
||||
def expect_call_and_return(self, call, result):
|
||||
self.expectations.append((call, result, defer.Deferred()))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def await_calls(self, timeout=1000):
|
||||
deferred = defer.DeferredList(
|
||||
[d for _, _, d in self.expectations], fireOnOneErrback=True
|
||||
)
|
||||
|
||||
timer = reactor.callLater(
|
||||
timeout / 1000,
|
||||
deferred.errback,
|
||||
AssertionError(
|
||||
"%d pending calls left: %s"
|
||||
% (
|
||||
len([e for e in self.expectations if not e[2].called]),
|
||||
[e for e in self.expectations if not e[2].called],
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
yield deferred
|
||||
|
||||
timer.cancel()
|
||||
|
||||
self.calls = []
|
||||
|
||||
def assert_had_no_calls(self):
|
||||
if self.calls:
|
||||
calls = self.calls
|
||||
self.calls = []
|
||||
|
||||
raise AssertionError(
|
||||
"Expected not to received any calls, got:\n"
|
||||
+ "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
|
||||
)
|
||||
|
||||
|
||||
async def create_room(hs, room_id: str, creator_id: str):
|
||||
"""Creates and persist a creation event for the given room
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user