Compare commits
2 Commits
dmr/debug-
...
anoa/docs_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3360be1829 | ||
|
|
19ca533bcc |
@@ -33,7 +33,7 @@ site-url = "/synapse/"
|
||||
additional-css = [
|
||||
"docs/website_files/table-of-contents.css",
|
||||
"docs/website_files/remove-nav-buttons.css",
|
||||
"docs/website_files/indent-section-headers.css",
|
||||
"docs/website_files/section-headers.css",
|
||||
]
|
||||
additional-js = ["docs/website_files/table-of-contents.js"]
|
||||
theme = "docs/website_files/theme"
|
||||
@@ -1 +0,0 @@
|
||||
Add config settings for background update parameters.
|
||||
@@ -1 +1 @@
|
||||
Add type hints to tests files.
|
||||
Add type hints to `tests/rest/client`.
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Add support for cancellation to `ReadWriteLock`.
|
||||
@@ -1 +1 @@
|
||||
Add type hints to tests files.
|
||||
Add type hints to `tests/rest`.
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Add `delay_cancellation` utility function, which behaves like `stop_cancellation` but waits until the original `Deferred` resolves before raising a `CancelledError`.
|
||||
@@ -1 +0,0 @@
|
||||
Add cancellation support to `@cached` and `@cachedList` decorators.
|
||||
@@ -1 +0,0 @@
|
||||
Add combined test for HTTP pusher and push rule. Contributed by Nick @ Beeper.
|
||||
@@ -1 +0,0 @@
|
||||
Document that the `typing`, `to_device`, `account_data`, `receipts`, and `presence` stream writer can only be used on a single worker.
|
||||
@@ -1 +0,0 @@
|
||||
Add tests for database transaction callbacks.
|
||||
@@ -1 +0,0 @@
|
||||
Handle cancellation in `DatabasePool.runInteraction()`.
|
||||
@@ -1 +0,0 @@
|
||||
The groups/communities feature in Synapse has been deprecated.
|
||||
@@ -1 +0,0 @@
|
||||
Avoid trying to calculate the state at outlier events.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a misleading comment in the function `check_event_for_spam`.
|
||||
@@ -1 +0,0 @@
|
||||
Document that contributors can sign off privately by email.
|
||||
@@ -1 +0,0 @@
|
||||
Remove unnecessary `pass` statements.
|
||||
@@ -1 +0,0 @@
|
||||
Add type hints to tests files.
|
||||
@@ -1 +0,0 @@
|
||||
Add type hints to tests files.
|
||||
@@ -1 +0,0 @@
|
||||
Update the SSO username picker template to comply with SIWA guidelines.
|
||||
@@ -1 +0,0 @@
|
||||
Improve code documentation for the typing stream over replication.
|
||||
@@ -1 +0,0 @@
|
||||
Add a new Jinja2 template filter to extract the local part of an email address.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a bug introduced in 1.54.0 that broke background updates on sqlite homeservers while search was disabled.
|
||||
@@ -1 +0,0 @@
|
||||
Add missing type hints for cache storage.
|
||||
@@ -1 +0,0 @@
|
||||
Clean-up logic around rebasing URLs for URL image previews.
|
||||
@@ -1 +0,0 @@
|
||||
Add type hints to tests files.
|
||||
@@ -1 +0,0 @@
|
||||
Use the `ignored_users` table in additional places instead of re-parsing the account data.
|
||||
@@ -1 +0,0 @@
|
||||
Refactor the relations endpoints to add a `RelationsHandler`.
|
||||
@@ -1 +0,0 @@
|
||||
Fix the link to the module documentation in the legacy spam checker warning message.
|
||||
@@ -1 +0,0 @@
|
||||
Refactor relations tests to improve code re-use.
|
||||
@@ -458,17 +458,6 @@ Git allows you to add this signoff automatically when using the `-s`
|
||||
flag to `git commit`, which uses the name and email set in your
|
||||
`user.name` and `user.email` git configs.
|
||||
|
||||
### Private Sign off
|
||||
|
||||
If you would like to provide your legal name privately to the Matrix.org
|
||||
Foundation (instead of in a public commit or comment), you can do so
|
||||
by emailing your legal name and a link to the pull request to
|
||||
[dco@matrix.org](mailto:dco@matrix.org?subject=Private%20sign%20off).
|
||||
It helps to include "sign off" or similar in the subject line. You will then
|
||||
be instructed further.
|
||||
|
||||
Once private sign off is complete, doing so for future contributions will not
|
||||
be required.
|
||||
|
||||
# 10. Turn feedback into better code.
|
||||
|
||||
|
||||
@@ -1947,14 +1947,8 @@ saml2_config:
|
||||
#
|
||||
# localpart_template: Jinja2 template for the localpart of the MXID.
|
||||
# If this is not set, the user will be prompted to choose their
|
||||
# own username (see the documentation for the
|
||||
# 'sso_auth_account_details.html' template). This template can
|
||||
# use the 'localpart_from_email' filter.
|
||||
#
|
||||
# confirm_localpart: Whether to prompt the user to validate (or
|
||||
# change) the generated localpart (see the documentation for the
|
||||
# 'sso_auth_account_details.html' template), instead of
|
||||
# registering the account right away.
|
||||
# own username (see 'sso_auth_account_details.html' in the 'sso'
|
||||
# section of this file).
|
||||
#
|
||||
# display_name_template: Jinja2 template for the display name to set
|
||||
# on first login. If unset, no displayname will be set.
|
||||
@@ -2735,35 +2729,3 @@ redis:
|
||||
# Optional password if configured on the Redis instance
|
||||
#
|
||||
#password: <secret_password>
|
||||
|
||||
|
||||
## Background Updates ##
|
||||
|
||||
# Background updates are database updates that are run in the background in batches.
|
||||
# The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to
|
||||
# sleep can all be configured. This is helpful to speed up or slow down the updates.
|
||||
#
|
||||
background_updates:
|
||||
# How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set
|
||||
# a time to change the default.
|
||||
#
|
||||
#background_update_duration_ms: 500
|
||||
|
||||
# Whether to sleep between updates. Defaults to True. Uncomment to change the default.
|
||||
#
|
||||
#sleep_enabled: false
|
||||
|
||||
# If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment
|
||||
# and set a duration to change the default.
|
||||
#
|
||||
#sleep_duration_ms: 300
|
||||
|
||||
# Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and
|
||||
# set a size to change the default.
|
||||
#
|
||||
#min_batch_size: 10
|
||||
|
||||
# The batch size to use for the first iteration of a new background update. The default is 100.
|
||||
# Uncomment and set a size to change the default.
|
||||
#
|
||||
#default_batch_size: 50
|
||||
|
||||
@@ -36,13 +36,6 @@ Turns a `mxc://` URL for media content into an HTTP(S) one using the homeserver'
|
||||
|
||||
Example: `message.sender_avatar_url|mxc_to_http(32,32)`
|
||||
|
||||
```python
|
||||
localpart_from_email(address: str) -> str
|
||||
```
|
||||
|
||||
Returns the local part of an email address (e.g. `alice` in `alice@example.com`).
|
||||
|
||||
Example: `user.email_address|localpart_from_email`
|
||||
|
||||
## Email templates
|
||||
|
||||
@@ -183,11 +176,8 @@ Below are the templates Synapse will look for when generating pages related to S
|
||||
for the brand of the IdP
|
||||
* `user_attributes`: an object containing details about the user that
|
||||
we received from the IdP. May have the following attributes:
|
||||
* `display_name`: the user's display name
|
||||
* `emails`: a list of email addresses
|
||||
* `localpart`: the local part of the Matrix user ID to register,
|
||||
if `localpart_template` is set in the mapping provider configuration (empty
|
||||
string if not)
|
||||
* display_name: the user's display_name
|
||||
* emails: a list of email addresses
|
||||
The template should render a form which submits the following fields:
|
||||
* `username`: the localpart of the user's chosen user id
|
||||
* `sso_new_user_consent.html`: HTML page allowing the user to consent to the
|
||||
|
||||
@@ -85,20 +85,6 @@ process, for example:
|
||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||
```
|
||||
|
||||
# Upgrading to v1.56.0
|
||||
|
||||
## Groups/communities feature has been deprecated
|
||||
|
||||
The non-standard groups/communities feature in Synapse has been deprecated and will
|
||||
be disabled by default in Synapse v1.58.0.
|
||||
|
||||
You can test disabling it by adding the following to your homeserver configuration:
|
||||
|
||||
```yaml
|
||||
experimental_features:
|
||||
groups_enabled: false
|
||||
```
|
||||
|
||||
# Upgrading to v1.55.0
|
||||
|
||||
## `synctl` script has been moved
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
/*
|
||||
* Indents each chapter title in the left sidebar so that they aren't
|
||||
* at the same level as the section headers.
|
||||
*/
|
||||
.chapter-item {
|
||||
margin-left: 1em;
|
||||
}
|
||||
20
docs/website_files/section-headers.css
Normal file
20
docs/website_files/section-headers.css
Normal file
@@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Indents each chapter title in the left sidebar so that they aren't
|
||||
* at the same level as the section headers.
|
||||
*/
|
||||
.chapter-item {
|
||||
margin-left: 1em;
|
||||
}
|
||||
|
||||
/*
|
||||
* Prevents a large gap between successive section headers.
|
||||
*
|
||||
* mdbook sets 'margin-top: 2.5em' on h2 and h3 headers. This makes sense when separating
|
||||
* a header from the paragraph beforehand, but has the downside of introducing a large
|
||||
* gap between headers that are next to each other with no text in between.
|
||||
*
|
||||
* This rule reduces the margin in this case.
|
||||
*/
|
||||
h1 + h2, h2 + h3 {
|
||||
margin-top: 1.0em;
|
||||
}
|
||||
@@ -351,11 +351,8 @@ is only supported with Redis-based replication.)
|
||||
|
||||
To enable this, the worker must have a HTTP replication listener configured,
|
||||
have a `worker_name` and be listed in the `instance_map` config. The same worker
|
||||
can handle multiple streams, but unless otherwise documented, each stream can only
|
||||
have a single writer.
|
||||
|
||||
For example, to move event persistence off to a dedicated worker, the shared
|
||||
configuration would include:
|
||||
can handle multiple streams. For example, to move event persistence off to a
|
||||
dedicated worker, the shared configuration would include:
|
||||
|
||||
```yaml
|
||||
instance_map:
|
||||
@@ -373,8 +370,8 @@ streams and the endpoints associated with them:
|
||||
|
||||
##### The `events` stream
|
||||
|
||||
The `events` stream experimentally supports having multiple writers, where work
|
||||
is sharded between them by room ID. Note that you *must* restart all worker
|
||||
The `events` stream also experimentally supports having multiple writers, where
|
||||
work is sharded between them by room ID. Note that you *must* restart all worker
|
||||
instances when adding or removing event persisters. An example `stream_writers`
|
||||
configuration with multiple writers:
|
||||
|
||||
@@ -387,38 +384,38 @@ stream_writers:
|
||||
|
||||
##### The `typing` stream
|
||||
|
||||
The following endpoints should be routed directly to the worker configured as
|
||||
the stream writer for the `typing` stream:
|
||||
The following endpoints should be routed directly to the workers configured as
|
||||
stream writers for the `typing` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing
|
||||
|
||||
##### The `to_device` stream
|
||||
|
||||
The following endpoints should be routed directly to the worker configured as
|
||||
the stream writer for the `to_device` stream:
|
||||
The following endpoints should be routed directly to the workers configured as
|
||||
stream writers for the `to_device` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/
|
||||
|
||||
##### The `account_data` stream
|
||||
|
||||
The following endpoints should be routed directly to the worker configured as
|
||||
the stream writer for the `account_data` stream:
|
||||
The following endpoints should be routed directly to the workers configured as
|
||||
stream writers for the `account_data` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data
|
||||
|
||||
##### The `receipts` stream
|
||||
|
||||
The following endpoints should be routed directly to the worker configured as
|
||||
the stream writer for the `receipts` stream:
|
||||
The following endpoints should be routed directly to the workers configured as
|
||||
stream writers for the `receipts` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers
|
||||
|
||||
##### The `presence` stream
|
||||
|
||||
The following endpoints should be routed directly to the worker configured as
|
||||
the stream writer for the `presence` stream:
|
||||
The following endpoints should be routed directly to the workers configured as
|
||||
stream writers for the `presence` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/
|
||||
|
||||
|
||||
6
mypy.ini
6
mypy.ini
@@ -67,8 +67,13 @@ exclude = (?x)
|
||||
|tests/federation/transport/test_knocking.py
|
||||
|tests/federation/transport/test_server.py
|
||||
|tests/handlers/test_cas.py
|
||||
|tests/handlers/test_directory.py
|
||||
|tests/handlers/test_e2e_keys.py
|
||||
|tests/handlers/test_federation.py
|
||||
|tests/handlers/test_oidc.py
|
||||
|tests/handlers/test_presence.py
|
||||
|tests/handlers/test_profile.py
|
||||
|tests/handlers/test_saml.py
|
||||
|tests/handlers/test_typing.py
|
||||
|tests/http/federation/test_matrix_federation_agent.py
|
||||
|tests/http/federation/test_srv_resolver.py
|
||||
@@ -85,6 +90,7 @@ exclude = (?x)
|
||||
|tests/push/test_push_rule_evaluator.py
|
||||
|tests/rest/client/test_transactions.py
|
||||
|tests/rest/media/v1/test_media_storage.py
|
||||
|tests/rest/media/v1/test_url_preview.py
|
||||
|tests/scripts/test_new_matrix_user.py
|
||||
|tests/server.py
|
||||
|tests/server_notices/test_resource_limits_server_notices.py
|
||||
|
||||
@@ -322,8 +322,7 @@ class GenericWorkerServer(HomeServer):
|
||||
|
||||
presence.register_servlets(self, resource)
|
||||
|
||||
if self.config.experimental.groups_enabled:
|
||||
groups.register_servlets(self, resource)
|
||||
groups.register_servlets(self, resource)
|
||||
|
||||
resources.update({CLIENT_API_PREFIX: resource})
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from synapse.config import (
|
||||
api,
|
||||
appservice,
|
||||
auth,
|
||||
background_updates,
|
||||
cache,
|
||||
captcha,
|
||||
cas,
|
||||
@@ -114,7 +113,6 @@ class RootConfig:
|
||||
caches: cache.CacheConfig
|
||||
federation: federation.FederationConfig
|
||||
retention: retention.RetentionConfig
|
||||
background_updates: background_updates.BackgroundUpdateConfig
|
||||
|
||||
config_classes: List[Type["Config"]] = ...
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
# Copyright 2022 Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class BackgroundUpdateConfig(Config):
|
||||
section = "background_updates"
|
||||
|
||||
def generate_config_section(self, **kwargs) -> str:
|
||||
return """\
|
||||
## Background Updates ##
|
||||
|
||||
# Background updates are database updates that are run in the background in batches.
|
||||
# The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to
|
||||
# sleep can all be configured. This is helpful to speed up or slow down the updates.
|
||||
#
|
||||
background_updates:
|
||||
# How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set
|
||||
# a time to change the default.
|
||||
#
|
||||
#background_update_duration_ms: 500
|
||||
|
||||
# Whether to sleep between updates. Defaults to True. Uncomment to change the default.
|
||||
#
|
||||
#sleep_enabled: false
|
||||
|
||||
# If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment
|
||||
# and set a duration to change the default.
|
||||
#
|
||||
#sleep_duration_ms: 300
|
||||
|
||||
# Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and
|
||||
# set a size to change the default.
|
||||
#
|
||||
#min_batch_size: 10
|
||||
|
||||
# The batch size to use for the first iteration of a new background update. The default is 100.
|
||||
# Uncomment and set a size to change the default.
|
||||
#
|
||||
#default_batch_size: 50
|
||||
"""
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
bg_update_config = config.get("background_updates") or {}
|
||||
|
||||
self.update_duration_ms = bg_update_config.get(
|
||||
"background_update_duration_ms", 100
|
||||
)
|
||||
|
||||
self.sleep_enabled = bg_update_config.get("sleep_enabled", True)
|
||||
|
||||
self.sleep_duration_ms = bg_update_config.get("sleep_duration_ms", 1000)
|
||||
|
||||
self.min_batch_size = bg_update_config.get("min_batch_size", 1)
|
||||
|
||||
self.default_batch_size = bg_update_config.get("default_batch_size", 100)
|
||||
@@ -74,6 +74,3 @@ class ExperimentalConfig(Config):
|
||||
|
||||
# MSC3720 (Account status endpoint)
|
||||
self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False)
|
||||
|
||||
# The deprecated groups feature.
|
||||
self.groups_enabled: bool = experimental.get("groups_enabled", True)
|
||||
|
||||
@@ -16,7 +16,6 @@ from .account_validity import AccountValidityConfig
|
||||
from .api import ApiConfig
|
||||
from .appservice import AppServiceConfig
|
||||
from .auth import AuthConfig
|
||||
from .background_updates import BackgroundUpdateConfig
|
||||
from .cache import CacheConfig
|
||||
from .captcha import CaptchaConfig
|
||||
from .cas import CasConfig
|
||||
@@ -100,5 +99,4 @@ class HomeServerConfig(RootConfig):
|
||||
WorkerConfig,
|
||||
RedisConfig,
|
||||
ExperimentalConfig,
|
||||
BackgroundUpdateConfig,
|
||||
]
|
||||
|
||||
@@ -182,14 +182,8 @@ class OIDCConfig(Config):
|
||||
#
|
||||
# localpart_template: Jinja2 template for the localpart of the MXID.
|
||||
# If this is not set, the user will be prompted to choose their
|
||||
# own username (see the documentation for the
|
||||
# 'sso_auth_account_details.html' template). This template can
|
||||
# use the 'localpart_from_email' filter.
|
||||
#
|
||||
# confirm_localpart: Whether to prompt the user to validate (or
|
||||
# change) the generated localpart (see the documentation for the
|
||||
# 'sso_auth_account_details.html' template), instead of
|
||||
# registering the account right away.
|
||||
# own username (see 'sso_auth_account_details.html' in the 'sso'
|
||||
# section of this file).
|
||||
#
|
||||
# display_name_template: Jinja2 template for the display name to set
|
||||
# on first login. If unset, no displayname will be set.
|
||||
|
||||
@@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
|
||||
LEGACY_SPAM_CHECKER_WARNING = """
|
||||
This server is using a spam checker module that is implementing the deprecated spam
|
||||
checker interface. Please check with the module's maintainer to see if a new version
|
||||
supporting Synapse's generic modules system is available. For more information, please
|
||||
see https://matrix-org.github.io/synapse/latest/modules/index.html
|
||||
supporting Synapse's generic modules system is available.
|
||||
For more information, please see https://matrix-org.github.io/synapse/latest/modules.html
|
||||
---------------------------------------------------------------------------------------"""
|
||||
|
||||
|
||||
|
||||
@@ -245,8 +245,8 @@ class SpamChecker:
|
||||
"""Checks if a given event is considered "spammy" by this server.
|
||||
|
||||
If the server considers an event spammy, then it will be rejected if
|
||||
sent by a local user. If it is sent by a user on another server, the
|
||||
event is soft-failed.
|
||||
sent by a local user. If it is sent by a user on another server, then
|
||||
users receive a blank event.
|
||||
|
||||
Args:
|
||||
event: the event to be checked
|
||||
|
||||
@@ -289,7 +289,7 @@ class OpenIdUserInfo(BaseFederationServlet):
|
||||
return 200, {"sub": user_id}
|
||||
|
||||
|
||||
SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
|
||||
DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
|
||||
"federation": FEDERATION_SERVLET_CLASSES,
|
||||
"room_list": (PublicRoomList,),
|
||||
"group_server": GROUP_SERVER_SERVLET_CLASSES,
|
||||
@@ -298,10 +298,6 @@ SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
|
||||
"openid": (OpenIdUserInfo,),
|
||||
}
|
||||
|
||||
DEFAULT_SERVLET_GROUPS = ("federation", "room_list", "openid")
|
||||
|
||||
GROUP_SERVLET_GROUPS = ("group_server", "group_local", "group_attestation")
|
||||
|
||||
|
||||
def register_servlets(
|
||||
hs: "HomeServer",
|
||||
@@ -324,19 +320,16 @@ def register_servlets(
|
||||
Defaults to ``DEFAULT_SERVLET_GROUPS``.
|
||||
"""
|
||||
if not servlet_groups:
|
||||
servlet_groups = DEFAULT_SERVLET_GROUPS
|
||||
# Only allow the groups servlets if the deprecated groups feature is enabled.
|
||||
if hs.config.experimental.groups_enabled:
|
||||
servlet_groups = servlet_groups + GROUP_SERVLET_GROUPS
|
||||
servlet_groups = DEFAULT_SERVLET_GROUPS.keys()
|
||||
|
||||
for servlet_group in servlet_groups:
|
||||
# Skip unknown servlet groups.
|
||||
if servlet_group not in SERVLET_GROUPS:
|
||||
if servlet_group not in DEFAULT_SERVLET_GROUPS:
|
||||
raise RuntimeError(
|
||||
f"Attempting to register unknown federation servlet: '{servlet_group}'"
|
||||
)
|
||||
|
||||
for servletclass in SERVLET_GROUPS[servlet_group]:
|
||||
for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]:
|
||||
# Only allow the `/timestamp_to_event` servlet if msc3030 is enabled
|
||||
if (
|
||||
servletclass == FederationTimestampLookupServlet
|
||||
|
||||
@@ -371,6 +371,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
log_kv(
|
||||
{"reason": "User doesn't have device id.", "device_id": device_id}
|
||||
)
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -413,6 +414,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
# no match
|
||||
set_tag("error", True)
|
||||
set_tag("reason", "User doesn't have that device id.")
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
@@ -45,7 +45,6 @@ from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||
from synapse.util import Clock, json_decoder
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||
from synapse.util.templates import _localpart_from_email_filter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -1229,7 +1228,6 @@ class OidcSessionData:
|
||||
|
||||
class UserAttributeDict(TypedDict):
|
||||
localpart: Optional[str]
|
||||
confirm_localpart: bool
|
||||
display_name: Optional[str]
|
||||
emails: List[str]
|
||||
|
||||
@@ -1309,11 +1307,6 @@ def jinja_finalize(thing: Any) -> Any:
|
||||
|
||||
|
||||
env = Environment(finalize=jinja_finalize)
|
||||
env.filters.update(
|
||||
{
|
||||
"localpart_from_email": _localpart_from_email_filter,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
@@ -1323,7 +1316,6 @@ class JinjaOidcMappingConfig:
|
||||
display_name_template: Optional[Template]
|
||||
email_template: Optional[Template]
|
||||
extra_attributes: Dict[str, Template]
|
||||
confirm_localpart: bool = False
|
||||
|
||||
|
||||
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||
@@ -1365,17 +1357,12 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||
"invalid jinja template", path=["extra_attributes", key]
|
||||
) from e
|
||||
|
||||
confirm_localpart = config.get("confirm_localpart") or False
|
||||
if not isinstance(confirm_localpart, bool):
|
||||
raise ConfigError("must be a bool", path=["confirm_localpart"])
|
||||
|
||||
return JinjaOidcMappingConfig(
|
||||
subject_claim=subject_claim,
|
||||
localpart_template=localpart_template,
|
||||
display_name_template=display_name_template,
|
||||
email_template=email_template,
|
||||
extra_attributes=extra_attributes,
|
||||
confirm_localpart=confirm_localpart,
|
||||
)
|
||||
|
||||
def get_remote_user_id(self, userinfo: UserInfo) -> str:
|
||||
@@ -1411,10 +1398,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||
emails.append(email)
|
||||
|
||||
return UserAttributeDict(
|
||||
localpart=localpart,
|
||||
display_name=display_name,
|
||||
emails=emails,
|
||||
confirm_localpart=self._config.confirm_localpart,
|
||||
localpart=localpart, display_name=display_name, emails=emails
|
||||
)
|
||||
|
||||
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set
|
||||
|
||||
import attr
|
||||
|
||||
@@ -350,7 +350,7 @@ class PaginationHandler:
|
||||
"""
|
||||
self._purges_in_progress_by_room.add(room_id)
|
||||
try:
|
||||
async with self.pagination_lock.write(room_id):
|
||||
with await self.pagination_lock.write(room_id):
|
||||
await self.storage.purge_events.purge_history(
|
||||
room_id, token, delete_local_events
|
||||
)
|
||||
@@ -406,7 +406,7 @@ class PaginationHandler:
|
||||
room_id: room to be purged
|
||||
force: set true to skip checking for joined users.
|
||||
"""
|
||||
async with self.pagination_lock.write(room_id):
|
||||
with await self.pagination_lock.write(room_id):
|
||||
# first check that we have no users in this room
|
||||
if not force:
|
||||
joined = await self.store.is_host_joined(room_id, self._server_name)
|
||||
@@ -422,7 +422,7 @@ class PaginationHandler:
|
||||
pagin_config: PaginationConfig,
|
||||
as_client_event: bool = True,
|
||||
event_filter: Optional[Filter] = None,
|
||||
) -> JsonDict:
|
||||
) -> Dict[str, Any]:
|
||||
"""Get messages in a room.
|
||||
|
||||
Args:
|
||||
@@ -431,7 +431,6 @@ class PaginationHandler:
|
||||
pagin_config: The pagination config rules to apply, if any.
|
||||
as_client_event: True to get events in client-server format.
|
||||
event_filter: Filter to apply to results or None
|
||||
|
||||
Returns:
|
||||
Pagination API results
|
||||
"""
|
||||
@@ -449,7 +448,7 @@ class PaginationHandler:
|
||||
|
||||
room_token = from_token.room_key
|
||||
|
||||
async with self.pagination_lock.read(room_id):
|
||||
with await self.pagination_lock.read(room_id):
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
@@ -616,7 +615,7 @@ class PaginationHandler:
|
||||
|
||||
self._purges_in_progress_by_room.add(room_id)
|
||||
try:
|
||||
async with self.pagination_lock.write(room_id):
|
||||
with await self.pagination_lock.write(room_id):
|
||||
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
|
||||
self._delete_by_id[
|
||||
delete_id
|
||||
|
||||
@@ -267,6 +267,7 @@ class BasePresenceHandler(abc.ABC):
|
||||
is_syncing: Whether or not the user is now syncing
|
||||
sync_time_msec: Time in ms when the user was last syncing
|
||||
"""
|
||||
pass
|
||||
|
||||
async def update_external_syncs_clear(self, process_id: str) -> None:
|
||||
"""Marks all users that had been marked as syncing by a given process
|
||||
@@ -276,6 +277,7 @@ class BasePresenceHandler(abc.ABC):
|
||||
|
||||
This is a no-op when presence is handled by a different worker.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import JsonDict, Requester, StreamToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RelationsHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._main_store = hs.get_datastores().main
|
||||
self._auth = hs.get_auth()
|
||||
self._clock = hs.get_clock()
|
||||
self._event_handler = hs.get_event_handler()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
async def get_relations(
|
||||
self,
|
||||
requester: Requester,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
relation_type: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
aggregation_key: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
direction: str = "b",
|
||||
from_token: Optional[StreamToken] = None,
|
||||
to_token: Optional[StreamToken] = None,
|
||||
) -> JsonDict:
|
||||
"""Get related events of a event, ordered by topological ordering.
|
||||
|
||||
TODO Accept a PaginationConfig instead of individual pagination parameters.
|
||||
|
||||
Args:
|
||||
requester: The user requesting the relations.
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
room_id: The room the event belongs to.
|
||||
relation_type: Only fetch events with this relation type, if given.
|
||||
event_type: Only fetch events with this event type, if given.
|
||||
aggregation_key: Only fetch events with this aggregation key, if given.
|
||||
limit: Only fetch the most recent `limit` events.
|
||||
direction: Whether to fetch the most recent first (`"b"`) or the
|
||||
oldest first (`"f"`).
|
||||
from_token: Fetch rows from the given token, or from the start if None.
|
||||
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||
|
||||
Returns:
|
||||
The pagination chunk.
|
||||
"""
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
await self._auth.check_user_in_room_or_world_readable(
|
||||
room_id, user_id, allow_departed_users=True
|
||||
)
|
||||
|
||||
# This gets the original event and checks that a) the event exists and
|
||||
# b) the user is allowed to view it.
|
||||
event = await self._event_handler.get_event(requester.user, room_id, event_id)
|
||||
if event is None:
|
||||
raise SynapseError(404, "Unknown parent event.")
|
||||
|
||||
pagination_chunk = await self._main_store.get_relations_for_event(
|
||||
event_id=event_id,
|
||||
event=event,
|
||||
room_id=room_id,
|
||||
relation_type=relation_type,
|
||||
event_type=event_type,
|
||||
aggregation_key=aggregation_key,
|
||||
limit=limit,
|
||||
direction=direction,
|
||||
from_token=from_token,
|
||||
to_token=to_token,
|
||||
)
|
||||
|
||||
events = await self._main_store.get_events_as_list(
|
||||
[c["event_id"] for c in pagination_chunk.chunk]
|
||||
)
|
||||
|
||||
now = self._clock.time_msec()
|
||||
# Do not bundle aggregations when retrieving the original event because
|
||||
# we want the content before relations are applied to it.
|
||||
original_event = self._event_serializer.serialize_event(
|
||||
event, now, bundle_aggregations=None
|
||||
)
|
||||
# The relations returned for the requested event do include their
|
||||
# bundled aggregations.
|
||||
aggregations = await self._main_store.get_bundled_aggregations(
|
||||
events, requester.user.to_string()
|
||||
)
|
||||
serialized_events = self._event_serializer.serialize_events(
|
||||
events, now, bundle_aggregations=aggregations
|
||||
)
|
||||
|
||||
return_value = await pagination_chunk.to_dict(self._main_store)
|
||||
return_value["chunk"] = serialized_events
|
||||
return_value["original_event"] = original_event
|
||||
|
||||
return return_value
|
||||
@@ -132,7 +132,6 @@ class UserAttributes:
|
||||
# if `None`, the mapper has not picked a userid, and the user should be prompted to
|
||||
# enter one.
|
||||
localpart: Optional[str]
|
||||
confirm_localpart: bool = False
|
||||
display_name: Optional[str] = None
|
||||
emails: Collection[str] = attr.Factory(list)
|
||||
|
||||
@@ -562,10 +561,9 @@ class SsoHandler:
|
||||
# Must provide either attributes or session, not both
|
||||
assert (attributes is not None) != (session is not None)
|
||||
|
||||
if (
|
||||
attributes
|
||||
and (attributes.localpart is None or attributes.confirm_localpart is True)
|
||||
) or (session and session.chosen_localpart is None):
|
||||
if (attributes and attributes.localpart is None) or (
|
||||
session and session.chosen_localpart is None
|
||||
):
|
||||
return b"/_synapse/client/pick_username/account_details"
|
||||
elif self._consent_at_registration and not (
|
||||
session and session.terms_accepted_version
|
||||
|
||||
@@ -28,7 +28,7 @@ from typing import (
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
|
||||
from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
|
||||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
@@ -1601,7 +1601,7 @@ class SyncHandler:
|
||||
return set(), set(), set(), set()
|
||||
|
||||
# 3. Work out which rooms need reporting in the sync response.
|
||||
ignored_users = await self.store.ignored_users(user_id)
|
||||
ignored_users = await self._get_ignored_users(user_id)
|
||||
if since_token:
|
||||
room_changes = await self._get_rooms_changed(
|
||||
sync_result_builder, ignored_users
|
||||
@@ -1627,6 +1627,7 @@ class SyncHandler:
|
||||
logger.debug("Generating room entry for %s", room_entry.room_id)
|
||||
await self._generate_room_entry(
|
||||
sync_result_builder,
|
||||
ignored_users,
|
||||
room_entry,
|
||||
ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
|
||||
tags=tags_by_room.get(room_entry.room_id),
|
||||
@@ -1656,6 +1657,29 @@ class SyncHandler:
|
||||
newly_left_users,
|
||||
)
|
||||
|
||||
async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
|
||||
"""Retrieve the users ignored by the given user from their global account_data.
|
||||
|
||||
Returns an empty set if
|
||||
- there is no global account_data entry for ignored_users
|
||||
- there is such an entry, but it's not a JSON object.
|
||||
"""
|
||||
# TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
|
||||
ignored_account_data = (
|
||||
await self.store.get_global_account_data_by_type_for_user(
|
||||
user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
|
||||
)
|
||||
)
|
||||
|
||||
# If there is ignored users account data and it matches the proper type,
|
||||
# then use it.
|
||||
ignored_users: FrozenSet[str] = frozenset()
|
||||
if ignored_account_data:
|
||||
ignored_users_data = ignored_account_data.get("ignored_users", {})
|
||||
if isinstance(ignored_users_data, dict):
|
||||
ignored_users = frozenset(ignored_users_data.keys())
|
||||
return ignored_users
|
||||
|
||||
async def _have_rooms_changed(
|
||||
self, sync_result_builder: "SyncResultBuilder"
|
||||
) -> bool:
|
||||
@@ -1998,6 +2022,7 @@ class SyncHandler:
|
||||
async def _generate_room_entry(
|
||||
self,
|
||||
sync_result_builder: "SyncResultBuilder",
|
||||
ignored_users: FrozenSet[str],
|
||||
room_builder: "RoomSyncResultBuilder",
|
||||
ephemeral: List[JsonDict],
|
||||
tags: Optional[Dict[str, Dict[str, Any]]],
|
||||
@@ -2026,6 +2051,7 @@ class SyncHandler:
|
||||
|
||||
Args:
|
||||
sync_result_builder
|
||||
ignored_users: Set of users ignored by user.
|
||||
room_builder
|
||||
ephemeral: List of new ephemeral events for room
|
||||
tags: List of *all* tags for room, or None if there has been
|
||||
|
||||
@@ -160,9 +160,8 @@ class FollowerTypingHandler:
|
||||
"""Should be called whenever we receive updates for typing stream."""
|
||||
|
||||
if self._latest_room_serial > token:
|
||||
# The typing worker has gone backwards (e.g. it may have restarted).
|
||||
# To prevent inconsistent data, just clear everything.
|
||||
logger.info("Typing handler stream went backwards; resetting")
|
||||
# The master has gone backwards. To prevent inconsistent data, just
|
||||
# clear everything.
|
||||
self._reset()
|
||||
|
||||
# Set the latest serial token to whatever the server gave us.
|
||||
|
||||
@@ -120,6 +120,7 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC):
|
||||
"""Called when response has finished streaming and the parser should
|
||||
return the final result (or error).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
@@ -600,6 +601,7 @@ class MatrixFederationHttpClient:
|
||||
response.code,
|
||||
response_phrase,
|
||||
)
|
||||
pass
|
||||
else:
|
||||
logger.info(
|
||||
"{%s} [%s] Got response headers: %d %s",
|
||||
|
||||
@@ -233,6 +233,7 @@ class HttpServer(Protocol):
|
||||
servlet_classname (str): The name of the handler to be used in prometheus
|
||||
and opentracing logs.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||
|
||||
@@ -169,7 +169,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
"kind": "event_match",
|
||||
"key": "content.msgtype",
|
||||
"pattern": "m.notice",
|
||||
"_cache_key": "_suppress_notices",
|
||||
"_id": "_suppress_notices",
|
||||
}
|
||||
],
|
||||
"actions": ["dont_notify"],
|
||||
@@ -183,13 +183,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
"kind": "event_match",
|
||||
"key": "type",
|
||||
"pattern": "m.room.member",
|
||||
"_cache_key": "_member",
|
||||
"_id": "_member",
|
||||
},
|
||||
{
|
||||
"kind": "event_match",
|
||||
"key": "content.membership",
|
||||
"pattern": "invite",
|
||||
"_cache_key": "_invite_member",
|
||||
"_id": "_invite_member",
|
||||
},
|
||||
{"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
|
||||
],
|
||||
@@ -212,7 +212,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
"kind": "event_match",
|
||||
"key": "type",
|
||||
"pattern": "m.room.member",
|
||||
"_cache_key": "_member",
|
||||
"_id": "_member",
|
||||
}
|
||||
],
|
||||
"actions": ["dont_notify"],
|
||||
@@ -237,12 +237,12 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
"kind": "event_match",
|
||||
"key": "content.body",
|
||||
"pattern": "@room",
|
||||
"_cache_key": "_roomnotif_content",
|
||||
"_id": "_roomnotif_content",
|
||||
},
|
||||
{
|
||||
"kind": "sender_notification_permission",
|
||||
"key": "room",
|
||||
"_cache_key": "_roomnotif_pl",
|
||||
"_id": "_roomnotif_pl",
|
||||
},
|
||||
],
|
||||
"actions": ["notify", {"set_tweak": "highlight", "value": True}],
|
||||
@@ -254,13 +254,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
"kind": "event_match",
|
||||
"key": "type",
|
||||
"pattern": "m.room.tombstone",
|
||||
"_cache_key": "_tombstone",
|
||||
"_id": "_tombstone",
|
||||
},
|
||||
{
|
||||
"kind": "event_match",
|
||||
"key": "state_key",
|
||||
"pattern": "",
|
||||
"_cache_key": "_tombstone_statekey",
|
||||
"_id": "_tombstone_statekey",
|
||||
},
|
||||
],
|
||||
"actions": ["notify", {"set_tweak": "highlight", "value": True}],
|
||||
@@ -272,7 +272,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
"kind": "event_match",
|
||||
"key": "type",
|
||||
"pattern": "m.reaction",
|
||||
"_cache_key": "_reaction",
|
||||
"_id": "_reaction",
|
||||
}
|
||||
],
|
||||
"actions": ["dont_notify"],
|
||||
@@ -288,7 +288,7 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
"kind": "event_match",
|
||||
"key": "type",
|
||||
"pattern": "m.call.invite",
|
||||
"_cache_key": "_call",
|
||||
"_id": "_call",
|
||||
}
|
||||
],
|
||||
"actions": [
|
||||
@@ -302,12 +302,12 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
{
|
||||
"rule_id": "global/underride/.m.rule.room_one_to_one",
|
||||
"conditions": [
|
||||
{"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
|
||||
{"kind": "room_member_count", "is": "2", "_id": "member_count"},
|
||||
{
|
||||
"kind": "event_match",
|
||||
"key": "type",
|
||||
"pattern": "m.room.message",
|
||||
"_cache_key": "_message",
|
||||
"_id": "_message",
|
||||
},
|
||||
],
|
||||
"actions": [
|
||||
@@ -321,12 +321,12 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
{
|
||||
"rule_id": "global/underride/.m.rule.encrypted_room_one_to_one",
|
||||
"conditions": [
|
||||
{"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
|
||||
{"kind": "room_member_count", "is": "2", "_id": "member_count"},
|
||||
{
|
||||
"kind": "event_match",
|
||||
"key": "type",
|
||||
"pattern": "m.room.encrypted",
|
||||
"_cache_key": "_encrypted",
|
||||
"_id": "_encrypted",
|
||||
},
|
||||
],
|
||||
"actions": [
|
||||
@@ -342,7 +342,7 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
"kind": "event_match",
|
||||
"key": "type",
|
||||
"pattern": "m.room.message",
|
||||
"_cache_key": "_message",
|
||||
"_id": "_message",
|
||||
}
|
||||
],
|
||||
"actions": ["notify", {"set_tweak": "highlight", "value": False}],
|
||||
@@ -356,7 +356,7 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
"kind": "event_match",
|
||||
"key": "type",
|
||||
"pattern": "m.room.encrypted",
|
||||
"_cache_key": "_encrypted",
|
||||
"_id": "_encrypted",
|
||||
}
|
||||
],
|
||||
"actions": ["notify", {"set_tweak": "highlight", "value": False}],
|
||||
@@ -368,19 +368,19 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
|
||||
"kind": "event_match",
|
||||
"key": "type",
|
||||
"pattern": "im.vector.modular.widgets",
|
||||
"_cache_key": "_type_modular_widgets",
|
||||
"_id": "_type_modular_widgets",
|
||||
},
|
||||
{
|
||||
"kind": "event_match",
|
||||
"key": "content.type",
|
||||
"pattern": "jitsi",
|
||||
"_cache_key": "_content_type_jitsi",
|
||||
"_id": "_content_type_jitsi",
|
||||
},
|
||||
{
|
||||
"kind": "event_match",
|
||||
"key": "state_key",
|
||||
"pattern": "*",
|
||||
"_cache_key": "_is_state_event",
|
||||
"_id": "_is_state_event",
|
||||
},
|
||||
],
|
||||
"actions": ["notify", {"set_tweak": "highlight", "value": False}],
|
||||
|
||||
@@ -213,7 +213,7 @@ class BulkPushRuleEvaluator:
|
||||
if not event.is_state():
|
||||
ignorers = await self.store.ignored_by(event.sender)
|
||||
else:
|
||||
ignorers = frozenset()
|
||||
ignorers = set()
|
||||
|
||||
for uid, rules in rules_by_user.items():
|
||||
if event.sender == uid:
|
||||
@@ -274,17 +274,17 @@ def _condition_checker(
|
||||
cache: Dict[str, bool],
|
||||
) -> bool:
|
||||
for cond in conditions:
|
||||
_cache_key = cond.get("_cache_key", None)
|
||||
if _cache_key:
|
||||
res = cache.get(_cache_key, None)
|
||||
_id = cond.get("_id", None)
|
||||
if _id:
|
||||
res = cache.get(_id, None)
|
||||
if res is False:
|
||||
return False
|
||||
elif res is True:
|
||||
continue
|
||||
|
||||
res = evaluator.matches(cond, uid, display_name)
|
||||
if _cache_key:
|
||||
cache[_cache_key] = bool(res)
|
||||
if _id:
|
||||
cache[_id] = bool(res)
|
||||
|
||||
if not res:
|
||||
return False
|
||||
|
||||
@@ -40,7 +40,7 @@ def format_push_rules_for_user(
|
||||
|
||||
# Remove internal stuff.
|
||||
for c in r["conditions"]:
|
||||
c.pop("_cache_key", None)
|
||||
c.pop("_id", None)
|
||||
|
||||
pattern_type = c.pop("pattern_type", None)
|
||||
if pattern_type == "user_id":
|
||||
|
||||
@@ -709,7 +709,7 @@ class ReplicationCommandHandler:
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None:
|
||||
"""Called when a new update is available to stream to Redis subscribers.
|
||||
"""Called when a new update is available to stream to clients.
|
||||
|
||||
We need to check if the client is interested in the stream or not
|
||||
"""
|
||||
|
||||
@@ -67,8 +67,8 @@ class ReplicationStreamProtocolFactory(ServerFactory):
|
||||
class ReplicationStreamer:
|
||||
"""Handles replication connections.
|
||||
|
||||
This needs to be poked when new replication data may be available.
|
||||
When new data is available it will propagate to all Redis subscribers.
|
||||
This needs to be poked when new replication data may be available. When new
|
||||
data is available it will propagate to all connected clients.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
@@ -109,7 +109,7 @@ class ReplicationStreamer:
|
||||
|
||||
def on_notifier_poke(self) -> None:
|
||||
"""Checks if there is actually any new data and sends it to the
|
||||
Redis subscribers if there are.
|
||||
connections if there are.
|
||||
|
||||
This should get called each time new data is available, even if it
|
||||
is currently being executed, so that nothing gets missed
|
||||
|
||||
@@ -316,19 +316,7 @@ class PresenceFederationStream(Stream):
|
||||
class TypingStream(Stream):
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class TypingStreamRow:
|
||||
"""
|
||||
An entry in the typing stream.
|
||||
Describes all the users that are 'typing' right now in one room.
|
||||
|
||||
When a user stops typing, it will be streamed as a new update with that
|
||||
user absent; you can think of the `user_ids` list as overwriting the
|
||||
entire list that was there previously.
|
||||
"""
|
||||
|
||||
# The room that this update is for.
|
||||
room_id: str
|
||||
|
||||
# All the users that are 'typing' right now in the specified room.
|
||||
user_ids: List[str]
|
||||
|
||||
NAME = "typing"
|
||||
|
||||
@@ -130,15 +130,15 @@
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>Choose your user name</h1>
|
||||
<p>This is required to create your account on {{ server_name }}, and you can't change this later.</p>
|
||||
<h1>Your account is nearly ready</h1>
|
||||
<p>Check your details before creating an account on {{ server_name }}</p>
|
||||
</header>
|
||||
<main>
|
||||
<form method="post" class="form__input" id="form">
|
||||
<div class="username_input" id="username_input">
|
||||
<label for="field-username">Username</label>
|
||||
<div class="prefix">@</div>
|
||||
<input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus>
|
||||
<input type="text" name="username" id="field-username" autofocus>
|
||||
<div class="postfix">:{{ server_name }}</div>
|
||||
</div>
|
||||
<output for="username_input" id="field-username-output"></output>
|
||||
|
||||
@@ -118,8 +118,7 @@ class ClientRestResource(JsonResource):
|
||||
thirdparty.register_servlets(hs, client_resource)
|
||||
sendtodevice.register_servlets(hs, client_resource)
|
||||
user_directory.register_servlets(hs, client_resource)
|
||||
if hs.config.experimental.groups_enabled:
|
||||
groups.register_servlets(hs, client_resource)
|
||||
groups.register_servlets(hs, client_resource)
|
||||
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
|
||||
room_batch.register_servlets(hs, client_resource)
|
||||
capabilities.register_servlets(hs, client_resource)
|
||||
|
||||
@@ -293,8 +293,7 @@ def register_servlets_for_client_rest_resource(
|
||||
ResetPasswordRestServlet(hs).register(http_server)
|
||||
SearchUsersRestServlet(hs).register(http_server)
|
||||
UserRegisterServlet(hs).register(http_server)
|
||||
if hs.config.experimental.groups_enabled:
|
||||
DeleteGroupAdminRestServlet(hs).register(http_server)
|
||||
DeleteGroupAdminRestServlet(hs).register(http_server)
|
||||
AccountValidityRenewServlet(hs).register(http_server)
|
||||
|
||||
# Load the media repo ones if we're using them. Otherwise load the servlets which
|
||||
|
||||
@@ -51,7 +51,9 @@ class RelationPaginationServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self._relations_handler = hs.get_relations_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.event_handler = hs.get_event_handler()
|
||||
|
||||
async def on_GET(
|
||||
self,
|
||||
@@ -63,6 +65,16 @@ class RelationPaginationServlet(RestServlet):
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
await self.auth.check_user_in_room_or_world_readable(
|
||||
room_id, requester.user.to_string(), allow_departed_users=True
|
||||
)
|
||||
|
||||
# This gets the original event and checks that a) the event exists and
|
||||
# b) the user is allowed to view it.
|
||||
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
|
||||
if event is None:
|
||||
raise SynapseError(404, "Unknown parent event.")
|
||||
|
||||
limit = parse_integer(request, "limit", default=5)
|
||||
direction = parse_string(
|
||||
request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
|
||||
@@ -78,9 +90,9 @@ class RelationPaginationServlet(RestServlet):
|
||||
if to_token_str:
|
||||
to_token = await StreamToken.from_string(self.store, to_token_str)
|
||||
|
||||
result = await self._relations_handler.get_relations(
|
||||
requester=requester,
|
||||
pagination_chunk = await self.store.get_relations_for_event(
|
||||
event_id=parent_id,
|
||||
event=event,
|
||||
room_id=room_id,
|
||||
relation_type=relation_type,
|
||||
event_type=event_type,
|
||||
@@ -90,7 +102,30 @@ class RelationPaginationServlet(RestServlet):
|
||||
to_token=to_token,
|
||||
)
|
||||
|
||||
return 200, result
|
||||
events = await self.store.get_events_as_list(
|
||||
[c["event_id"] for c in pagination_chunk.chunk]
|
||||
)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
# Do not bundle aggregations when retrieving the original event because
|
||||
# we want the content before relations are applied to it.
|
||||
original_event = self._event_serializer.serialize_event(
|
||||
event, now, bundle_aggregations=None
|
||||
)
|
||||
# The relations returned for the requested event do include their
|
||||
# bundled aggregations.
|
||||
aggregations = await self.store.get_bundled_aggregations(
|
||||
events, requester.user.to_string()
|
||||
)
|
||||
serialized_events = self._event_serializer.serialize_events(
|
||||
events, now, bundle_aggregations=aggregations
|
||||
)
|
||||
|
||||
return_value = await pagination_chunk.to_dict(self.store)
|
||||
return_value["chunk"] = serialized_events
|
||||
return_value["original_event"] = original_event
|
||||
|
||||
return 200, return_value
|
||||
|
||||
|
||||
class RelationAggregationPaginationServlet(RestServlet):
|
||||
@@ -210,7 +245,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self._relations_handler = hs.get_relations_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.event_handler = hs.get_event_handler()
|
||||
|
||||
async def on_GET(
|
||||
self,
|
||||
@@ -223,6 +260,18 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
await self.auth.check_user_in_room_or_world_readable(
|
||||
room_id,
|
||||
requester.user.to_string(),
|
||||
allow_departed_users=True,
|
||||
)
|
||||
|
||||
# This checks that a) the event exists and b) the user is allowed to
|
||||
# view it.
|
||||
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
|
||||
if event is None:
|
||||
raise SynapseError(404, "Unknown parent event.")
|
||||
|
||||
if relation_type != RelationTypes.ANNOTATION:
|
||||
raise SynapseError(400, "Relation type must be 'annotation'")
|
||||
|
||||
@@ -237,9 +286,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||
if to_token_str:
|
||||
to_token = await StreamToken.from_string(self.store, to_token_str)
|
||||
|
||||
result = await self._relations_handler.get_relations(
|
||||
requester=requester,
|
||||
result = await self.store.get_relations_for_event(
|
||||
event_id=parent_id,
|
||||
event=event,
|
||||
room_id=room_id,
|
||||
relation_type=relation_type,
|
||||
event_type=event_type,
|
||||
@@ -249,7 +298,17 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||
to_token=to_token,
|
||||
)
|
||||
|
||||
return 200, result
|
||||
events = await self.store.get_events_as_list(
|
||||
[c["event_id"] for c in result.chunk]
|
||||
)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
serialized_events = self._event_serializer.serialize_events(events, now)
|
||||
|
||||
return_value = await result.to_dict(self.store)
|
||||
return_value["chunk"] = serialized_events
|
||||
|
||||
return 200, return_value
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
|
||||
@@ -298,6 +298,7 @@ class Responder:
|
||||
Returns:
|
||||
Resolves once the response has finished being written
|
||||
"""
|
||||
pass
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -16,6 +16,7 @@ import itertools
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
|
||||
from urllib import parse as urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lxml import etree
|
||||
@@ -143,7 +144,9 @@ def decode_body(
|
||||
return etree.fromstring(body, parser)
|
||||
|
||||
|
||||
def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||
def parse_html_to_open_graph(
|
||||
tree: "etree.Element", media_uri: str
|
||||
) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Parse the HTML document into an Open Graph response.
|
||||
|
||||
@@ -152,6 +155,7 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||
|
||||
Args:
|
||||
tree: The parsed HTML document.
|
||||
media_url: The URI used to download the body.
|
||||
|
||||
Returns:
|
||||
The Open Graph response as a dictionary.
|
||||
@@ -205,7 +209,7 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||
)
|
||||
if meta_image:
|
||||
og["og:image"] = meta_image[0]
|
||||
og["og:image"] = rebase_url(meta_image[0], media_uri)
|
||||
else:
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
@@ -316,6 +320,37 @@ def _iterate_over_text(
|
||||
)
|
||||
|
||||
|
||||
def rebase_url(url: str, base: str) -> str:
|
||||
"""
|
||||
Resolves a potentially relative `url` against an absolute `base` URL.
|
||||
|
||||
For example:
|
||||
|
||||
>>> rebase_url("subpage", "https://example.com/foo/")
|
||||
'https://example.com/foo/subpage'
|
||||
>>> rebase_url("sibling", "https://example.com/foo")
|
||||
'https://example.com/sibling'
|
||||
>>> rebase_url("/bar", "https://example.com/foo/")
|
||||
'https://example.com/bar'
|
||||
>>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
|
||||
'https://alice.com/a'
|
||||
"""
|
||||
base_parts = urlparse.urlparse(base)
|
||||
# Convert the parsed URL to a list for (potential) modification.
|
||||
url_parts = list(urlparse.urlparse(url))
|
||||
# Add a scheme, if one does not exist.
|
||||
if not url_parts[0]:
|
||||
url_parts[0] = base_parts.scheme or "http"
|
||||
# Fix up the hostname, if this is not a data URL.
|
||||
if url_parts[0] != "data" and not url_parts[1]:
|
||||
url_parts[1] = base_parts.netloc
|
||||
# If the path does not start with a /, nest it under the base path's last
|
||||
# directory.
|
||||
if not url_parts[2].startswith("/"):
|
||||
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
|
||||
return urlparse.urlunparse(url_parts)
|
||||
|
||||
|
||||
def summarize_paragraphs(
|
||||
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
|
||||
) -> Optional[str]:
|
||||
|
||||
@@ -22,7 +22,7 @@ import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
|
||||
from urllib.parse import urljoin, urlparse, urlsplit
|
||||
from urllib import parse as urlparse
|
||||
from urllib.request import urlopen
|
||||
|
||||
import attr
|
||||
@@ -44,7 +44,11 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
from synapse.rest.media.v1.oembed import OEmbedProvider
|
||||
from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
|
||||
from synapse.rest.media.v1.preview_html import (
|
||||
decode_body,
|
||||
parse_html_to_open_graph,
|
||||
rebase_url,
|
||||
)
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
@@ -183,7 +187,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
ts = self.clock.time_msec()
|
||||
|
||||
# XXX: we could move this into _do_preview if we wanted.
|
||||
url_tuple = urlsplit(url)
|
||||
url_tuple = urlparse.urlsplit(url)
|
||||
for entry in self.url_preview_url_blacklist:
|
||||
match = True
|
||||
for attrib in entry:
|
||||
@@ -318,7 +322,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
|
||||
# Parse Open Graph information from the HTML in case the oEmbed
|
||||
# response failed or is incomplete.
|
||||
og_from_html = parse_html_to_open_graph(tree)
|
||||
og_from_html = parse_html_to_open_graph(tree, media_info.uri)
|
||||
|
||||
# Compile the Open Graph response by using the scraped
|
||||
# information from the HTML and overlaying any information
|
||||
@@ -584,17 +588,12 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
if "og:image" not in og or not og["og:image"]:
|
||||
return
|
||||
|
||||
# The image URL from the HTML might be relative to the previewed page,
|
||||
# convert it to an URL which can be requested directly.
|
||||
image_url = og["og:image"]
|
||||
url_parts = urlparse(image_url)
|
||||
if url_parts.scheme != "data":
|
||||
image_url = urljoin(media_info.uri, image_url)
|
||||
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
image_info = await self._handle_url(image_url, user, allow_data_urls=True)
|
||||
image_info = await self._handle_url(
|
||||
rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
|
||||
)
|
||||
|
||||
if _is_media(image_info.media_type):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
|
||||
@@ -92,20 +92,12 @@ class AccountDetailsResource(DirectServeHtmlResource):
|
||||
self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
|
||||
return
|
||||
|
||||
# The configuration might mandate going through this step to validate an
|
||||
# automatically generated localpart, so session.chosen_localpart might already
|
||||
# be set.
|
||||
localpart = ""
|
||||
if session.chosen_localpart is not None:
|
||||
localpart = session.chosen_localpart
|
||||
|
||||
idp_id = session.auth_provider_id
|
||||
template_params = {
|
||||
"idp": self._sso_handler.get_identity_providers()[idp_id],
|
||||
"user_attributes": {
|
||||
"display_name": session.display_name,
|
||||
"emails": session.emails,
|
||||
"localpart": localpart,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -94,7 +94,6 @@ from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||
from synapse.handlers.receipts import ReceiptsHandler
|
||||
from synapse.handlers.register import RegistrationHandler
|
||||
from synapse.handlers.relations import RelationsHandler
|
||||
from synapse.handlers.room import (
|
||||
RoomContextHandler,
|
||||
RoomCreationHandler,
|
||||
@@ -329,6 +328,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
Does nothing in this base class; overridden in derived classes to start the
|
||||
appropriate listeners.
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup_background_tasks(self) -> None:
|
||||
"""
|
||||
@@ -720,10 +720,6 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
def get_pagination_handler(self) -> PaginationHandler:
|
||||
return PaginationHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_relations_handler(self) -> RelationsHandler:
|
||||
return RelationsHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_room_context_handler(self) -> RoomContextHandler:
|
||||
return RoomContextHandler(self)
|
||||
|
||||
@@ -60,19 +60,18 @@ class _BackgroundUpdateHandler:
|
||||
|
||||
|
||||
class _BackgroundUpdateContextManager:
|
||||
def __init__(
|
||||
self, sleep: bool, clock: Clock, sleep_duration_ms: int, update_duration: int
|
||||
):
|
||||
BACKGROUND_UPDATE_INTERVAL_MS = 1000
|
||||
BACKGROUND_UPDATE_DURATION_MS = 100
|
||||
|
||||
def __init__(self, sleep: bool, clock: Clock):
|
||||
self._sleep = sleep
|
||||
self._clock = clock
|
||||
self._sleep_duration_ms = sleep_duration_ms
|
||||
self._update_duration_ms = update_duration
|
||||
|
||||
async def __aenter__(self) -> int:
|
||||
if self._sleep:
|
||||
await self._clock.sleep(self._sleep_duration_ms / 1000)
|
||||
await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
|
||||
|
||||
return self._update_duration_ms
|
||||
return self.BACKGROUND_UPDATE_DURATION_MS
|
||||
|
||||
async def __aexit__(self, *exc) -> None:
|
||||
pass
|
||||
@@ -134,6 +133,9 @@ class BackgroundUpdater:
|
||||
process and autotuning the batch size.
|
||||
"""
|
||||
|
||||
MINIMUM_BACKGROUND_BATCH_SIZE = 1
|
||||
DEFAULT_BACKGROUND_BATCH_SIZE = 100
|
||||
|
||||
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
|
||||
self._clock = hs.get_clock()
|
||||
self.db_pool = database
|
||||
@@ -158,14 +160,6 @@ class BackgroundUpdater:
|
||||
# enable/disable background updates via the admin API.
|
||||
self.enabled = True
|
||||
|
||||
self.minimum_background_batch_size = hs.config.background_updates.min_batch_size
|
||||
self.default_background_batch_size = (
|
||||
hs.config.background_updates.default_batch_size
|
||||
)
|
||||
self.update_duration_ms = hs.config.background_updates.update_duration_ms
|
||||
self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms
|
||||
self.sleep_enabled = hs.config.background_updates.sleep_enabled
|
||||
|
||||
def register_update_controller_callbacks(
|
||||
self,
|
||||
on_update: ON_UPDATE_CALLBACK,
|
||||
@@ -222,9 +216,7 @@ class BackgroundUpdater:
|
||||
if self._on_update_callback is not None:
|
||||
return self._on_update_callback(update_name, database_name, oneshot)
|
||||
|
||||
return _BackgroundUpdateContextManager(
|
||||
sleep, self._clock, self.sleep_duration_ms, self.update_duration_ms
|
||||
)
|
||||
return _BackgroundUpdateContextManager(sleep, self._clock)
|
||||
|
||||
async def _default_batch_size(self, update_name: str, database_name: str) -> int:
|
||||
"""The batch size to use for the first iteration of a new background
|
||||
@@ -233,7 +225,7 @@ class BackgroundUpdater:
|
||||
if self._default_batch_size_callback is not None:
|
||||
return await self._default_batch_size_callback(update_name, database_name)
|
||||
|
||||
return self.default_background_batch_size
|
||||
return self.DEFAULT_BACKGROUND_BATCH_SIZE
|
||||
|
||||
async def _min_batch_size(self, update_name: str, database_name: str) -> int:
|
||||
"""A lower bound on the batch size of a new background update.
|
||||
@@ -243,7 +235,7 @@ class BackgroundUpdater:
|
||||
if self._min_batch_size_callback is not None:
|
||||
return await self._min_batch_size_callback(update_name, database_name)
|
||||
|
||||
return self.minimum_background_batch_size
|
||||
return self.MINIMUM_BACKGROUND_BATCH_SIZE
|
||||
|
||||
def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
|
||||
"""Returns the current background update, if any."""
|
||||
@@ -262,12 +254,9 @@ class BackgroundUpdater:
|
||||
if self.enabled:
|
||||
# if we start a new background update, not all updates are done.
|
||||
self._all_done = False
|
||||
sleep = self.sleep_enabled
|
||||
run_as_background_process(
|
||||
"background_updates", self.run_background_updates, sleep
|
||||
)
|
||||
run_as_background_process("background_updates", self.run_background_updates)
|
||||
|
||||
async def run_background_updates(self, sleep: bool) -> None:
|
||||
async def run_background_updates(self, sleep: bool = True) -> None:
|
||||
if self._running or not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ from prometheus_client import Histogram
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
@@ -56,7 +55,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -734,45 +732,34 @@ class DatabasePool:
|
||||
Returns:
|
||||
The result of func
|
||||
"""
|
||||
after_callbacks: List[_CallbackListEntry] = []
|
||||
exception_callbacks: List[_CallbackListEntry] = []
|
||||
|
||||
async def _runInteraction() -> R:
|
||||
after_callbacks: List[_CallbackListEntry] = []
|
||||
exception_callbacks: List[_CallbackListEntry] = []
|
||||
if not current_context():
|
||||
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
||||
|
||||
if not current_context():
|
||||
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
||||
try:
|
||||
with opentracing.start_active_span(f"db.{desc}"):
|
||||
result = await self.runWithConnection(
|
||||
self.new_transaction,
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
db_autocommit=db_autocommit,
|
||||
isolation_level=isolation_level,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
with opentracing.start_active_span(f"db.{desc}"):
|
||||
result = await self.runWithConnection(
|
||||
self.new_transaction,
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
db_autocommit=db_autocommit,
|
||||
isolation_level=isolation_level,
|
||||
**kwargs,
|
||||
)
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
except Exception:
|
||||
for after_callback, after_args, after_kwargs in exception_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
raise
|
||||
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
|
||||
return cast(R, result)
|
||||
except Exception:
|
||||
for after_callback, after_args, after_kwargs in exception_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
raise
|
||||
|
||||
# To handle cancellation, we ensure that `after_callback`s and
|
||||
# `exception_callback`s are always run, since the transaction will complete
|
||||
# on another thread regardless of cancellation.
|
||||
#
|
||||
# We also wait until everything above is done before releasing the
|
||||
# `CancelledError`, so that logging contexts won't get used after they have been
|
||||
# finished.
|
||||
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
|
||||
return cast(R, result)
|
||||
|
||||
async def runWithConnection(
|
||||
self,
|
||||
|
||||
@@ -14,17 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
|
||||
|
||||
from synapse.api.constants import AccountDataTypes
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
@@ -375,7 +365,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||
)
|
||||
|
||||
@cached(max_entries=5000, iterable=True)
|
||||
async def ignored_by(self, user_id: str) -> FrozenSet[str]:
|
||||
async def ignored_by(self, user_id: str) -> Set[str]:
|
||||
"""
|
||||
Get users which ignore the given user.
|
||||
|
||||
@@ -385,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||
Return:
|
||||
The user IDs which ignore the given user.
|
||||
"""
|
||||
return frozenset(
|
||||
return set(
|
||||
await self.db_pool.simple_select_onecol(
|
||||
table="ignored_users",
|
||||
keyvalues={"ignored_user_id": user_id},
|
||||
@@ -394,26 +384,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||
)
|
||||
)
|
||||
|
||||
@cached(max_entries=5000, iterable=True)
|
||||
async def ignored_users(self, user_id: str) -> FrozenSet[str]:
|
||||
"""
|
||||
Get users which the given user ignores.
|
||||
|
||||
Params:
|
||||
user_id: The user ID which is making the request.
|
||||
|
||||
Return:
|
||||
The user IDs which are ignored by the given user.
|
||||
"""
|
||||
return frozenset(
|
||||
await self.db_pool.simple_select_onecol(
|
||||
table="ignored_users",
|
||||
keyvalues={"ignorer_user_id": user_id},
|
||||
retcol="ignored_user_id",
|
||||
desc="ignored_users",
|
||||
)
|
||||
)
|
||||
|
||||
def process_replication_rows(
|
||||
self,
|
||||
stream_name: str,
|
||||
@@ -559,10 +529,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||
else:
|
||||
currently_ignored_users = set()
|
||||
|
||||
# If the data has not changed, nothing to do.
|
||||
if previously_ignored_users == currently_ignored_users:
|
||||
return
|
||||
|
||||
# Delete entries which are no longer ignored.
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
@@ -585,7 +551,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||
# Invalidate the cache for any ignored users which were added or removed.
|
||||
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
|
||||
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
|
||||
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
|
||||
|
||||
async def purge_account_data_for_user(self, user_id: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -23,7 +23,6 @@ from synapse.replication.tcp.streams.events import (
|
||||
EventsStream,
|
||||
EventsStreamCurrentStateRow,
|
||||
EventsStreamEventRow,
|
||||
EventsStreamRow,
|
||||
)
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
@@ -32,7 +31,6 @@ from synapse.storage.database import (
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.util.caches.descriptors import _CachedFunction
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -84,9 +82,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_updated_caches_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
def get_all_updated_caches_txn(txn):
|
||||
# We purposefully don't bound by the current token, as we want to
|
||||
# send across cache invalidations as quickly as possible. Cache
|
||||
# invalidations are idempotent, so duplicates are fine.
|
||||
@@ -111,9 +107,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||
) -> None:
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == EventsStream.NAME:
|
||||
for row in rows:
|
||||
self._process_event_stream_row(token, row)
|
||||
@@ -148,11 +142,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
|
||||
def _process_event_stream_row(self, token, row):
|
||||
data = row.data
|
||||
|
||||
if row.type == EventsStreamEventRow.TypeId:
|
||||
assert isinstance(data, EventsStreamEventRow)
|
||||
self._invalidate_caches_for_event(
|
||||
token,
|
||||
data.event_id,
|
||||
@@ -164,8 +157,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
backfilled=False,
|
||||
)
|
||||
elif row.type == EventsStreamCurrentStateRow.TypeId:
|
||||
assert isinstance(data, EventsStreamCurrentStateRow)
|
||||
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
|
||||
self._curr_state_delta_stream_cache.entity_has_changed(
|
||||
row.data.room_id, token
|
||||
)
|
||||
|
||||
if data.type == EventTypes.Member:
|
||||
self.get_rooms_for_user_with_stream_ordering.invalidate(
|
||||
@@ -176,15 +170,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
|
||||
def _invalidate_caches_for_event(
|
||||
self,
|
||||
stream_ordering: int,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
etype: str,
|
||||
state_key: Optional[str],
|
||||
redacts: Optional[str],
|
||||
relates_to: Optional[str],
|
||||
backfilled: bool,
|
||||
) -> None:
|
||||
stream_ordering,
|
||||
event_id,
|
||||
room_id,
|
||||
etype,
|
||||
state_key,
|
||||
redacts,
|
||||
relates_to,
|
||||
backfilled,
|
||||
):
|
||||
self._invalidate_get_event_cache(event_id)
|
||||
self.have_seen_event.invalidate((room_id, event_id))
|
||||
|
||||
@@ -213,9 +207,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
self.get_thread_summary.invalidate((relates_to,))
|
||||
self.get_thread_participated.invalidate((relates_to,))
|
||||
|
||||
async def invalidate_cache_and_stream(
|
||||
self, cache_name: str, keys: Tuple[Any, ...]
|
||||
) -> None:
|
||||
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
|
||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||
will know to invalidate their caches.
|
||||
|
||||
@@ -235,12 +227,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
keys,
|
||||
)
|
||||
|
||||
def _invalidate_cache_and_stream(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
cache_func: _CachedFunction,
|
||||
keys: Tuple[Any, ...],
|
||||
) -> None:
|
||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||
will know to invalidate their caches.
|
||||
|
||||
@@ -251,9 +238,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
txn.call_after(cache_func.invalidate, keys)
|
||||
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
|
||||
|
||||
def _invalidate_all_cache_and_stream(
|
||||
self, txn: LoggingTransaction, cache_func: _CachedFunction
|
||||
) -> None:
|
||||
def _invalidate_all_cache_and_stream(self, txn, cache_func):
|
||||
"""Invalidates the entire cache and adds it to the cache stream so slaves
|
||||
will know to invalidate their caches.
|
||||
"""
|
||||
@@ -294,8 +279,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
def _send_invalidation_to_replication(
|
||||
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
|
||||
) -> None:
|
||||
self, txn, cache_name: str, keys: Optional[Iterable[Any]]
|
||||
):
|
||||
"""Notifies replication that given cache has been invalidated.
|
||||
|
||||
Note that this does *not* invalidate the cache locally.
|
||||
@@ -330,7 +315,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
"instance_name": self._instance_name,
|
||||
"cache_func": cache_name,
|
||||
"keys": keys,
|
||||
"invalidation_ts": self._clock.time_msec(),
|
||||
"invalidation_ts": self.clock.time_msec(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -48,6 +48,8 @@ class ExternalIDReuseException(Exception):
|
||||
"""Exception if writing an external id for a user fails,
|
||||
because this external id is given to an other user."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class TokenLookupResult:
|
||||
|
||||
@@ -125,6 +125,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
if not hs.config.server.enable_search:
|
||||
return
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
|
||||
)
|
||||
@@ -240,13 +243,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
|
||||
return len(event_search_rows)
|
||||
|
||||
if self.hs.config.server.enable_search:
|
||||
result = await self.db_pool.runInteraction(
|
||||
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
|
||||
)
|
||||
else:
|
||||
# Don't index anything if search is not enabled.
|
||||
result = 0
|
||||
result = await self.db_pool.runInteraction(
|
||||
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
|
||||
)
|
||||
|
||||
if not result:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
|
||||
@@ -36,6 +36,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
|
||||
config_files = config.appservice.app_service_config_files
|
||||
except AttributeError:
|
||||
logger.warning("Could not get app_service_config_files from config")
|
||||
pass
|
||||
|
||||
appservices = load_appservices(config.server.server_name, config_files)
|
||||
|
||||
|
||||
@@ -18,10 +18,9 @@ import collections
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
@@ -41,7 +40,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import attr
|
||||
from typing_extensions import AsyncContextManager, Literal
|
||||
from typing_extensions import ContextManager, Literal
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import CancelledError
|
||||
@@ -492,7 +491,7 @@ class ReadWriteLock:
|
||||
|
||||
Example:
|
||||
|
||||
async with read_write_lock.read("test_key"):
|
||||
with await read_write_lock.read("test_key"):
|
||||
# do some work
|
||||
"""
|
||||
|
||||
@@ -515,24 +514,22 @@ class ReadWriteLock:
|
||||
# Latest writer queued
|
||||
self.key_to_current_writer: Dict[str, defer.Deferred] = {}
|
||||
|
||||
def read(self, key: str) -> AsyncContextManager:
|
||||
@asynccontextmanager
|
||||
async def _ctx_manager() -> AsyncIterator[None]:
|
||||
new_defer: "defer.Deferred[None]" = defer.Deferred()
|
||||
async def read(self, key: str) -> ContextManager:
|
||||
new_defer: "defer.Deferred[None]" = defer.Deferred()
|
||||
|
||||
curr_readers = self.key_to_current_readers.setdefault(key, set())
|
||||
curr_writer = self.key_to_current_writer.get(key, None)
|
||||
curr_readers = self.key_to_current_readers.setdefault(key, set())
|
||||
curr_writer = self.key_to_current_writer.get(key, None)
|
||||
|
||||
curr_readers.add(new_defer)
|
||||
curr_readers.add(new_defer)
|
||||
|
||||
# We wait for the latest writer to finish writing. We can safely ignore
|
||||
# any existing readers... as they're readers.
|
||||
if curr_writer:
|
||||
await make_deferred_yieldable(curr_writer)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager() -> Iterator[None]:
|
||||
try:
|
||||
# We wait for the latest writer to finish writing. We can safely ignore
|
||||
# any existing readers... as they're readers.
|
||||
# May raise a `CancelledError` if the `Deferred` wrapping us is
|
||||
# cancelled. The `Deferred` we are waiting on must not be cancelled,
|
||||
# since we do not own it.
|
||||
if curr_writer:
|
||||
await make_deferred_yieldable(stop_cancellation(curr_writer))
|
||||
yield
|
||||
finally:
|
||||
with PreserveLoggingContext():
|
||||
@@ -541,35 +538,29 @@ class ReadWriteLock:
|
||||
|
||||
return _ctx_manager()
|
||||
|
||||
def write(self, key: str) -> AsyncContextManager:
|
||||
@asynccontextmanager
|
||||
async def _ctx_manager() -> AsyncIterator[None]:
|
||||
new_defer: "defer.Deferred[None]" = defer.Deferred()
|
||||
async def write(self, key: str) -> ContextManager:
|
||||
new_defer: "defer.Deferred[None]" = defer.Deferred()
|
||||
|
||||
curr_readers = self.key_to_current_readers.get(key, set())
|
||||
curr_writer = self.key_to_current_writer.get(key, None)
|
||||
curr_readers = self.key_to_current_readers.get(key, set())
|
||||
curr_writer = self.key_to_current_writer.get(key, None)
|
||||
|
||||
# We wait on all latest readers and writer.
|
||||
to_wait_on = list(curr_readers)
|
||||
if curr_writer:
|
||||
to_wait_on.append(curr_writer)
|
||||
# We wait on all latest readers and writer.
|
||||
to_wait_on = list(curr_readers)
|
||||
if curr_writer:
|
||||
to_wait_on.append(curr_writer)
|
||||
|
||||
# We can clear the list of current readers since `new_defer` waits
|
||||
# for them to finish.
|
||||
curr_readers.clear()
|
||||
self.key_to_current_writer[key] = new_defer
|
||||
# We can clear the list of current readers since the new writer waits
|
||||
# for them to finish.
|
||||
curr_readers.clear()
|
||||
self.key_to_current_writer[key] = new_defer
|
||||
|
||||
to_wait_on_defer = defer.gatherResults(to_wait_on)
|
||||
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager() -> Iterator[None]:
|
||||
try:
|
||||
# Wait for all current readers and the latest writer to finish.
|
||||
# May raise a `CancelledError` immediately after the wait if the
|
||||
# `Deferred` wrapping us is cancelled. We must only release the lock
|
||||
# once we have acquired it, hence the use of `delay_cancellation`
|
||||
# rather than `stop_cancellation`.
|
||||
await make_deferred_yieldable(delay_cancellation(to_wait_on_defer))
|
||||
yield
|
||||
finally:
|
||||
# Release the lock.
|
||||
with PreserveLoggingContext():
|
||||
new_defer.callback(None)
|
||||
# `self.key_to_current_writer[key]` may be missing if there was another
|
||||
@@ -695,48 +686,12 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
||||
Synapse logcontext rules.
|
||||
|
||||
Returns:
|
||||
A new `Deferred`, which will contain the result of the original `Deferred`.
|
||||
The new `Deferred` will not propagate cancellation through to the original.
|
||||
When cancelled, the new `Deferred` will fail with a `CancelledError`.
|
||||
|
||||
The new `Deferred` will not follow the Synapse logcontext rules and should be
|
||||
wrapped with `make_deferred_yieldable`.
|
||||
A new `Deferred`, which will contain the result of the original `Deferred`,
|
||||
but will not propagate cancellation through to the original. When cancelled,
|
||||
the new `Deferred` will fail with a `CancelledError` and will not follow the
|
||||
Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap
|
||||
the new `Deferred`.
|
||||
"""
|
||||
new_deferred: "defer.Deferred[T]" = defer.Deferred()
|
||||
deferred.chainDeferred(new_deferred)
|
||||
return new_deferred
|
||||
|
||||
|
||||
def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
||||
"""Delay cancellation of a `Deferred` until it resolves.
|
||||
|
||||
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
|
||||
resolve with a `CancelledError` until the original `Deferred` resolves.
|
||||
|
||||
Args:
|
||||
deferred: The `Deferred` to protect against cancellation. May optionally follow
|
||||
the Synapse logcontext rules.
|
||||
|
||||
Returns:
|
||||
A new `Deferred`, which will contain the result of the original `Deferred`.
|
||||
The new `Deferred` will not propagate cancellation through to the original.
|
||||
When cancelled, the new `Deferred` will wait until the original `Deferred`
|
||||
resolves before failing with a `CancelledError`.
|
||||
|
||||
The new `Deferred` will follow the Synapse logcontext rules if `deferred`
|
||||
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
|
||||
wrapped with `make_deferred_yieldable`.
|
||||
"""
|
||||
|
||||
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
|
||||
# before the new deferred is cancelled, we `pause` it to stop the cancellation
|
||||
# propagating. we then `unpause` it once the wrapped deferred completes, to
|
||||
# propagate the exception.
|
||||
new_deferred.pause()
|
||||
new_deferred.errback(Failure(CancelledError()))
|
||||
|
||||
deferred.addBoth(lambda _: new_deferred.unpause())
|
||||
|
||||
new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel)
|
||||
new_deferred: defer.Deferred[T] = defer.Deferred()
|
||||
deferred.chainDeferred(new_deferred)
|
||||
return new_deferred
|
||||
|
||||
@@ -41,7 +41,6 @@ from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
from synapse.util.caches.deferred_cache import DeferredCache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
@@ -351,11 +350,6 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
|
||||
ret = cache.set(cache_key, ret, callback=invalidate_callback)
|
||||
|
||||
# We started a new call to `self.orig`, so we must always wait for it to
|
||||
# complete. Otherwise we might mark our current logging context as
|
||||
# finished while `self.orig` is still using it in the background.
|
||||
ret = delay_cancellation(ret)
|
||||
|
||||
return make_deferred_yieldable(ret)
|
||||
|
||||
wrapped = cast(_CachedFunction, _wrapped)
|
||||
@@ -516,11 +510,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
|
||||
lambda _: results, unwrapFirstError
|
||||
)
|
||||
if missing:
|
||||
# We started a new call to `self.orig`, so we must always wait for it to
|
||||
# complete. Otherwise we might mark our current logging context as
|
||||
# finished while `self.orig` is still using it in the background.
|
||||
d = delay_cancellation(d)
|
||||
return make_deferred_yieldable(d)
|
||||
else:
|
||||
return defer.succeed(results)
|
||||
|
||||
@@ -22,6 +22,8 @@ class TreeCacheNode(dict):
|
||||
leaves.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TreeCache:
|
||||
"""
|
||||
|
||||
@@ -25,7 +25,6 @@ from typing import Iterable, NamedTuple, Optional
|
||||
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DISTRIBUTION_NAME = "matrix-synapse"
|
||||
|
||||
try:
|
||||
@@ -165,10 +164,6 @@ def check_requirements(extra: Optional[str] = None) -> None:
|
||||
errors.append(_not_installed(requirement, extra))
|
||||
else:
|
||||
# We specify prereleases=True to allow prereleases such as RCs.
|
||||
logger.warning(
|
||||
"DEP CHECK: %s, %s, %s, %r",
|
||||
requirement, must_be_installed, dist, dist.version,
|
||||
)
|
||||
if not requirement.specifier.contains(dist.version, prereleases=True):
|
||||
deps_unfulfilled.append(requirement.name)
|
||||
errors.append(_incorrect_version(requirement, dist.version, extra))
|
||||
|
||||
@@ -64,7 +64,6 @@ def build_jinja_env(
|
||||
{
|
||||
"format_ts": _format_ts_filter,
|
||||
"mxc_to_http": _create_mxc_to_http_filter(config.server.public_baseurl),
|
||||
"localpart_from_email": _localpart_from_email_filter,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -113,7 +112,3 @@ def _create_mxc_to_http_filter(
|
||||
|
||||
def _format_ts_filter(value: int, format: str) -> str:
|
||||
return time.strftime(format, time.localtime(value / 1000))
|
||||
|
||||
|
||||
def _localpart_from_email_filter(address: str) -> str:
|
||||
return address.rsplit("@", 1)[0]
|
||||
|
||||
@@ -14,7 +14,12 @@
|
||||
import logging
|
||||
from typing import Dict, FrozenSet, List, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
||||
from synapse.api.constants import (
|
||||
AccountDataTypes,
|
||||
EventTypes,
|
||||
HistoryVisibility,
|
||||
Membership,
|
||||
)
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.storage import Storage
|
||||
@@ -82,8 +87,15 @@ async def filter_events_for_client(
|
||||
state_filter=StateFilter.from_types(types),
|
||||
)
|
||||
|
||||
# Get the users who are ignored by the requesting user.
|
||||
ignore_list = await storage.main.ignored_users(user_id)
|
||||
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
|
||||
user_id, AccountDataTypes.IGNORED_USER_LIST
|
||||
)
|
||||
|
||||
ignore_list: FrozenSet[str] = frozenset()
|
||||
if ignore_dict_content:
|
||||
ignored_users_dict = ignore_dict_content.get("ignored_users", {})
|
||||
if isinstance(ignored_users_dict, dict):
|
||||
ignore_list = frozenset(ignored_users_dict.keys())
|
||||
|
||||
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
|
||||
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import yaml
|
||||
|
||||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
|
||||
class BackgroundUpdateConfigTestCase(HomeserverTestCase):
|
||||
# Tests that the default values in the config are correctly loaded. Note that the default
|
||||
# values are loaded when the corresponding config options are commented out, which is why there isn't
|
||||
# a config specified here.
|
||||
def test_default_configuration(self):
|
||||
background_updater = BackgroundUpdater(
|
||||
self.hs, self.hs.get_datastores().main.db_pool
|
||||
)
|
||||
|
||||
self.assertEqual(background_updater.minimum_background_batch_size, 1)
|
||||
self.assertEqual(background_updater.default_background_batch_size, 100)
|
||||
self.assertEqual(background_updater.sleep_enabled, True)
|
||||
self.assertEqual(background_updater.sleep_duration_ms, 1000)
|
||||
self.assertEqual(background_updater.update_duration_ms, 100)
|
||||
|
||||
# Tests that non-default values for the config options are properly picked up and passed on.
|
||||
@override_config(
|
||||
yaml.safe_load(
|
||||
"""
|
||||
background_updates:
|
||||
background_update_duration_ms: 1000
|
||||
sleep_enabled: false
|
||||
sleep_duration_ms: 600
|
||||
min_batch_size: 5
|
||||
default_batch_size: 50
|
||||
"""
|
||||
)
|
||||
)
|
||||
def test_custom_configuration(self):
|
||||
background_updater = BackgroundUpdater(
|
||||
self.hs, self.hs.get_datastores().main.db_pool
|
||||
)
|
||||
|
||||
self.assertEqual(background_updater.minimum_background_batch_size, 5)
|
||||
self.assertEqual(background_updater.default_background_batch_size, 50)
|
||||
self.assertEqual(background_updater.sleep_enabled, False)
|
||||
self.assertEqual(background_updater.sleep_duration_ms, 600)
|
||||
self.assertEqual(background_updater.update_duration_ms, 1000)
|
||||
@@ -15,15 +15,11 @@
|
||||
from collections import Counter
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
import synapse.storage
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.rest.client import knock, login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -36,7 +32,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
knock.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
|
||||
self.user1 = self.register_user("user1", "password")
|
||||
@@ -45,7 +41,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.user2 = self.register_user("user2", "password")
|
||||
self.token2 = self.login("user2", "password")
|
||||
|
||||
def test_single_public_joined_room(self) -> None:
|
||||
def test_single_public_joined_room(self):
|
||||
"""Test that we write *all* events for a public room"""
|
||||
room_id = self.helper.create_room_as(
|
||||
self.user1, tok=self.token1, is_public=True
|
||||
@@ -78,7 +74,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
|
||||
|
||||
def test_single_private_joined_room(self) -> None:
|
||||
def test_single_private_joined_room(self):
|
||||
"""Tests that we correctly write state when we can't see all events in
|
||||
a room.
|
||||
"""
|
||||
@@ -116,7 +112,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
|
||||
|
||||
def test_single_left_room(self) -> None:
|
||||
def test_single_left_room(self):
|
||||
"""Tests that we don't see events in the room after we leave."""
|
||||
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
|
||||
self.helper.send(room_id, body="Hello!", tok=self.token1)
|
||||
@@ -148,7 +144,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
|
||||
|
||||
def test_single_left_rejoined_private_room(self) -> None:
|
||||
def test_single_left_rejoined_private_room(self):
|
||||
"""Tests that see the correct events in private rooms when we
|
||||
repeatedly join and leave.
|
||||
"""
|
||||
@@ -189,7 +185,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
|
||||
|
||||
def test_invite(self) -> None:
|
||||
def test_invite(self):
|
||||
"""Tests that pending invites get handled correctly."""
|
||||
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
|
||||
self.helper.send(room_id, body="Hello!", tok=self.token1)
|
||||
@@ -208,7 +204,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.assertEqual(args[1].content["membership"], "invite")
|
||||
self.assertTrue(args[2]) # Assert there is at least one bit of state
|
||||
|
||||
def test_knock(self) -> None:
|
||||
def test_knock(self):
|
||||
"""Tests that knock get handled correctly."""
|
||||
# create a knockable v7 room
|
||||
room_id = self.helper.create_room_as(
|
||||
|
||||
@@ -15,12 +15,8 @@ from unittest.mock import Mock
|
||||
|
||||
import pymacaroons
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import AuthError, ResourceLimitError
|
||||
from synapse.rest import admin
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
@@ -31,7 +27,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.macaroon_generator = hs.get_macaroon_generator()
|
||||
|
||||
@@ -46,23 +42,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.user1 = self.register_user("a_user", "pass")
|
||||
|
||||
def test_macaroon_caveats(self) -> None:
|
||||
def test_macaroon_caveats(self):
|
||||
token = self.macaroon_generator.generate_guest_access_token("a_user")
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
|
||||
def verify_gen(caveat: str) -> bool:
|
||||
def verify_gen(caveat):
|
||||
return caveat == "gen = 1"
|
||||
|
||||
def verify_user(caveat: str) -> bool:
|
||||
def verify_user(caveat):
|
||||
return caveat == "user_id = a_user"
|
||||
|
||||
def verify_type(caveat: str) -> bool:
|
||||
def verify_type(caveat):
|
||||
return caveat == "type = access"
|
||||
|
||||
def verify_nonce(caveat: str) -> bool:
|
||||
def verify_nonce(caveat):
|
||||
return caveat.startswith("nonce =")
|
||||
|
||||
def verify_guest(caveat: str) -> bool:
|
||||
def verify_guest(caveat):
|
||||
return caveat == "guest = true"
|
||||
|
||||
v = pymacaroons.Verifier()
|
||||
@@ -73,7 +69,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
v.satisfy_general(verify_guest)
|
||||
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
||||
|
||||
def test_short_term_login_token_gives_user_id(self) -> None:
|
||||
def test_short_term_login_token_gives_user_id(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
self.user1, "", duration_in_ms=5000
|
||||
)
|
||||
@@ -88,7 +84,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
AuthError,
|
||||
)
|
||||
|
||||
def test_short_term_login_token_gives_auth_provider(self) -> None:
|
||||
def test_short_term_login_token_gives_auth_provider(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
self.user1, auth_provider_id="my_idp"
|
||||
)
|
||||
@@ -96,7 +92,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(self.user1, res.user_id)
|
||||
self.assertEqual("my_idp", res.auth_provider_id)
|
||||
|
||||
def test_short_term_login_token_cannot_replace_user_id(self) -> None:
|
||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
self.user1, "", duration_in_ms=5000
|
||||
)
|
||||
@@ -116,7 +112,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
AuthError,
|
||||
)
|
||||
|
||||
def test_mau_limits_disabled(self) -> None:
|
||||
def test_mau_limits_disabled(self):
|
||||
self.auth_blocking._limit_usage_by_mau = False
|
||||
# Ensure does not throw exception
|
||||
self.get_success(
|
||||
@@ -131,7 +127,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_mau_limits_exceeded_large(self) -> None:
|
||||
def test_mau_limits_exceeded_large(self):
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
self.hs.get_datastores().main.get_monthly_active_count = Mock(
|
||||
return_value=make_awaitable(self.large_number_of_users)
|
||||
@@ -154,7 +150,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
ResourceLimitError,
|
||||
)
|
||||
|
||||
def test_mau_limits_parity(self) -> None:
|
||||
def test_mau_limits_parity(self):
|
||||
# Ensure we're not at the unix epoch.
|
||||
self.reactor.advance(1)
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
@@ -193,7 +189,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_mau_limits_not_exceeded(self) -> None:
|
||||
def test_mau_limits_not_exceeded(self):
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
|
||||
self.hs.get_datastores().main.get_monthly_active_count = Mock(
|
||||
@@ -215,7 +211,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def _get_macaroon(self) -> pymacaroons.Macaroon:
|
||||
def _get_macaroon(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
self.user1, "", duration_in_ms=5000
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
|
||||
self.user = self.register_user("user", "pass")
|
||||
self.token = self.login("user", "pass")
|
||||
|
||||
def _deactivate_my_account(self) -> None:
|
||||
def _deactivate_my_account(self):
|
||||
"""
|
||||
Deactivates the account `self.user` using `self.token` and asserts
|
||||
that it returns a 200 success code.
|
||||
|
||||
@@ -14,14 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
import synapse.api.errors
|
||||
import synapse.handlers.device
|
||||
import synapse.storage
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -30,27 +25,28 @@ user2 = "@theresa:bbb"
|
||||
|
||||
|
||||
class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver("server", federation_http_client=None)
|
||||
self.handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# These tests assume that it starts 1000 seconds in.
|
||||
self.reactor.advance(1000)
|
||||
|
||||
def test_device_is_created_with_invalid_name(self) -> None:
|
||||
def test_device_is_created_with_invalid_name(self):
|
||||
self.get_failure(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="foo",
|
||||
initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1),
|
||||
initial_device_display_name="a"
|
||||
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
|
||||
),
|
||||
SynapseError,
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_device_is_created_if_doesnt_exist(self) -> None:
|
||||
def test_device_is_created_if_doesnt_exist(self):
|
||||
res = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
@@ -63,7 +59,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
|
||||
self.assertEqual(dev["display_name"], "display name")
|
||||
|
||||
def test_device_is_preserved_if_exists(self) -> None:
|
||||
def test_device_is_preserved_if_exists(self):
|
||||
res1 = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
@@ -85,7 +81,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
|
||||
self.assertEqual(dev["display_name"], "display name")
|
||||
|
||||
def test_device_id_is_made_up_if_unspecified(self) -> None:
|
||||
def test_device_id_is_made_up_if_unspecified(self):
|
||||
device_id = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@theresa:foo",
|
||||
@@ -97,7 +93,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
|
||||
self.assertEqual(dev["display_name"], "display")
|
||||
|
||||
def test_get_devices_by_user(self) -> None:
|
||||
def test_get_devices_by_user(self):
|
||||
self._record_users()
|
||||
|
||||
res = self.get_success(self.handler.get_devices_by_user(user1))
|
||||
@@ -135,7 +131,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
device_map["abc"],
|
||||
)
|
||||
|
||||
def test_get_device(self) -> None:
|
||||
def test_get_device(self):
|
||||
self._record_users()
|
||||
|
||||
res = self.get_success(self.handler.get_device(user1, "abc"))
|
||||
@@ -150,19 +146,21 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
res,
|
||||
)
|
||||
|
||||
def test_delete_device(self) -> None:
|
||||
def test_delete_device(self):
|
||||
self._record_users()
|
||||
|
||||
# delete the device
|
||||
self.get_success(self.handler.delete_device(user1, "abc"))
|
||||
|
||||
# check the device was deleted
|
||||
self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
|
||||
self.get_failure(
|
||||
self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
|
||||
)
|
||||
|
||||
# we'd like to check the access token was invalidated, but that's a
|
||||
# bit of a PITA.
|
||||
|
||||
def test_delete_device_and_device_inbox(self) -> None:
|
||||
def test_delete_device_and_device_inbox(self):
|
||||
self._record_users()
|
||||
|
||||
# add an device_inbox
|
||||
@@ -193,7 +191,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertIsNone(res)
|
||||
|
||||
def test_update_device(self) -> None:
|
||||
def test_update_device(self):
|
||||
self._record_users()
|
||||
|
||||
update = {"display_name": "new display"}
|
||||
@@ -202,29 +200,32 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
res = self.get_success(self.handler.get_device(user1, "abc"))
|
||||
self.assertEqual(res["display_name"], "new display")
|
||||
|
||||
def test_update_device_too_long_display_name(self) -> None:
|
||||
def test_update_device_too_long_display_name(self):
|
||||
"""Update a device with a display name that is invalid (too long)."""
|
||||
self._record_users()
|
||||
|
||||
# Request to update a device display name with a new value that is longer than allowed.
|
||||
update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)}
|
||||
update = {
|
||||
"display_name": "a"
|
||||
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
|
||||
}
|
||||
self.get_failure(
|
||||
self.handler.update_device(user1, "abc", update),
|
||||
SynapseError,
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
# Ensure the display name was not updated.
|
||||
res = self.get_success(self.handler.get_device(user1, "abc"))
|
||||
self.assertEqual(res["display_name"], "display 2")
|
||||
|
||||
def test_update_unknown_device(self) -> None:
|
||||
def test_update_unknown_device(self):
|
||||
update = {"display_name": "new_display"}
|
||||
self.get_failure(
|
||||
self.handler.update_device("user_id", "unknown_device_id", update),
|
||||
NotFoundError,
|
||||
synapse.api.errors.NotFoundError,
|
||||
)
|
||||
|
||||
def _record_users(self) -> None:
|
||||
def _record_users(self):
|
||||
# check this works for both devices which have a recorded client_ip,
|
||||
# and those which don't.
|
||||
self._record_user(user1, "xyz", "display 0")
|
||||
@@ -237,13 +238,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
self.reactor.advance(10000)
|
||||
|
||||
def _record_user(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
display_name: str,
|
||||
access_token: Optional[str] = None,
|
||||
ip: Optional[str] = None,
|
||||
) -> None:
|
||||
self, user_id, device_id, display_name, access_token=None, ip=None
|
||||
):
|
||||
device_id = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id=user_id,
|
||||
@@ -252,7 +248,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
if access_token is not None and ip is not None:
|
||||
if ip is not None:
|
||||
self.get_success(
|
||||
self.store.insert_client_ip(
|
||||
user_id, access_token, ip, "user_agent", device_id
|
||||
@@ -262,7 +258,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
||||
class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver("server", federation_http_client=None)
|
||||
self.handler = hs.get_device_handler()
|
||||
self.registration = hs.get_registration_handler()
|
||||
@@ -270,7 +266,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
self.store = hs.get_datastores().main
|
||||
return hs
|
||||
|
||||
def test_dehydrate_and_rehydrate_device(self) -> None:
|
||||
def test_dehydrate_and_rehydrate_device(self):
|
||||
user_id = "@boris:dehydration"
|
||||
|
||||
self.get_success(self.store.register_user(user_id, "foobar"))
|
||||
@@ -307,7 +303,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
access_token=access_token,
|
||||
device_id="not the right device ID",
|
||||
),
|
||||
NotFoundError,
|
||||
synapse.api.errors.NotFoundError,
|
||||
)
|
||||
|
||||
# dehydrating the right devices should succeed and change our device ID
|
||||
@@ -335,7 +331,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
# make sure that the device ID that we were initially assigned no longer exists
|
||||
self.get_failure(
|
||||
self.handler.get_device(user_id, device_id),
|
||||
NotFoundError,
|
||||
synapse.api.errors.NotFoundError,
|
||||
)
|
||||
|
||||
# make sure that there's no device available for dehydrating now
|
||||
|
||||
@@ -12,18 +12,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Awaitable, Callable, Dict
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from unittest.mock import Mock
|
||||
|
||||
import synapse.api.errors
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.rest.client import directory, login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, RoomAlias, create_requester
|
||||
from synapse.util import Clock
|
||||
from synapse.types import RoomAlias, create_requester
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
@@ -32,15 +28,13 @@ from tests.test_utils import make_awaitable
|
||||
class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests the directory service."""
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.mock_federation = Mock()
|
||||
self.mock_registry = Mock()
|
||||
|
||||
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
|
||||
self.query_handlers = {}
|
||||
|
||||
def register_query_handler(
|
||||
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||
) -> None:
|
||||
def register_query_handler(query_type, handler):
|
||||
self.query_handlers[query_type] = handler
|
||||
|
||||
self.mock_registry.register_query_handler = register_query_handler
|
||||
@@ -60,7 +54,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def test_get_local_association(self) -> None:
|
||||
def test_get_local_association(self):
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
self.my_room, "!8765qwer:test", ["test"]
|
||||
@@ -71,7 +65,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
|
||||
|
||||
def test_get_remote_association(self) -> None:
|
||||
def test_get_remote_association(self):
|
||||
self.mock_federation.make_query.return_value = make_awaitable(
|
||||
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
|
||||
)
|
||||
@@ -89,7 +83,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def test_incoming_fed_query(self) -> None:
|
||||
def test_incoming_fed_query(self):
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
self.your_room, "!8765asdf:test", ["test"]
|
||||
@@ -111,7 +105,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||
directory.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.handler = hs.get_directory_handler()
|
||||
|
||||
# Create user
|
||||
@@ -131,7 +125,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||
self.test_user_tok = self.login("user", "pass")
|
||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||
|
||||
def test_create_alias_joined_room(self) -> None:
|
||||
def test_create_alias_joined_room(self):
|
||||
"""A user can create an alias for a room they're in."""
|
||||
self.get_success(
|
||||
self.handler.create_association(
|
||||
@@ -141,7 +135,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_create_alias_other_room(self) -> None:
|
||||
def test_create_alias_other_room(self):
|
||||
"""A user cannot create an alias for a room they're NOT in."""
|
||||
other_room_id = self.helper.create_room_as(
|
||||
self.admin_user, tok=self.admin_user_tok
|
||||
@@ -156,7 +150,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_create_alias_admin(self) -> None:
|
||||
def test_create_alias_admin(self):
|
||||
"""An admin can create an alias for a room they're NOT in."""
|
||||
other_room_id = self.helper.create_room_as(
|
||||
self.test_user, tok=self.test_user_tok
|
||||
@@ -179,7 +173,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
directory.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastores().main
|
||||
self.handler = hs.get_directory_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
@@ -201,7 +195,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
self.test_user_tok = self.login("user", "pass")
|
||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||
|
||||
def _create_alias(self, user) -> None:
|
||||
def _create_alias(self, user):
|
||||
# Create a new alias to this room.
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
@@ -209,7 +203,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_delete_alias_not_allowed(self) -> None:
|
||||
def test_delete_alias_not_allowed(self):
|
||||
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
|
||||
self._create_alias(self.admin_user)
|
||||
self.get_failure(
|
||||
@@ -219,7 +213,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
synapse.api.errors.AuthError,
|
||||
)
|
||||
|
||||
def test_delete_alias_creator(self) -> None:
|
||||
def test_delete_alias_creator(self):
|
||||
"""An alias creator can delete their own alias."""
|
||||
# Create an alias from a different user.
|
||||
self._create_alias(self.test_user)
|
||||
@@ -238,7 +232,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_delete_alias_admin(self) -> None:
|
||||
def test_delete_alias_admin(self):
|
||||
"""A server admin can delete an alias created by another user."""
|
||||
# Create an alias from a different user.
|
||||
self._create_alias(self.test_user)
|
||||
@@ -257,7 +251,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_delete_alias_sufficient_power(self) -> None:
|
||||
def test_delete_alias_sufficient_power(self):
|
||||
"""A user with a sufficient power level should be able to delete an alias."""
|
||||
self._create_alias(self.admin_user)
|
||||
|
||||
@@ -294,7 +288,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
directory.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastores().main
|
||||
self.handler = hs.get_directory_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
@@ -323,7 +317,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
return room_alias
|
||||
|
||||
def _set_canonical_alias(self, content) -> None:
|
||||
def _set_canonical_alias(self, content):
|
||||
"""Configure the canonical alias state on the room."""
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
@@ -340,7 +334,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_remove_alias(self) -> None:
|
||||
def test_remove_alias(self):
|
||||
"""Removing an alias that is the canonical alias should remove it there too."""
|
||||
# Set this new alias as the canonical alias for this room
|
||||
self._set_canonical_alias(
|
||||
@@ -362,7 +356,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
self.assertNotIn("alias", data["content"])
|
||||
self.assertNotIn("alt_aliases", data["content"])
|
||||
|
||||
def test_remove_other_alias(self) -> None:
|
||||
def test_remove_other_alias(self):
|
||||
"""Removing an alias listed as in alt_aliases should remove it there too."""
|
||||
# Create a second alias.
|
||||
other_test_alias = "#test2:test"
|
||||
@@ -399,7 +393,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [directory.register_servlets, room.register_servlets]
|
||||
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
|
||||
# Add custom alias creation rules to the config.
|
||||
@@ -409,7 +403,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
|
||||
return config
|
||||
|
||||
def test_denied(self) -> None:
|
||||
def test_denied(self):
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
channel = self.make_request(
|
||||
@@ -419,7 +413,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(403, channel.code, channel.result)
|
||||
|
||||
def test_allowed(self) -> None:
|
||||
def test_allowed(self):
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
channel = self.make_request(
|
||||
@@ -429,7 +423,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
def test_denied_during_creation(self) -> None:
|
||||
def test_denied_during_creation(self):
|
||||
"""A room alias that is not allowed should be rejected during creation."""
|
||||
# Invalid room alias.
|
||||
self.helper.create_room_as(
|
||||
@@ -438,7 +432,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
extra_content={"room_alias_name": "foo"},
|
||||
)
|
||||
|
||||
def test_allowed_during_creation(self) -> None:
|
||||
def test_allowed_during_creation(self):
|
||||
"""A valid room alias should be allowed during creation."""
|
||||
room_id = self.helper.create_room_as(
|
||||
self.user_id,
|
||||
@@ -465,7 +459,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
data = {"room_alias_name": "unofficial_test"}
|
||||
allowed_localpart = "allowed"
|
||||
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
|
||||
# Add custom room list publication rules to the config.
|
||||
@@ -480,9 +474,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
|
||||
return config
|
||||
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
||||
) -> HomeServer:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
|
||||
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
|
||||
|
||||
@@ -491,7 +483,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def test_denied_without_publication_permission(self) -> None:
|
||||
def test_denied_without_publication_permission(self):
|
||||
"""
|
||||
Try to create a room, register an alias for it, and publish it,
|
||||
as a user without permission to publish rooms.
|
||||
@@ -505,7 +497,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
expect_code=403,
|
||||
)
|
||||
|
||||
def test_allowed_when_creating_private_room(self) -> None:
|
||||
def test_allowed_when_creating_private_room(self):
|
||||
"""
|
||||
Try to create a room, register an alias for it, and NOT publish it,
|
||||
as a user without permission to publish rooms.
|
||||
@@ -519,7 +511,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
expect_code=200,
|
||||
)
|
||||
|
||||
def test_allowed_with_publication_permission(self) -> None:
|
||||
def test_allowed_with_publication_permission(self):
|
||||
"""
|
||||
Try to create a room, register an alias for it, and publish it,
|
||||
as a user WITH permission to publish rooms.
|
||||
@@ -533,7 +525,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
expect_code=200,
|
||||
)
|
||||
|
||||
def test_denied_publication_with_invalid_alias(self) -> None:
|
||||
def test_denied_publication_with_invalid_alias(self):
|
||||
"""
|
||||
Try to create a room, register an alias for it, and publish it,
|
||||
as a user WITH permission to publish rooms.
|
||||
@@ -546,7 +538,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
expect_code=403,
|
||||
)
|
||||
|
||||
def test_can_create_as_private_room_after_rejection(self) -> None:
|
||||
def test_can_create_as_private_room_after_rejection(self):
|
||||
"""
|
||||
After failing to publish a room with an alias as a user without publish permission,
|
||||
retry as the same user, but without publishing the room.
|
||||
@@ -557,7 +549,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
self.test_denied_without_publication_permission()
|
||||
self.test_allowed_when_creating_private_room()
|
||||
|
||||
def test_can_create_with_permission_after_rejection(self) -> None:
|
||||
def test_can_create_with_permission_after_rejection(self):
|
||||
"""
|
||||
After failing to publish a room with an alias as a user without publish permission,
|
||||
retry as someone with permission, using the same alias.
|
||||
@@ -574,9 +566,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [directory.register_servlets, room.register_servlets]
|
||||
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
||||
) -> HomeServer:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
channel = self.make_request(
|
||||
@@ -589,7 +579,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def test_disabling_room_list(self) -> None:
|
||||
def test_disabling_room_list(self):
|
||||
self.room_list_handler.enable_room_list_search = True
|
||||
self.directory_handler.enable_room_list_search = True
|
||||
|
||||
|
||||
@@ -20,37 +20,33 @@ from parameterized import parameterized
|
||||
from signedjson import key as key, sign as sign
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import RoomEncryptionAlgorithms
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
|
||||
|
||||
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
return self.setup_test_homeserver(federation_client=mock.Mock())
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.handler = hs.get_e2e_keys_handler()
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
def test_query_local_devices_no_devices(self) -> None:
|
||||
def test_query_local_devices_no_devices(self):
|
||||
"""If the user has no devices, we expect an empty list."""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
||||
self.assertDictEqual(res, {local_user: {}})
|
||||
|
||||
def test_reupload_one_time_keys(self) -> None:
|
||||
def test_reupload_one_time_keys(self):
|
||||
"""we should be able to re-upload the same keys"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
keys: JsonDict = {
|
||||
keys = {
|
||||
"alg1:k1": "key1",
|
||||
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||
"alg2:k3": {"key": "key3"},
|
||||
@@ -78,7 +74,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
|
||||
)
|
||||
|
||||
def test_change_one_time_keys(self) -> None:
|
||||
def test_change_one_time_keys(self):
|
||||
"""attempts to change one-time-keys should be rejected"""
|
||||
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
@@ -138,7 +134,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
SynapseError,
|
||||
)
|
||||
|
||||
def test_claim_one_time_key(self) -> None:
|
||||
def test_claim_one_time_key(self):
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
keys = {"alg1:k1": "key1"}
|
||||
@@ -165,7 +161,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_fallback_key(self) -> None:
|
||||
def test_fallback_key(self):
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||
@@ -298,7 +294,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
||||
)
|
||||
|
||||
def test_replace_master_key(self) -> None:
|
||||
def test_replace_master_key(self):
|
||||
"""uploading a new signing key should make the old signing key unavailable"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
keys1 = {
|
||||
@@ -332,7 +328,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
||||
|
||||
def test_reupload_signatures(self) -> None:
|
||||
def test_reupload_signatures(self):
|
||||
"""re-uploading a signature should not fail"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
keys1 = {
|
||||
@@ -437,7 +433,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
|
||||
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
|
||||
|
||||
def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
|
||||
def test_self_signing_key_doesnt_show_up_as_device(self):
|
||||
"""signing keys should be hidden when fetching a user's devices"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
keys1 = {
|
||||
@@ -466,7 +462,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
||||
self.assertDictEqual(res, {local_user: {}})
|
||||
|
||||
def test_upload_signatures(self) -> None:
|
||||
def test_upload_signatures(self):
|
||||
"""should check signatures that are uploaded"""
|
||||
# set up a user with cross-signing keys and a device. This user will
|
||||
# try uploading signatures
|
||||
@@ -690,7 +686,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
|
||||
)
|
||||
|
||||
def test_query_devices_remote_no_sync(self) -> None:
|
||||
def test_query_devices_remote_no_sync(self):
|
||||
"""Tests that querying keys for a remote user that we don't share a room
|
||||
with returns the cross signing keys correctly.
|
||||
"""
|
||||
@@ -763,7 +759,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_query_devices_remote_sync(self) -> None:
|
||||
def test_query_devices_remote_sync(self):
|
||||
"""Tests that querying keys for a remote user that we share a room with,
|
||||
but haven't yet fetched the keys for, returns the cross signing keys
|
||||
correctly.
|
||||
@@ -849,7 +845,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
(["device_1", "device_2"],),
|
||||
]
|
||||
)
|
||||
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
|
||||
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
|
||||
"""Test that requests for all of a remote user's devices are cached.
|
||||
|
||||
We do this by asserting that only one call over federation was made, and that
|
||||
@@ -857,7 +853,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
local_user_id = "@test:test"
|
||||
remote_user_id = "@test:other"
|
||||
request_body: JsonDict = {"device_keys": {remote_user_id: []}}
|
||||
request_body = {"device_keys": {remote_user_id: []}}
|
||||
|
||||
response_devices = [
|
||||
{
|
||||
|
||||
@@ -13,18 +13,14 @@
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pymacaroons
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
from synapse.types import UserID
|
||||
from synapse.util.macaroons import get_value_from_macaroon
|
||||
|
||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
||||
@@ -102,7 +98,7 @@ class TestMappingProviderFailures(TestMappingProvider):
|
||||
}
|
||||
|
||||
|
||||
async def get_json(url: str) -> JsonDict:
|
||||
async def get_json(url):
|
||||
# Mock get_json calls to handle jwks & oidc discovery endpoints
|
||||
if url == WELL_KNOWN:
|
||||
# Minimal discovery document, as defined in OpenID.Discovery
|
||||
@@ -120,8 +116,6 @@ async def get_json(url: str) -> JsonDict:
|
||||
elif url == JWKS_URI:
|
||||
return {"keys": []}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _key_file_path() -> str:
|
||||
"""path to a file containing the private half of a test key"""
|
||||
@@ -153,12 +147,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
if not HAS_OIDC:
|
||||
skip = "requires OIDC"
|
||||
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.http_client = Mock(spec=["get_json"])
|
||||
self.http_client.get_json.side_effect = get_json
|
||||
self.http_client.user_agent = b"Synapse Test"
|
||||
@@ -170,7 +164,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
sso_handler = hs.get_sso_handler()
|
||||
# Mock the render error method.
|
||||
self.render_error = Mock(return_value=None)
|
||||
sso_handler.render_error = self.render_error # type: ignore[assignment]
|
||||
sso_handler.render_error = self.render_error
|
||||
|
||||
# Reduce the number of attempts when generating MXIDs.
|
||||
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||
@@ -199,14 +193,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
return args
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_config(self) -> None:
|
||||
def test_config(self):
|
||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
||||
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
||||
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
||||
|
||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
|
||||
def test_discovery(self) -> None:
|
||||
def test_discovery(self):
|
||||
"""The handler should discover the endpoints from OIDC discovery document."""
|
||||
# This would throw if some metadata were invalid
|
||||
metadata = self.get_success(self.provider.load_metadata())
|
||||
@@ -225,13 +219,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.http_client.get_json.assert_not_called()
|
||||
|
||||
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||
def test_no_discovery(self) -> None:
|
||||
def test_no_discovery(self):
|
||||
"""When discovery is disabled, it should not try to load from discovery document."""
|
||||
self.get_success(self.provider.load_metadata())
|
||||
self.http_client.get_json.assert_not_called()
|
||||
|
||||
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||
def test_load_jwks(self) -> None:
|
||||
def test_load_jwks(self):
|
||||
"""JWKS loading is done once (then cached) if used."""
|
||||
jwks = self.get_success(self.provider.load_jwks())
|
||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||
@@ -259,7 +253,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_validate_config(self) -> None:
|
||||
def test_validate_config(self):
|
||||
"""Provider metadatas are extensively validated."""
|
||||
h = self.provider
|
||||
|
||||
@@ -342,14 +336,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
force_load_metadata()
|
||||
|
||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
|
||||
def test_skip_verification(self) -> None:
|
||||
def test_skip_verification(self):
|
||||
"""Provider metadata validation can be disabled by config."""
|
||||
with self.metadata_edit({"issuer": "http://insecure"}):
|
||||
# This should not throw
|
||||
get_awaitable_result(self.provider.load_metadata())
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_redirect_request(self) -> None:
|
||||
def test_redirect_request(self):
|
||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||
req = Mock(spec=["cookies"])
|
||||
req.cookies = []
|
||||
@@ -393,7 +387,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.assertEqual(redirect, "http://client/redirect")
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_callback_error(self) -> None:
|
||||
def test_callback_error(self):
|
||||
"""Errors from the provider returned in the callback are displayed."""
|
||||
request = Mock(args={})
|
||||
request.args[b"error"] = [b"invalid_client"]
|
||||
@@ -405,7 +399,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.assertRenderedError("invalid_client", "some description")
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_callback(self) -> None:
|
||||
def test_callback(self):
|
||||
"""Code callback works and display errors if something went wrong.
|
||||
|
||||
A lot of scenarios are tested here:
|
||||
@@ -434,9 +428,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"username": username,
|
||||
}
|
||||
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
@@ -474,7 +468,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.assertRenderedError("mapping_error")
|
||||
|
||||
# Handle ID token errors
|
||||
self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
|
||||
self.provider._parse_id_token = simple_async_mock(raises=Exception())
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_token")
|
||||
|
||||
@@ -489,7 +483,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"type": "bearer",
|
||||
"access_token": "access_token",
|
||||
}
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
@@ -516,8 +510,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
id_token = {
|
||||
"sid": "abcdefgh",
|
||||
}
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=id_token)
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
self.provider._fetch_userinfo.reset_mock()
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
@@ -537,21 +531,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.render_error.assert_not_called()
|
||||
|
||||
# Handle userinfo fetching error
|
||||
self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
|
||||
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("fetch_error")
|
||||
|
||||
# Handle code exchange failure
|
||||
from synapse.handlers.oidc import OidcError
|
||||
|
||||
self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
|
||||
self.provider._exchange_code = simple_async_mock(
|
||||
raises=OidcError("invalid_request")
|
||||
)
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_request")
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_callback_session(self) -> None:
|
||||
def test_callback_session(self):
|
||||
"""The callback verifies the session presence and validity"""
|
||||
request = Mock(spec=["args", "getCookie", "cookies"])
|
||||
|
||||
@@ -596,7 +590,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
@override_config(
|
||||
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
|
||||
)
|
||||
def test_exchange_code(self) -> None:
|
||||
def test_exchange_code(self):
|
||||
"""Code exchange behaves correctly and handles various error scenarios."""
|
||||
token = {"type": "bearer"}
|
||||
token_json = json.dumps(token).encode("utf-8")
|
||||
@@ -692,7 +686,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_exchange_code_jwt_key(self) -> None:
|
||||
def test_exchange_code_jwt_key(self):
|
||||
"""Test that code exchange works with a JWK client secret."""
|
||||
from authlib.jose import jwt
|
||||
|
||||
@@ -747,7 +741,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_exchange_code_no_auth(self) -> None:
|
||||
def test_exchange_code_no_auth(self):
|
||||
"""Test that code exchange works with no client secret."""
|
||||
token = {"type": "bearer"}
|
||||
self.http_client.request = simple_async_mock(
|
||||
@@ -782,7 +776,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_extra_attributes(self) -> None:
|
||||
def test_extra_attributes(self):
|
||||
"""
|
||||
Login while using a mapping provider that implements get_extra_attributes.
|
||||
"""
|
||||
@@ -796,8 +790,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"username": "foo",
|
||||
"phone": "1234567",
|
||||
}
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
@@ -823,12 +817,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_map_userinfo_to_user(self) -> None:
|
||||
def test_map_userinfo_to_user(self):
|
||||
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
userinfo: dict = {
|
||||
userinfo = {
|
||||
"sub": "test_user",
|
||||
"username": "test_user",
|
||||
}
|
||||
@@ -876,7 +870,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
||||
def test_map_userinfo_to_existing_user(self) -> None:
|
||||
def test_map_userinfo_to_existing_user(self):
|
||||
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
||||
store = self.hs.get_datastores().main
|
||||
user = UserID.from_string("@test_user:test")
|
||||
@@ -980,7 +974,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_map_userinfo_to_invalid_localpart(self) -> None:
|
||||
def test_map_userinfo_to_invalid_localpart(self):
|
||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||
self.get_success(
|
||||
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
|
||||
@@ -997,7 +991,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_map_userinfo_to_user_retries(self) -> None:
|
||||
def test_map_userinfo_to_user_retries(self):
|
||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
@@ -1045,7 +1039,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_empty_localpart(self) -> None:
|
||||
def test_empty_localpart(self):
|
||||
"""Attempts to map onto an empty localpart should be rejected."""
|
||||
userinfo = {
|
||||
"sub": "tester",
|
||||
@@ -1064,7 +1058,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_null_localpart(self) -> None:
|
||||
def test_null_localpart(self):
|
||||
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
|
||||
userinfo = {
|
||||
"sub": "tester",
|
||||
@@ -1081,7 +1075,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_attribute_requirements(self) -> None:
|
||||
def test_attribute_requirements(self):
|
||||
"""The required attributes must be met from the OIDC userinfo response."""
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
@@ -1121,7 +1115,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_attribute_requirements_contains(self) -> None:
|
||||
def test_attribute_requirements_contains(self):
|
||||
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
@@ -1152,7 +1146,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_attribute_requirements_mismatch(self) -> None:
|
||||
def test_attribute_requirements_mismatch(self):
|
||||
"""
|
||||
Test that auth fails if attributes exist but don't match,
|
||||
or are non-string values.
|
||||
@@ -1160,7 +1154,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
# userinfo with "test": "not_foobar" attribute should fail
|
||||
userinfo: dict = {
|
||||
userinfo = {
|
||||
"sub": "tester",
|
||||
"username": "tester",
|
||||
"test": "not_foobar",
|
||||
@@ -1254,9 +1248,9 @@ async def _make_callback_with_userinfo(
|
||||
|
||||
handler = hs.get_oidc_handler()
|
||||
provider = handler._providers["oidc"]
|
||||
provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
|
||||
provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||
provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||
provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
|
||||
provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
|
||||
state = "state"
|
||||
session = handler._token_generator.generate_oidc_session_token(
|
||||
|
||||
@@ -124,6 +124,7 @@ class PasswordCustomAuthProvider:
|
||||
("m.login.password", ("password",)): self.check_auth,
|
||||
}
|
||||
)
|
||||
pass
|
||||
|
||||
def check_auth(self, *args):
|
||||
return mock_password_provider.check_auth(*args)
|
||||
|
||||
@@ -11,17 +11,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Awaitable, Callable, Dict
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.rest import admin
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
@@ -32,15 +29,13 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [admin.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.mock_federation = Mock()
|
||||
self.mock_registry = Mock()
|
||||
|
||||
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
|
||||
self.query_handlers = {}
|
||||
|
||||
def register_query_handler(
|
||||
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||
) -> None:
|
||||
def register_query_handler(query_type, handler):
|
||||
self.query_handlers[query_type] = handler
|
||||
|
||||
self.mock_registry.register_query_handler = register_query_handler
|
||||
@@ -52,7 +47,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs: HomeServer):
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
self.frank = UserID.from_string("@1234abcd:test")
|
||||
@@ -63,7 +58,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.handler = hs.get_profile_handler()
|
||||
|
||||
def test_get_my_name(self) -> None:
|
||||
def test_get_my_name(self):
|
||||
self.get_success(
|
||||
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
)
|
||||
@@ -72,7 +67,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual("Frank", displayname)
|
||||
|
||||
def test_set_my_name(self) -> None:
|
||||
def test_set_my_name(self):
|
||||
self.get_success(
|
||||
self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||
@@ -115,7 +110,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
||||
)
|
||||
|
||||
def test_set_my_name_if_disabled(self) -> None:
|
||||
def test_set_my_name_if_disabled(self):
|
||||
self.hs.config.registration.enable_set_displayname = False
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
@@ -140,7 +135,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
SynapseError,
|
||||
)
|
||||
|
||||
def test_set_my_name_noauth(self) -> None:
|
||||
def test_set_my_name_noauth(self):
|
||||
self.get_failure(
|
||||
self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
||||
@@ -148,7 +143,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
AuthError,
|
||||
)
|
||||
|
||||
def test_get_other_name(self) -> None:
|
||||
def test_get_other_name(self):
|
||||
self.mock_federation.make_query.return_value = make_awaitable(
|
||||
{"displayname": "Alice"}
|
||||
)
|
||||
@@ -163,7 +158,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def test_incoming_fed_query(self) -> None:
|
||||
def test_incoming_fed_query(self):
|
||||
self.get_success(self.store.create_profile("caroline"))
|
||||
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
|
||||
|
||||
@@ -179,7 +174,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual({"displayname": "Caroline"}, response)
|
||||
|
||||
def test_get_my_avatar(self) -> None:
|
||||
def test_get_my_avatar(self):
|
||||
self.get_success(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
@@ -189,7 +184,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual("http://my.server/me.png", avatar_url)
|
||||
|
||||
def test_set_my_avatar(self) -> None:
|
||||
def test_set_my_avatar(self):
|
||||
self.get_success(
|
||||
self.handler.set_avatar_url(
|
||||
self.frank,
|
||||
@@ -230,7 +225,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
)
|
||||
|
||||
def test_set_my_avatar_if_disabled(self) -> None:
|
||||
def test_set_my_avatar_if_disabled(self):
|
||||
self.hs.config.registration.enable_set_avatar_url = False
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
@@ -255,7 +250,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
SynapseError,
|
||||
)
|
||||
|
||||
def test_avatar_constraints_no_config(self) -> None:
|
||||
def test_avatar_constraints_no_config(self):
|
||||
"""Tests that the method to check an avatar against configured constraints skips
|
||||
all of its check if no constraint is configured.
|
||||
"""
|
||||
@@ -268,7 +263,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertTrue(res)
|
||||
|
||||
@unittest.override_config({"max_avatar_size": 50})
|
||||
def test_avatar_constraints_missing(self) -> None:
|
||||
def test_avatar_constraints_missing(self):
|
||||
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
|
||||
be found.
|
||||
"""
|
||||
@@ -278,7 +273,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertFalse(res)
|
||||
|
||||
@unittest.override_config({"max_avatar_size": 50})
|
||||
def test_avatar_constraints_file_size(self) -> None:
|
||||
def test_avatar_constraints_file_size(self):
|
||||
"""Tests that a file that's above the allowed file size is forbidden but one
|
||||
that's below it is allowed.
|
||||
"""
|
||||
@@ -300,7 +295,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertFalse(res)
|
||||
|
||||
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
|
||||
def test_avatar_constraint_mime_type(self) -> None:
|
||||
def test_avatar_constraint_mime_type(self):
|
||||
"""Tests that a file with an unauthorised MIME type is forbidden but one with
|
||||
an authorised content type is allowed.
|
||||
"""
|
||||
|
||||
@@ -12,16 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import RedirectException
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.test_utils import simple_async_mock
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
@@ -85,10 +81,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
|
||||
|
||||
|
||||
class SamlHandlerTestCase(HomeserverTestCase):
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
saml_config: Dict[str, Any] = {
|
||||
saml_config = {
|
||||
"sp_config": {"metadata": {}},
|
||||
# Disable grandfathering.
|
||||
"grandfathered_mxid_source_attribute": None,
|
||||
@@ -102,7 +98,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver()
|
||||
|
||||
self.handler = hs.get_saml_handler()
|
||||
@@ -118,7 +114,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
elif not has_xmlsec1:
|
||||
skip = "Requires xmlsec1"
|
||||
|
||||
def test_map_saml_response_to_user(self) -> None:
|
||||
def test_map_saml_response_to_user(self):
|
||||
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
|
||||
|
||||
# stub out the auth handler
|
||||
@@ -144,7 +140,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||
def test_map_saml_response_to_existing_user(self) -> None:
|
||||
def test_map_saml_response_to_existing_user(self):
|
||||
"""Existing users can log in with SAML account."""
|
||||
store = self.hs.get_datastores().main
|
||||
self.get_success(
|
||||
@@ -190,7 +186,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
auth_provider_session_id=None,
|
||||
)
|
||||
|
||||
def test_map_saml_response_to_invalid_localpart(self) -> None:
|
||||
def test_map_saml_response_to_invalid_localpart(self):
|
||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||
|
||||
# stub out the auth handler
|
||||
@@ -211,7 +207,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
auth_handler.complete_sso_login.assert_not_called()
|
||||
|
||||
def test_map_saml_response_to_user_retries(self) -> None:
|
||||
def test_map_saml_response_to_user_retries(self):
|
||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||
|
||||
# stub out the auth handler and error renderer
|
||||
@@ -275,7 +271,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_map_saml_response_redirect(self) -> None:
|
||||
def test_map_saml_response_redirect(self):
|
||||
"""Test a mapping provider that raises a RedirectException"""
|
||||
|
||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
||||
@@ -296,7 +292,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_attribute_requirements(self) -> None:
|
||||
def test_attribute_requirements(self):
|
||||
"""The required attributes must be met from the SAML response."""
|
||||
|
||||
# stub out the auth handler
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Tuple
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
@@ -19,7 +18,7 @@ from twisted.internet.defer import Deferred
|
||||
import synapse.rest.admin
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.rest.client import login, push_rule, receipts, room
|
||||
from synapse.rest.client import login, receipts, room
|
||||
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
@@ -30,7 +29,6 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
receipts.register_servlets,
|
||||
push_rule.register_servlets,
|
||||
]
|
||||
user_id = True
|
||||
hijack_auth = False
|
||||
@@ -41,12 +39,12 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.push_attempts: List[tuple[Deferred, str, dict]] = []
|
||||
self.push_attempts = []
|
||||
|
||||
m = Mock()
|
||||
|
||||
def post_json_get_json(url, body):
|
||||
d: Deferred = Deferred()
|
||||
d = Deferred()
|
||||
self.push_attempts.append((d, url, body))
|
||||
return make_deferred_yieldable(d)
|
||||
|
||||
@@ -721,67 +719,3 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
access_token=access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
|
||||
user_id = self.register_user(username, "pass")
|
||||
access_token = self.login(username, "pass")
|
||||
|
||||
# Register the pusher
|
||||
user_tuple = self.get_success(
|
||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
app_id="m.http",
|
||||
app_display_name="HTTP Push Notifications",
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
||||
)
|
||||
)
|
||||
|
||||
return user_id, access_token
|
||||
|
||||
def test_dont_notify_rule_overrides_message(self):
|
||||
"""
|
||||
The override push rule will suppress notification
|
||||
"""
|
||||
|
||||
user_id, access_token = self._make_user_with_pusher("user")
|
||||
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
|
||||
|
||||
# Create a room
|
||||
room = self.helper.create_room_as(user_id, tok=access_token)
|
||||
|
||||
# Disable user notifications for this room -> user
|
||||
body = {
|
||||
"conditions": [{"kind": "event_match", "key": "room_id", "pattern": room}],
|
||||
"actions": ["dont_notify"],
|
||||
}
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
"/pushrules/global/override/best.friend",
|
||||
body,
|
||||
access_token=access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Check we start with no pushes
|
||||
self.assertEqual(len(self.push_attempts), 0)
|
||||
|
||||
# The other user joins
|
||||
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
|
||||
|
||||
# The other user sends a message (ignored by dont_notify push rule set above)
|
||||
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||
self.assertEqual(len(self.push_attempts), 0)
|
||||
|
||||
# The user sends a message back (sends a notification)
|
||||
self.helper.send(room, body="Hello", tok=access_token)
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
|
||||
@@ -39,7 +39,6 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
self.store = hs.get_datastores().main
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
self.updater = BackgroundUpdater(hs, self.store.db_pool)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
@@ -136,10 +135,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"""Test the status API works with a background update."""
|
||||
|
||||
# Create a new background update
|
||||
|
||||
self._register_bg_update()
|
||||
|
||||
self.store.db_pool.updates.start_doing_background_updates()
|
||||
|
||||
self.reactor.pump([1.0, 1.0, 1.0])
|
||||
|
||||
channel = self.make_request(
|
||||
@@ -159,7 +158,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"average_items_per_ms": 0.1,
|
||||
"total_duration_ms": 1000.0,
|
||||
"total_item_count": (
|
||||
self.updater.default_background_batch_size
|
||||
BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE
|
||||
),
|
||||
}
|
||||
},
|
||||
@@ -214,7 +213,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"average_items_per_ms": 0.1,
|
||||
"total_duration_ms": 1000.0,
|
||||
"total_item_count": (
|
||||
self.updater.default_background_batch_size
|
||||
BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE
|
||||
),
|
||||
}
|
||||
},
|
||||
@@ -243,7 +242,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
|
||||
"average_items_per_ms": 0.1,
|
||||
"total_duration_ms": 1000.0,
|
||||
"total_item_count": (
|
||||
self.updater.default_background_batch_size
|
||||
BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE
|
||||
),
|
||||
}
|
||||
},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,6 @@ from synapse.util import Clock
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
||||
one_hour_ms = 3600000
|
||||
one_day_ms = one_hour_ms * 24
|
||||
@@ -39,10 +38,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
# merge this default retention config with anything that was specified in
|
||||
# @override_config
|
||||
retention_config = {
|
||||
config["retention"] = {
|
||||
"enabled": True,
|
||||
"default_policy": {
|
||||
"min_lifetime": one_day_ms,
|
||||
@@ -51,8 +47,6 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
"allowed_lifetime_min": one_day_ms,
|
||||
"allowed_lifetime_max": one_day_ms * 3,
|
||||
}
|
||||
retention_config.update(config.get("retention", {}))
|
||||
config["retention"] = retention_config
|
||||
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
|
||||
@@ -121,20 +115,22 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._test_retention_event_purged(room_id, one_day_ms * 2)
|
||||
|
||||
@override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}})
|
||||
def test_visibility(self) -> None:
|
||||
"""Tests that synapse.visibility.filter_events_for_client correctly filters out
|
||||
outdated events, even if the purge job hasn't got to them yet.
|
||||
|
||||
We do this by setting a very long time between purge jobs.
|
||||
outdated events
|
||||
"""
|
||||
store = self.hs.get_datastores().main
|
||||
storage = self.hs.get_storage()
|
||||
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
events = []
|
||||
|
||||
# Send a first event, which should be filtered out at the end of the test.
|
||||
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
|
||||
first_event_id = resp.get("event_id")
|
||||
|
||||
# Get the event from the store so that we end up with a FrozenEvent that we can
|
||||
# give to filter_events_for_client. We need to do this now because the event won't
|
||||
# be in the database anymore after it has expired.
|
||||
events.append(self.get_success(store.get_event(resp.get("event_id"))))
|
||||
|
||||
# Advance the time by 2 days. We're using the default retention policy, therefore
|
||||
# after this the first event will still be valid.
|
||||
@@ -142,17 +138,16 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Send another event, which shouldn't get filtered out.
|
||||
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
|
||||
|
||||
valid_event_id = resp.get("event_id")
|
||||
|
||||
events.append(self.get_success(store.get_event(valid_event_id)))
|
||||
|
||||
# Advance the time by another 2 days. After this, the first event should be
|
||||
# outdated but not the second one.
|
||||
self.reactor.advance(one_day_ms * 2 / 1000)
|
||||
|
||||
# Fetch the events, and run filter_events_for_client on them
|
||||
events = self.get_success(
|
||||
store.get_events_as_list([first_event_id, valid_event_id])
|
||||
)
|
||||
self.assertEqual(2, len(events), "events retrieved from database")
|
||||
# Run filter_events_for_client with our list of FrozenEvents.
|
||||
filtered_events = self.get_success(
|
||||
filter_events_for_client(storage, self.user_id, events)
|
||||
)
|
||||
|
||||
@@ -1,18 +1,3 @@
|
||||
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import Mock, call
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
@@ -26,14 +11,14 @@ from tests.utils import MockClock
|
||||
|
||||
|
||||
class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
def setUp(self):
|
||||
self.clock = MockClock()
|
||||
self.hs = Mock()
|
||||
self.hs.get_clock = Mock(return_value=self.clock)
|
||||
self.hs.get_auth = Mock()
|
||||
self.cache = HttpTransactionCache(self.hs)
|
||||
|
||||
self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
|
||||
self.mock_http_response = (200, "GOOD JOB!")
|
||||
self.mock_key = "foo"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user