1
0

Compare commits

..

27 Commits

Author SHA1 Message Date
David Robertson
4c98817264 debug check deps for #12223 2022-03-16 15:33:23 +00:00
Sean Quah
6121056740 Handle cancellation in DatabasePool.runInteraction() (#12199)
To handle cancellation, we ensure that `after_callback`s and
`exception_callback`s are always run, since the transaction will
complete on another thread regardless of cancellation.

We also wait until everything is done before releasing the
`CancelledError`, so that logging contexts won't get used after they
have been finished.

Signed-off-by: Sean Quah <seanq@element.io>
2022-03-16 15:07:41 +00:00
Patrick Cloke
fc9bd620ce Add a relations handler to avoid duplication. (#12227)
Adds a handler layer between the REST and datastore layers for relations.
2022-03-16 10:39:15 -04:00
Patrick Cloke
c486fa5fd9 Add some missing type hints to cache datastore. (#12216) 2022-03-16 10:37:04 -04:00
David Robertson
86965605a4 Fix dead link in spam checker warning (#12231) 2022-03-16 13:52:59 +00:00
Patrick Cloke
1da0f79d54 Refactor relations tests (#12232)
* Moves the relation pagination tests to a separate class.
* Move the assertion of the response code into the `_send_relation` helper.
* Moves some helpers into the base-class.
2022-03-16 09:20:57 -04:00
Patrick Cloke
4587b35929 Clean-up logic for rebasing URLs during URL preview. (#12219)
By using urljoin from the standard library and reducing the number
of places URLs are rebased.
2022-03-16 07:21:36 -04:00
Patrick Cloke
dda9b7fc4d Use the ignored_users table to test event visibility & sync. (#12225)
Instead of fetching the raw account data and re-parsing it. The
ignored_users table is a denormalised version of the account data
for quick searching.
2022-03-15 14:06:05 -04:00
Sean Quah
dea577998f Add tests for database transaction callbacks (#12198)
Signed-off-by: Sean Quah <seanq@element.io>
2022-03-15 15:40:34 +00:00
Dirk Klimpel
5dd949bee6 Add type hints to some tests/handlers files. (#12224) 2022-03-15 09:16:37 -04:00
Sean Quah
2fcf4b3f6c Add cancellation support to @cached and @cachedList decorators (#12183)
These decorators mostly support cancellation already. Add cancellation
tests and fix use of finished logging contexts by delaying cancellation,
as suggested by @erikjohnston.

Signed-off-by: Sean Quah <seanq@element.io>
2022-03-14 19:04:29 +00:00
Sean Quah
605d161d7d Add cancellation support to ReadWriteLock (#12120)
Also convert `ReadWriteLock` to use async context managers.

Signed-off-by: Sean Quah <seanq@element.io>
2022-03-14 18:49:07 +00:00
Sean Quah
8e5706d144 Fix broken background updates when using sqlite with enable_search off (#12215)
Signed-off-by: Sean Quah <seanq@element.io>
2022-03-14 17:52:58 +00:00
Sean Quah
90b2327066 Add delay_cancellation utility function (#12180)
`delay_cancellation` behaves like `stop_cancellation`, except it
delays `CancelledError`s until the original `Deferred` resolves.
This is handy for unifying cleanup paths and ensuring that uncancelled
coroutines don't use finished logcontexts.

Signed-off-by: Sean Quah <seanq@element.io>
2022-03-14 17:52:15 +00:00
Patrick Cloke
54f674f7a9 Deprecate the groups/communities endpoints and add an experimental configuration flag. (#12200) 2022-03-12 13:23:37 -05:00
Shay
ef3619e61d Add config settings for background update parameters (#11980) 2022-03-11 10:46:45 -08:00
Brendan Abolivier
e6a106fd5e Implement a Jinja2 filter to extract localparts from email addresses (#12212) 2022-03-11 15:15:11 +00:00
reivilibre
4a53f35737 Improve code documentation for the typing stream over replication. (#12211) 2022-03-11 14:00:15 +00:00
Nick Mills-Barrett
735e89bd3a Add an additional HTTP pusher + push rule tests. (#12188)
And rename the field used for caching from _id to _cache_key.
2022-03-11 08:45:26 -05:00
Brendan Abolivier
003cc6910a Update the SSO username picker template to comply with SIWA guidelines (#12210)
Fixes https://github.com/matrix-org/synapse/issues/12205
2022-03-11 13:20:00 +00:00
Dirk Klimpel
32c828d0f7 Add type hints to tests/rest. (#12208)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
2022-03-11 12:42:22 +00:00
Patrick Cloke
e10a2fe0c2 Add some type hints to the tests.handlers module. (#12207) 2022-03-11 07:07:15 -05:00
Patrick Cloke
bc9dff1d95 Remove unnecessary pass statements. (#12206) 2022-03-11 07:06:21 -05:00
Andrew Morgan
3b12f6d61b Note that contributors can sign off privately (#12204)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
2022-03-11 11:10:20 +00:00
Richard van der Hoff
483f2aa2ec Retention test: avoid relying on state at purged events (#12202)
This test was relying on poking events which weren't in the database into
filter_events_for_client.
2022-03-11 10:33:49 +00:00
~creme
7577894bec Document that most streams can only have a single writer. (#12196)
This includes the `typing`, `to_device`, `account_data`, `receipts`, and `presence`
streams (really anything except the `events` stream).
2022-03-10 18:15:19 +00:00
Shay
ed9aea42fa fix misleading comment in check_events_for_spam (#12203) 2022-03-10 09:40:07 -08:00
109 changed files with 2607 additions and 1204 deletions

View File

@@ -33,7 +33,7 @@ site-url = "/synapse/"
additional-css = [
"docs/website_files/table-of-contents.css",
"docs/website_files/remove-nav-buttons.css",
"docs/website_files/section-headers.css",
"docs/website_files/indent-section-headers.css",
]
additional-js = ["docs/website_files/table-of-contents.js"]
theme = "docs/website_files/theme"

1
changelog.d/11980.misc Normal file
View File

@@ -0,0 +1 @@
Add config settings for background update parameters.

View File

@@ -1 +1 @@
Add type hints to `tests/rest/client`.
Add type hints to tests files.

1
changelog.d/12120.misc Normal file
View File

@@ -0,0 +1 @@
Add support for cancellation to `ReadWriteLock`.

View File

@@ -1 +1 @@
Add type hints to `tests/rest`.
Add type hints to tests files.

1
changelog.d/12180.misc Normal file
View File

@@ -0,0 +1 @@
Add `delay_cancellation` utility function, which behaves like `stop_cancellation` but waits until the original `Deferred` resolves before raising a `CancelledError`.

1
changelog.d/12183.misc Normal file
View File

@@ -0,0 +1 @@
Add cancellation support to `@cached` and `@cachedList` decorators.

1
changelog.d/12188.misc Normal file
View File

@@ -0,0 +1 @@
Add combined test for HTTP pusher and push rule. Contributed by Nick @ Beeper.

1
changelog.d/12196.doc Normal file
View File

@@ -0,0 +1 @@
Document that the `typing`, `to_device`, `account_data`, `receipts`, and `presence` stream writer can only be used on a single worker.

1
changelog.d/12198.misc Normal file
View File

@@ -0,0 +1 @@
Add tests for database transaction callbacks.

1
changelog.d/12199.misc Normal file
View File

@@ -0,0 +1 @@
Handle cancellation in `DatabasePool.runInteraction()`.

View File

@@ -0,0 +1 @@
The groups/communities feature in Synapse has been deprecated.

1
changelog.d/12202.misc Normal file
View File

@@ -0,0 +1 @@
Avoid trying to calculate the state at outlier events.

1
changelog.d/12203.misc Normal file
View File

@@ -0,0 +1 @@
Fix a misleading comment in the function `check_event_for_spam`.

1
changelog.d/12204.doc Normal file
View File

@@ -0,0 +1 @@
Document that contributors can sign off privately by email.

1
changelog.d/12206.misc Normal file
View File

@@ -0,0 +1 @@
Remove unnecessary `pass` statements.

1
changelog.d/12207.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to tests files.

1
changelog.d/12208.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to tests files.

1
changelog.d/12210.misc Normal file
View File

@@ -0,0 +1 @@
Update the SSO username picker template to comply with SIWA guidelines.

1
changelog.d/12211.misc Normal file
View File

@@ -0,0 +1 @@
Improve code documentation for the typing stream over replication.

View File

@@ -0,0 +1 @@
Add a new Jinja2 template filter to extract the local part of an email address.

1
changelog.d/12215.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug introduced in 1.54.0 that broke background updates on sqlite homeservers while search was disabled.

1
changelog.d/12216.misc Normal file
View File

@@ -0,0 +1 @@
Add missing type hints for cache storage.

1
changelog.d/12219.misc Normal file
View File

@@ -0,0 +1 @@
Clean-up logic around rebasing URLs for URL image previews.

1
changelog.d/12224.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to tests files.

1
changelog.d/12225.misc Normal file
View File

@@ -0,0 +1 @@
Use the `ignored_users` table in additional places instead of re-parsing the account data.

1
changelog.d/12227.misc Normal file
View File

@@ -0,0 +1 @@
Refactor the relations endpoints to add a `RelationsHandler`.

1
changelog.d/12231.doc Normal file
View File

@@ -0,0 +1 @@
Fix the link to the module documentation in the legacy spam checker warning message.

1
changelog.d/12232.misc Normal file
View File

@@ -0,0 +1 @@
Refactor relations tests to improve code re-use.

View File

@@ -458,6 +458,17 @@ Git allows you to add this signoff automatically when using the `-s`
flag to `git commit`, which uses the name and email set in your
`user.name` and `user.email` git configs.
### Private Sign off
If you would like to provide your legal name privately to the Matrix.org
Foundation (instead of in a public commit or comment), you can do so
by emailing your legal name and a link to the pull request to
[dco@matrix.org](mailto:dco@matrix.org?subject=Private%20sign%20off).
It helps to include "sign off" or similar in the subject line. You will then
be instructed further.
Once private sign off is complete, doing so for future contributions will not
be required.
# 10. Turn feedback into better code.

View File

@@ -1947,8 +1947,14 @@ saml2_config:
#
# localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their
# own username (see 'sso_auth_account_details.html' in the 'sso'
# section of this file).
# own username (see the documentation for the
# 'sso_auth_account_details.html' template). This template can
# use the 'localpart_from_email' filter.
#
# confirm_localpart: Whether to prompt the user to validate (or
# change) the generated localpart (see the documentation for the
# 'sso_auth_account_details.html' template), instead of
# registering the account right away.
#
# display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set.
@@ -2729,3 +2735,35 @@ redis:
# Optional password if configured on the Redis instance
#
#password: <secret_password>
## Background Updates ##
# Background updates are database updates that are run in the background in batches.
# The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to
# sleep can all be configured. This is helpful to speed up or slow down the updates.
#
background_updates:
# How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set
# a time to change the default.
#
#background_update_duration_ms: 500
# Whether to sleep between updates. Defaults to True. Uncomment to change the default.
#
#sleep_enabled: false
# If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment
# and set a duration to change the default.
#
#sleep_duration_ms: 300
# Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and
# set a size to change the default.
#
#min_batch_size: 10
# The batch size to use for the first iteration of a new background update. The default is 100.
# Uncomment and set a size to change the default.
#
#default_batch_size: 50

View File

@@ -36,6 +36,13 @@ Turns a `mxc://` URL for media content into an HTTP(S) one using the homeserver'
Example: `message.sender_avatar_url|mxc_to_http(32,32)`
```python
localpart_from_email(address: str) -> str
```
Returns the local part of an email address (e.g. `alice` in `alice@example.com`).
Example: `user.email_address|localpart_from_email`
## Email templates
@@ -176,8 +183,11 @@ Below are the templates Synapse will look for when generating pages related to S
for the brand of the IdP
* `user_attributes`: an object containing details about the user that
we received from the IdP. May have the following attributes:
* display_name: the user's display_name
* emails: a list of email addresses
* `display_name`: the user's display name
* `emails`: a list of email addresses
* `localpart`: the local part of the Matrix user ID to register,
if `localpart_template` is set in the mapping provider configuration (empty
string if not)
The template should render a form which submits the following fields:
* `username`: the localpart of the user's chosen user id
* `sso_new_user_consent.html`: HTML page allowing the user to consent to the

View File

@@ -85,6 +85,20 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
```
# Upgrading to v1.56.0
## Groups/communities feature has been deprecated
The non-standard groups/communities feature in Synapse has been deprecated and will
be disabled by default in Synapse v1.58.0.
You can test disabling it by adding the following to your homeserver configuration:
```yaml
experimental_features:
groups_enabled: false
```
# Upgrading to v1.55.0
## `synctl` script has been moved

View File

@@ -0,0 +1,7 @@
/*
* Indents each chapter title in the left sidebar so that they aren't
* at the same level as the section headers.
*/
.chapter-item {
margin-left: 1em;
}

View File

@@ -1,20 +0,0 @@
/*
* Indents each chapter title in the left sidebar so that they aren't
* at the same level as the section headers.
*/
.chapter-item {
margin-left: 1em;
}
/*
* Prevents a large gap between successive section headers.
*
* mdbook sets 'margin-top: 2.5em' on h2 and h3 headers. This makes sense when separating
* a header from the paragraph beforehand, but has the downside of introducing a large
* gap between headers that are next to each other with no text in between.
*
* This rule reduces the margin in this case.
*/
h1 + h2, h2 + h3 {
margin-top: 1.0em;
}

View File

@@ -351,8 +351,11 @@ is only supported with Redis-based replication.)
To enable this, the worker must have a HTTP replication listener configured,
have a `worker_name` and be listed in the `instance_map` config. The same worker
can handle multiple streams. For example, to move event persistence off to a
dedicated worker, the shared configuration would include:
can handle multiple streams, but unless otherwise documented, each stream can only
have a single writer.
For example, to move event persistence off to a dedicated worker, the shared
configuration would include:
```yaml
instance_map:
@@ -370,8 +373,8 @@ streams and the endpoints associated with them:
##### The `events` stream
The `events` stream also experimentally supports having multiple writers, where
work is sharded between them by room ID. Note that you *must* restart all worker
The `events` stream experimentally supports having multiple writers, where work
is sharded between them by room ID. Note that you *must* restart all worker
instances when adding or removing event persisters. An example `stream_writers`
configuration with multiple writers:
@@ -384,38 +387,38 @@ stream_writers:
##### The `typing` stream
The following endpoints should be routed directly to the workers configured as
stream writers for the `typing` stream:
The following endpoints should be routed directly to the worker configured as
the stream writer for the `typing` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing
##### The `to_device` stream
The following endpoints should be routed directly to the workers configured as
stream writers for the `to_device` stream:
The following endpoints should be routed directly to the worker configured as
the stream writer for the `to_device` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/
##### The `account_data` stream
The following endpoints should be routed directly to the workers configured as
stream writers for the `account_data` stream:
The following endpoints should be routed directly to the worker configured as
the stream writer for the `account_data` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data
##### The `receipts` stream
The following endpoints should be routed directly to the workers configured as
stream writers for the `receipts` stream:
The following endpoints should be routed directly to the worker configured as
the stream writer for the `receipts` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers
##### The `presence` stream
The following endpoints should be routed directly to the workers configured as
stream writers for the `presence` stream:
The following endpoints should be routed directly to the worker configured as
the stream writer for the `presence` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/

View File

@@ -67,13 +67,8 @@ exclude = (?x)
|tests/federation/transport/test_knocking.py
|tests/federation/transport/test_server.py
|tests/handlers/test_cas.py
|tests/handlers/test_directory.py
|tests/handlers/test_e2e_keys.py
|tests/handlers/test_federation.py
|tests/handlers/test_oidc.py
|tests/handlers/test_presence.py
|tests/handlers/test_profile.py
|tests/handlers/test_saml.py
|tests/handlers/test_typing.py
|tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py
@@ -90,7 +85,6 @@ exclude = (?x)
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py
|tests/rest/media/v1/test_media_storage.py
|tests/rest/media/v1/test_url_preview.py
|tests/scripts/test_new_matrix_user.py
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py

View File

@@ -322,7 +322,8 @@ class GenericWorkerServer(HomeServer):
presence.register_servlets(self, resource)
groups.register_servlets(self, resource)
if self.config.experimental.groups_enabled:
groups.register_servlets(self, resource)
resources.update({CLIENT_API_PREFIX: resource})

View File

@@ -19,6 +19,7 @@ from synapse.config import (
api,
appservice,
auth,
background_updates,
cache,
captcha,
cas,
@@ -113,6 +114,7 @@ class RootConfig:
caches: cache.CacheConfig
federation: federation.FederationConfig
retention: retention.RetentionConfig
background_updates: background_updates.BackgroundUpdateConfig
config_classes: List[Type["Config"]] = ...
def __init__(self) -> None: ...

View File

@@ -0,0 +1,68 @@
# Copyright 2022 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class BackgroundUpdateConfig(Config):
section = "background_updates"
def generate_config_section(self, **kwargs) -> str:
return """\
## Background Updates ##
# Background updates are database updates that are run in the background in batches.
# The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to
# sleep can all be configured. This is helpful to speed up or slow down the updates.
#
background_updates:
# How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set
# a time to change the default.
#
#background_update_duration_ms: 500
# Whether to sleep between updates. Defaults to True. Uncomment to change the default.
#
#sleep_enabled: false
# If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment
# and set a duration to change the default.
#
#sleep_duration_ms: 300
# Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and
# set a size to change the default.
#
#min_batch_size: 10
# The batch size to use for the first iteration of a new background update. The default is 100.
# Uncomment and set a size to change the default.
#
#default_batch_size: 50
"""
def read_config(self, config, **kwargs) -> None:
bg_update_config = config.get("background_updates") or {}
self.update_duration_ms = bg_update_config.get(
"background_update_duration_ms", 100
)
self.sleep_enabled = bg_update_config.get("sleep_enabled", True)
self.sleep_duration_ms = bg_update_config.get("sleep_duration_ms", 1000)
self.min_batch_size = bg_update_config.get("min_batch_size", 1)
self.default_batch_size = bg_update_config.get("default_batch_size", 100)

View File

@@ -74,3 +74,6 @@ class ExperimentalConfig(Config):
# MSC3720 (Account status endpoint)
self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False)
# The deprecated groups feature.
self.groups_enabled: bool = experimental.get("groups_enabled", True)

View File

@@ -16,6 +16,7 @@ from .account_validity import AccountValidityConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .auth import AuthConfig
from .background_updates import BackgroundUpdateConfig
from .cache import CacheConfig
from .captcha import CaptchaConfig
from .cas import CasConfig
@@ -99,4 +100,5 @@ class HomeServerConfig(RootConfig):
WorkerConfig,
RedisConfig,
ExperimentalConfig,
BackgroundUpdateConfig,
]

View File

@@ -182,8 +182,14 @@ class OIDCConfig(Config):
#
# localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their
# own username (see 'sso_auth_account_details.html' in the 'sso'
# section of this file).
# own username (see the documentation for the
# 'sso_auth_account_details.html' template). This template can
# use the 'localpart_from_email' filter.
#
# confirm_localpart: Whether to prompt the user to validate (or
# change) the generated localpart (see the documentation for the
# 'sso_auth_account_details.html' template), instead of
# registering the account right away.
#
# display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set.

View File

@@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
LEGACY_SPAM_CHECKER_WARNING = """
This server is using a spam checker module that is implementing the deprecated spam
checker interface. Please check with the module's maintainer to see if a new version
supporting Synapse's generic modules system is available.
For more information, please see https://matrix-org.github.io/synapse/latest/modules.html
supporting Synapse's generic modules system is available. For more information, please
see https://matrix-org.github.io/synapse/latest/modules/index.html
---------------------------------------------------------------------------------------"""

View File

@@ -245,8 +245,8 @@ class SpamChecker:
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
sent by a local user. If it is sent by a user on another server, then
users receive a blank event.
sent by a local user. If it is sent by a user on another server, the
event is soft-failed.
Args:
event: the event to be checked

View File

@@ -289,7 +289,7 @@ class OpenIdUserInfo(BaseFederationServlet):
return 200, {"sub": user_id}
DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
"federation": FEDERATION_SERVLET_CLASSES,
"room_list": (PublicRoomList,),
"group_server": GROUP_SERVER_SERVLET_CLASSES,
@@ -298,6 +298,10 @@ DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
"openid": (OpenIdUserInfo,),
}
DEFAULT_SERVLET_GROUPS = ("federation", "room_list", "openid")
GROUP_SERVLET_GROUPS = ("group_server", "group_local", "group_attestation")
def register_servlets(
hs: "HomeServer",
@@ -320,16 +324,19 @@ def register_servlets(
Defaults to ``DEFAULT_SERVLET_GROUPS``.
"""
if not servlet_groups:
servlet_groups = DEFAULT_SERVLET_GROUPS.keys()
servlet_groups = DEFAULT_SERVLET_GROUPS
# Only allow the groups servlets if the deprecated groups feature is enabled.
if hs.config.experimental.groups_enabled:
servlet_groups = servlet_groups + GROUP_SERVLET_GROUPS
for servlet_group in servlet_groups:
# Skip unknown servlet groups.
if servlet_group not in DEFAULT_SERVLET_GROUPS:
if servlet_group not in SERVLET_GROUPS:
raise RuntimeError(
f"Attempting to register unknown federation servlet: '{servlet_group}'"
)
for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]:
for servletclass in SERVLET_GROUPS[servlet_group]:
# Only allow the `/timestamp_to_event` servlet if msc3030 is enabled
if (
servletclass == FederationTimestampLookupServlet

View File

@@ -371,7 +371,6 @@ class DeviceHandler(DeviceWorkerHandler):
log_kv(
{"reason": "User doesn't have device id.", "device_id": device_id}
)
pass
else:
raise
@@ -414,7 +413,6 @@ class DeviceHandler(DeviceWorkerHandler):
# no match
set_tag("error", True)
set_tag("reason", "User doesn't have that device id.")
pass
else:
raise

View File

@@ -45,6 +45,7 @@ from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.templates import _localpart_from_email_filter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -1228,6 +1229,7 @@ class OidcSessionData:
class UserAttributeDict(TypedDict):
localpart: Optional[str]
confirm_localpart: bool
display_name: Optional[str]
emails: List[str]
@@ -1307,6 +1309,11 @@ def jinja_finalize(thing: Any) -> Any:
env = Environment(finalize=jinja_finalize)
env.filters.update(
{
"localpart_from_email": _localpart_from_email_filter,
}
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -1316,6 +1323,7 @@ class JinjaOidcMappingConfig:
display_name_template: Optional[Template]
email_template: Optional[Template]
extra_attributes: Dict[str, Template]
confirm_localpart: bool = False
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1357,12 +1365,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
"invalid jinja template", path=["extra_attributes", key]
) from e
confirm_localpart = config.get("confirm_localpart") or False
if not isinstance(confirm_localpart, bool):
raise ConfigError("must be a bool", path=["confirm_localpart"])
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
email_template=email_template,
extra_attributes=extra_attributes,
confirm_localpart=confirm_localpart,
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
@@ -1398,7 +1411,10 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
emails.append(email)
return UserAttributeDict(
localpart=localpart, display_name=display_name, emails=emails
localpart=localpart,
display_name=display_name,
emails=emails,
confirm_localpart=self._config.confirm_localpart,
)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
import attr
@@ -350,7 +350,7 @@ class PaginationHandler:
"""
self._purges_in_progress_by_room.add(room_id)
try:
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
await self.storage.purge_events.purge_history(
room_id, token, delete_local_events
)
@@ -406,7 +406,7 @@ class PaginationHandler:
room_id: room to be purged
force: set true to skip checking for joined users.
"""
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
# first check that we have no users in this room
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)
@@ -422,7 +422,7 @@ class PaginationHandler:
pagin_config: PaginationConfig,
as_client_event: bool = True,
event_filter: Optional[Filter] = None,
) -> Dict[str, Any]:
) -> JsonDict:
"""Get messages in a room.
Args:
@@ -431,6 +431,7 @@ class PaginationHandler:
pagin_config: The pagination config rules to apply, if any.
as_client_event: True to get events in client-server format.
event_filter: Filter to apply to results or None
Returns:
Pagination API results
"""
@@ -448,7 +449,7 @@ class PaginationHandler:
room_token = from_token.room_key
with await self.pagination_lock.read(room_id):
async with self.pagination_lock.read(room_id):
(
membership,
member_event_id,
@@ -615,7 +616,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[
delete_id

View File

@@ -267,7 +267,6 @@ class BasePresenceHandler(abc.ABC):
is_syncing: Whether or not the user is now syncing
sync_time_msec: Time in ms when the user was last syncing
"""
pass
async def update_external_syncs_clear(self, process_id: str) -> None:
"""Marks all users that had been marked as syncing by a given process
@@ -277,7 +276,6 @@ class BasePresenceHandler(abc.ABC):
This is a no-op when presence is handled by a different worker.
"""
pass
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list

View File

@@ -0,0 +1,117 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
from synapse.types import JsonDict, Requester, StreamToken
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
self._auth = hs.get_auth()
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
async def get_relations(
self,
requester: Requester,
event_id: str,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
TODO Accept a PaginationConfig instead of individual pagination parameters.
Args:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
The pagination chunk.
"""
user_id = requester.user.to_string()
await self._auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self._event_handler.get_event(requester.user, room_id, event_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
pagination_chunk = await self._main_store.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=aggregation_key,
limit=limit,
direction=direction,
from_token=from_token,
to_token=to_token,
)
events = await self._main_store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
)
now = self._clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
original_event = self._event_serializer.serialize_event(
event, now, bundle_aggregations=None
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self._main_store.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self._main_store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return return_value

View File

@@ -132,6 +132,7 @@ class UserAttributes:
# if `None`, the mapper has not picked a userid, and the user should be prompted to
# enter one.
localpart: Optional[str]
confirm_localpart: bool = False
display_name: Optional[str] = None
emails: Collection[str] = attr.Factory(list)
@@ -561,9 +562,10 @@ class SsoHandler:
# Must provide either attributes or session, not both
assert (attributes is not None) != (session is not None)
if (attributes and attributes.localpart is None) or (
session and session.chosen_localpart is None
):
if (
attributes
and (attributes.localpart is None or attributes.confirm_localpart is True)
) or (session and session.chosen_localpart is None):
return b"/_synapse/client/pick_username/account_details"
elif self._consent_at_registration and not (
session and session.terms_accepted_version

View File

@@ -28,7 +28,7 @@ from typing import (
import attr
from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -1601,7 +1601,7 @@ class SyncHandler:
return set(), set(), set(), set()
# 3. Work out which rooms need reporting in the sync response.
ignored_users = await self._get_ignored_users(user_id)
ignored_users = await self.store.ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
@@ -1627,7 +1627,6 @@ class SyncHandler:
logger.debug("Generating room entry for %s", room_entry.room_id)
await self._generate_room_entry(
sync_result_builder,
ignored_users,
room_entry,
ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
tags=tags_by_room.get(room_entry.room_id),
@@ -1657,29 +1656,6 @@ class SyncHandler:
newly_left_users,
)
async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
"""Retrieve the users ignored by the given user from their global account_data.
Returns an empty set if
- there is no global account_data entry for ignored_users
- there is such an entry, but it's not a JSON object.
"""
# TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
ignored_account_data = (
await self.store.get_global_account_data_by_type_for_user(
user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
)
)
# If there is ignored users account data and it matches the proper type,
# then use it.
ignored_users: FrozenSet[str] = frozenset()
if ignored_account_data:
ignored_users_data = ignored_account_data.get("ignored_users", {})
if isinstance(ignored_users_data, dict):
ignored_users = frozenset(ignored_users_data.keys())
return ignored_users
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
) -> bool:
@@ -2022,7 +1998,6 @@ class SyncHandler:
async def _generate_room_entry(
self,
sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
@@ -2051,7 +2026,6 @@ class SyncHandler:
Args:
sync_result_builder
ignored_users: Set of users ignored by user.
room_builder
ephemeral: List of new ephemeral events for room
tags: List of *all* tags for room, or None if there has been

View File

@@ -160,8 +160,9 @@ class FollowerTypingHandler:
"""Should be called whenever we receive updates for typing stream."""
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
# clear everything.
# The typing worker has gone backwards (e.g. it may have restarted).
# To prevent inconsistent data, just clear everything.
logger.info("Typing handler stream went backwards; resetting")
self._reset()
# Set the latest serial token to whatever the server gave us.

View File

@@ -120,7 +120,6 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC):
"""Called when response has finished streaming and the parser should
return the final result (or error).
"""
pass
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -601,7 +600,6 @@ class MatrixFederationHttpClient:
response.code,
response_phrase,
)
pass
else:
logger.info(
"{%s} [%s] Got response headers: %d %s",

View File

@@ -233,7 +233,6 @@ class HttpServer(Protocol):
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
"""
pass
class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):

View File

@@ -169,7 +169,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"kind": "event_match",
"key": "content.msgtype",
"pattern": "m.notice",
"_id": "_suppress_notices",
"_cache_key": "_suppress_notices",
}
],
"actions": ["dont_notify"],
@@ -183,13 +183,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"kind": "event_match",
"key": "type",
"pattern": "m.room.member",
"_id": "_member",
"_cache_key": "_member",
},
{
"kind": "event_match",
"key": "content.membership",
"pattern": "invite",
"_id": "_invite_member",
"_cache_key": "_invite_member",
},
{"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
],
@@ -212,7 +212,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"kind": "event_match",
"key": "type",
"pattern": "m.room.member",
"_id": "_member",
"_cache_key": "_member",
}
],
"actions": ["dont_notify"],
@@ -237,12 +237,12 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"kind": "event_match",
"key": "content.body",
"pattern": "@room",
"_id": "_roomnotif_content",
"_cache_key": "_roomnotif_content",
},
{
"kind": "sender_notification_permission",
"key": "room",
"_id": "_roomnotif_pl",
"_cache_key": "_roomnotif_pl",
},
],
"actions": ["notify", {"set_tweak": "highlight", "value": True}],
@@ -254,13 +254,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"kind": "event_match",
"key": "type",
"pattern": "m.room.tombstone",
"_id": "_tombstone",
"_cache_key": "_tombstone",
},
{
"kind": "event_match",
"key": "state_key",
"pattern": "",
"_id": "_tombstone_statekey",
"_cache_key": "_tombstone_statekey",
},
],
"actions": ["notify", {"set_tweak": "highlight", "value": True}],
@@ -272,7 +272,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"kind": "event_match",
"key": "type",
"pattern": "m.reaction",
"_id": "_reaction",
"_cache_key": "_reaction",
}
],
"actions": ["dont_notify"],
@@ -288,7 +288,7 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"kind": "event_match",
"key": "type",
"pattern": "m.call.invite",
"_id": "_call",
"_cache_key": "_call",
}
],
"actions": [
@@ -302,12 +302,12 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
{
"rule_id": "global/underride/.m.rule.room_one_to_one",
"conditions": [
{"kind": "room_member_count", "is": "2", "_id": "member_count"},
{"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
{
"kind": "event_match",
"key": "type",
"pattern": "m.room.message",
"_id": "_message",
"_cache_key": "_message",
},
],
"actions": [
@@ -321,12 +321,12 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
{
"rule_id": "global/underride/.m.rule.encrypted_room_one_to_one",
"conditions": [
{"kind": "room_member_count", "is": "2", "_id": "member_count"},
{"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
{
"kind": "event_match",
"key": "type",
"pattern": "m.room.encrypted",
"_id": "_encrypted",
"_cache_key": "_encrypted",
},
],
"actions": [
@@ -342,7 +342,7 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"kind": "event_match",
"key": "type",
"pattern": "m.room.message",
"_id": "_message",
"_cache_key": "_message",
}
],
"actions": ["notify", {"set_tweak": "highlight", "value": False}],
@@ -356,7 +356,7 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"kind": "event_match",
"key": "type",
"pattern": "m.room.encrypted",
"_id": "_encrypted",
"_cache_key": "_encrypted",
}
],
"actions": ["notify", {"set_tweak": "highlight", "value": False}],
@@ -368,19 +368,19 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"kind": "event_match",
"key": "type",
"pattern": "im.vector.modular.widgets",
"_id": "_type_modular_widgets",
"_cache_key": "_type_modular_widgets",
},
{
"kind": "event_match",
"key": "content.type",
"pattern": "jitsi",
"_id": "_content_type_jitsi",
"_cache_key": "_content_type_jitsi",
},
{
"kind": "event_match",
"key": "state_key",
"pattern": "*",
"_id": "_is_state_event",
"_cache_key": "_is_state_event",
},
],
"actions": ["notify", {"set_tweak": "highlight", "value": False}],

View File

@@ -213,7 +213,7 @@ class BulkPushRuleEvaluator:
if not event.is_state():
ignorers = await self.store.ignored_by(event.sender)
else:
ignorers = set()
ignorers = frozenset()
for uid, rules in rules_by_user.items():
if event.sender == uid:
@@ -274,17 +274,17 @@ def _condition_checker(
cache: Dict[str, bool],
) -> bool:
for cond in conditions:
_id = cond.get("_id", None)
if _id:
res = cache.get(_id, None)
_cache_key = cond.get("_cache_key", None)
if _cache_key:
res = cache.get(_cache_key, None)
if res is False:
return False
elif res is True:
continue
res = evaluator.matches(cond, uid, display_name)
if _id:
cache[_id] = bool(res)
if _cache_key:
cache[_cache_key] = bool(res)
if not res:
return False

View File

@@ -40,7 +40,7 @@ def format_push_rules_for_user(
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
c.pop("_cache_key", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":

View File

@@ -709,7 +709,7 @@ class ReplicationCommandHandler:
self.send_command(RemoteServerUpCommand(server))
def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None:
"""Called when a new update is available to stream to clients.
"""Called when a new update is available to stream to Redis subscribers.
We need to check if the client is interested in the stream or not
"""

View File

@@ -67,8 +67,8 @@ class ReplicationStreamProtocolFactory(ServerFactory):
class ReplicationStreamer:
"""Handles replication connections.
This needs to be poked when new replication data may be available. When new
data is available it will propagate to all connected clients.
This needs to be poked when new replication data may be available.
When new data is available it will propagate to all Redis subscribers.
"""
def __init__(self, hs: "HomeServer"):
@@ -109,7 +109,7 @@ class ReplicationStreamer:
def on_notifier_poke(self) -> None:
"""Checks if there is actually any new data and sends it to the
connections if there are.
Redis subscribers if there are.
This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed

View File

@@ -316,7 +316,19 @@ class PresenceFederationStream(Stream):
class TypingStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class TypingStreamRow:
"""
An entry in the typing stream.
Describes all the users that are 'typing' right now in one room.
When a user stops typing, it will be streamed as a new update with that
user absent; you can think of the `user_ids` list as overwriting the
entire list that was there previously.
"""
# The room that this update is for.
room_id: str
# All the users that are 'typing' right now in the specified room.
user_ids: List[str]
NAME = "typing"

View File

@@ -130,15 +130,15 @@
</head>
<body>
<header>
<h1>Your account is nearly ready</h1>
<p>Check your details before creating an account on {{ server_name }}</p>
<h1>Choose your user name</h1>
<p>This is required to create your account on {{ server_name }}, and you can't change this later.</p>
</header>
<main>
<form method="post" class="form__input" id="form">
<div class="username_input" id="username_input">
<label for="field-username">Username</label>
<div class="prefix">@</div>
<input type="text" name="username" id="field-username" autofocus>
<input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus>
<div class="postfix">:{{ server_name }}</div>
</div>
<output for="username_input" id="field-username-output"></output>

View File

@@ -118,7 +118,8 @@ class ClientRestResource(JsonResource):
thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)
user_directory.register_servlets(hs, client_resource)
groups.register_servlets(hs, client_resource)
if hs.config.experimental.groups_enabled:
groups.register_servlets(hs, client_resource)
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
room_batch.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)

View File

@@ -293,7 +293,8 @@ def register_servlets_for_client_rest_resource(
ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
DeleteGroupAdminRestServlet(hs).register(http_server)
if hs.config.experimental.groups_enabled:
DeleteGroupAdminRestServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)
# Load the media repo ones if we're using them. Otherwise load the servlets which

View File

@@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
self._relations_handler = hs.get_relations_handler()
async def on_GET(
self,
@@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
room_id, requester.user.to_string(), allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
limit = parse_integer(request, "limit", default=5)
direction = parse_string(
request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
@@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
pagination_chunk = await self.store.get_relations_for_event(
result = await self._relations_handler.get_relations(
requester=requester,
event_id=parent_id,
event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token,
)
events = await self.store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
)
now = self.clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
original_event = self._event_serializer.serialize_event(
event, now, bundle_aggregations=None
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.store.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self.store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return 200, return_value
return 200, result
class RelationAggregationPaginationServlet(RestServlet):
@@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
self._relations_handler = hs.get_relations_handler()
async def on_GET(
self,
@@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
room_id,
requester.user.to_string(),
allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
result = await self.store.get_relations_for_event(
result = await self._relations_handler.get_relations(
requester=requester,
event_id=parent_id,
event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token,
)
events = await self.store.get_events_as_list(
[c["event_id"] for c in result.chunk]
)
now = self.clock.time_msec()
serialized_events = self._event_serializer.serialize_events(events, now)
return_value = await result.to_dict(self.store)
return_value["chunk"] = serialized_events
return 200, return_value
return 200, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:

View File

@@ -298,7 +298,6 @@ class Responder:
Returns:
Resolves once the response has finished being written
"""
pass
def __enter__(self) -> None:
pass

View File

@@ -16,7 +16,6 @@ import itertools
import logging
import re
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
from urllib import parse as urlparse
if TYPE_CHECKING:
from lxml import etree
@@ -144,9 +143,7 @@ def decode_body(
return etree.fromstring(body, parser)
def parse_html_to_open_graph(
tree: "etree.Element", media_uri: str
) -> Dict[str, Optional[str]]:
def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
"""
Parse the HTML document into an Open Graph response.
@@ -155,7 +152,6 @@ def parse_html_to_open_graph(
Args:
tree: The parsed HTML document.
media_url: The URI used to download the body.
Returns:
The Open Graph response as a dictionary.
@@ -209,7 +205,7 @@ def parse_html_to_open_graph(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og["og:image"] = rebase_url(meta_image[0], media_uri)
og["og:image"] = meta_image[0]
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
@@ -320,37 +316,6 @@ def _iterate_over_text(
)
def rebase_url(url: str, base: str) -> str:
"""
Resolves a potentially relative `url` against an absolute `base` URL.
For example:
>>> rebase_url("subpage", "https://example.com/foo/")
'https://example.com/foo/subpage'
>>> rebase_url("sibling", "https://example.com/foo")
'https://example.com/sibling'
>>> rebase_url("/bar", "https://example.com/foo/")
'https://example.com/bar'
>>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
'https://alice.com/a'
"""
base_parts = urlparse.urlparse(base)
# Convert the parsed URL to a list for (potential) modification.
url_parts = list(urlparse.urlparse(url))
# Add a scheme, if one does not exist.
if not url_parts[0]:
url_parts[0] = base_parts.scheme or "http"
# Fix up the hostname, if this is not a data URL.
if url_parts[0] != "data" and not url_parts[1]:
url_parts[1] = base_parts.netloc
# If the path does not start with a /, nest it under the base path's last
# directory.
if not url_parts[2].startswith("/"):
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
return urlparse.urlunparse(url_parts)
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:

View File

@@ -22,7 +22,7 @@ import shutil
import sys
import traceback
from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
from urllib import parse as urlparse
from urllib.parse import urljoin, urlparse, urlsplit
from urllib.request import urlopen
import attr
@@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.rest.media.v1.preview_html import (
decode_body,
parse_html_to_open_graph,
rebase_url,
)
from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource):
ts = self.clock.time_msec()
# XXX: we could move this into _do_preview if we wanted.
url_tuple = urlparse.urlsplit(url)
url_tuple = urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
@@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# Parse Open Graph information from the HTML in case the oEmbed
# response failed or is incomplete.
og_from_html = parse_html_to_open_graph(tree, media_info.uri)
og_from_html = parse_html_to_open_graph(tree)
# Compile the Open Graph response by using the scraped
# information from the HTML and overlaying any information
@@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource):
if "og:image" not in og or not og["og:image"]:
return
# The image URL from the HTML might be relative to the previewed page,
# convert it to an URL which can be requested directly.
image_url = og["og:image"]
url_parts = urlparse(image_url)
if url_parts.scheme != "data":
image_url = urljoin(media_info.uri, image_url)
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
image_info = await self._handle_url(
rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
)
image_info = await self._handle_url(image_url, user, allow_data_urls=True)
if _is_media(image_info.media_type):
# TODO: make sure we don't choke on white-on-transparent images

View File

@@ -92,12 +92,20 @@ class AccountDetailsResource(DirectServeHtmlResource):
self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
return
# The configuration might mandate going through this step to validate an
# automatically generated localpart, so session.chosen_localpart might already
# be set.
localpart = ""
if session.chosen_localpart is not None:
localpart = session.chosen_localpart
idp_id = session.auth_provider_id
template_params = {
"idp": self._sso_handler.get_identity_providers()[idp_id],
"user_attributes": {
"display_name": session.display_name,
"emails": session.emails,
"localpart": localpart,
},
}

View File

@@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler
from synapse.handlers.relations import RelationsHandler
from synapse.handlers.room import (
RoomContextHandler,
RoomCreationHandler,
@@ -328,7 +329,6 @@ class HomeServer(metaclass=abc.ABCMeta):
Does nothing in this base class; overridden in derived classes to start the
appropriate listeners.
"""
pass
def setup_background_tasks(self) -> None:
"""
@@ -720,6 +720,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_pagination_handler(self) -> PaginationHandler:
return PaginationHandler(self)
@cache_in_self
def get_relations_handler(self) -> RelationsHandler:
return RelationsHandler(self)
@cache_in_self
def get_room_context_handler(self) -> RoomContextHandler:
return RoomContextHandler(self)

View File

@@ -60,18 +60,19 @@ class _BackgroundUpdateHandler:
class _BackgroundUpdateContextManager:
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, sleep: bool, clock: Clock):
def __init__(
self, sleep: bool, clock: Clock, sleep_duration_ms: int, update_duration: int
):
self._sleep = sleep
self._clock = clock
self._sleep_duration_ms = sleep_duration_ms
self._update_duration_ms = update_duration
async def __aenter__(self) -> int:
if self._sleep:
await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
await self._clock.sleep(self._sleep_duration_ms / 1000)
return self.BACKGROUND_UPDATE_DURATION_MS
return self._update_duration_ms
async def __aexit__(self, *exc) -> None:
pass
@@ -133,9 +134,6 @@ class BackgroundUpdater:
process and autotuning the batch size.
"""
MINIMUM_BACKGROUND_BATCH_SIZE = 1
DEFAULT_BACKGROUND_BATCH_SIZE = 100
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
@@ -160,6 +158,14 @@ class BackgroundUpdater:
# enable/disable background updates via the admin API.
self.enabled = True
self.minimum_background_batch_size = hs.config.background_updates.min_batch_size
self.default_background_batch_size = (
hs.config.background_updates.default_batch_size
)
self.update_duration_ms = hs.config.background_updates.update_duration_ms
self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms
self.sleep_enabled = hs.config.background_updates.sleep_enabled
def register_update_controller_callbacks(
self,
on_update: ON_UPDATE_CALLBACK,
@@ -216,7 +222,9 @@ class BackgroundUpdater:
if self._on_update_callback is not None:
return self._on_update_callback(update_name, database_name, oneshot)
return _BackgroundUpdateContextManager(sleep, self._clock)
return _BackgroundUpdateContextManager(
sleep, self._clock, self.sleep_duration_ms, self.update_duration_ms
)
async def _default_batch_size(self, update_name: str, database_name: str) -> int:
"""The batch size to use for the first iteration of a new background
@@ -225,7 +233,7 @@ class BackgroundUpdater:
if self._default_batch_size_callback is not None:
return await self._default_batch_size_callback(update_name, database_name)
return self.DEFAULT_BACKGROUND_BATCH_SIZE
return self.default_background_batch_size
async def _min_batch_size(self, update_name: str, database_name: str) -> int:
"""A lower bound on the batch size of a new background update.
@@ -235,7 +243,7 @@ class BackgroundUpdater:
if self._min_batch_size_callback is not None:
return await self._min_batch_size_callback(update_name, database_name)
return self.MINIMUM_BACKGROUND_BATCH_SIZE
return self.minimum_background_batch_size
def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
"""Returns the current background update, if any."""
@@ -254,9 +262,12 @@ class BackgroundUpdater:
if self.enabled:
# if we start a new background update, not all updates are done.
self._all_done = False
run_as_background_process("background_updates", self.run_background_updates)
sleep = self.sleep_enabled
run_as_background_process(
"background_updates", self.run_background_updates, sleep
)
async def run_background_updates(self, sleep: bool = True) -> None:
async def run_background_updates(self, sleep: bool) -> None:
if self._running or not self.enabled:
return

View File

@@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -732,34 +734,45 @@ class DatabasePool:
Returns:
The result of func
"""
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
async def _runInteraction() -> R:
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
try:
with opentracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
try:
with opentracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
return cast(R, result)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
return cast(R, result)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
# To handle cancellation, we ensure that `after_callback`s and
# `exception_callback`s are always run, since the transaction will complete
# on another thread regardless of cancellation.
#
# We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been
# finished.
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection(
self,

View File

@@ -14,7 +14,17 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
cast,
)
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
@cached(max_entries=5000, iterable=True)
async def ignored_by(self, user_id: str) -> Set[str]:
async def ignored_by(self, user_id: str) -> FrozenSet[str]:
"""
Get users which ignore the given user.
@@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
Return:
The user IDs which ignore the given user.
"""
return set(
return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignored_user_id": user_id},
@@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
)
@cached(max_entries=5000, iterable=True)
async def ignored_users(self, user_id: str) -> FrozenSet[str]:
"""
Get users which the given user ignores.
Params:
user_id: The user ID which is making the request.
Return:
The user IDs which are ignored by the given user.
"""
return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignorer_user_id": user_id},
retcol="ignored_user_id",
desc="ignored_users",
)
)
def process_replication_rows(
self,
stream_name: str,
@@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
currently_ignored_users = set()
# If the data has not changed, nothing to do.
if previously_ignored_users == currently_ignored_users:
return
# Delete entries which are no longer ignored.
self.db_pool.simple_delete_many_txn(
txn,
@@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def purge_account_data_for_user(self, user_id: str) -> None:
"""

View File

@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -31,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
def get_all_updated_caches_txn(txn):
def get_all_updated_caches_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
@@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
@@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token, row):
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data
if row.type == EventsStreamEventRow.TypeId:
assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event(
token,
data.event_id,
@@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed(
row.data.room_id, token
)
assert isinstance(data, EventsStreamCurrentStateRow)
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
@@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_caches_for_event(
self,
stream_ordering,
event_id,
room_id,
etype,
state_key,
redacts,
relates_to,
backfilled,
):
stream_ordering: int,
event_id: str,
room_id: str,
etype: str,
state_key: Optional[str],
redacts: Optional[str],
relates_to: Optional[str],
backfilled: bool,
) -> None:
self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))
@@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,))
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
keys,
)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(self, txn, cache_func):
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
@@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]]
):
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:
"""Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally.
@@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
"invalidation_ts": self._clock.time_msec(),
},
)

View File

@@ -48,8 +48,6 @@ class ExternalIDReuseException(Exception):
"""Exception if writing an external id for a user fails,
because this external id is given to an other user."""
pass
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:

View File

@@ -125,9 +125,6 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
):
super().__init__(database, db_conn, hs)
if not hs.config.server.enable_search:
return
self.db_pool.updates.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
@@ -243,9 +240,13 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return len(event_search_rows)
result = await self.db_pool.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
if self.hs.config.server.enable_search:
result = await self.db_pool.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
else:
# Don't index anything if search is not enabled.
result = 0
if not result:
await self.db_pool.updates._end_background_update(

View File

@@ -36,7 +36,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
config_files = config.appservice.app_service_config_files
except AttributeError:
logger.warning("Could not get app_service_config_files from config")
pass
appservices = load_appservices(config.server.server_name, config_files)

View File

@@ -18,9 +18,10 @@ import collections
import inspect
import itertools
import logging
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Collection,
@@ -40,7 +41,7 @@ from typing import (
)
import attr
from typing_extensions import ContextManager, Literal
from typing_extensions import AsyncContextManager, Literal
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@@ -491,7 +492,7 @@ class ReadWriteLock:
Example:
with await read_write_lock.read("test_key"):
async with read_write_lock.read("test_key"):
# do some work
"""
@@ -514,22 +515,24 @@ class ReadWriteLock:
# Latest writer queued
self.key_to_current_writer: Dict[str, defer.Deferred] = {}
async def read(self, key: str) -> ContextManager:
new_defer: "defer.Deferred[None]" = defer.Deferred()
def read(self, key: str) -> AsyncContextManager:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
new_defer: "defer.Deferred[None]" = defer.Deferred()
curr_readers = self.key_to_current_readers.setdefault(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
curr_readers = self.key_to_current_readers.setdefault(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
curr_readers.add(new_defer)
curr_readers.add(new_defer)
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
if curr_writer:
await make_deferred_yieldable(curr_writer)
@contextmanager
def _ctx_manager() -> Iterator[None]:
try:
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
# May raise a `CancelledError` if the `Deferred` wrapping us is
# cancelled. The `Deferred` we are waiting on must not be cancelled,
# since we do not own it.
if curr_writer:
await make_deferred_yieldable(stop_cancellation(curr_writer))
yield
finally:
with PreserveLoggingContext():
@@ -538,29 +541,35 @@ class ReadWriteLock:
return _ctx_manager()
async def write(self, key: str) -> ContextManager:
new_defer: "defer.Deferred[None]" = defer.Deferred()
def write(self, key: str) -> AsyncContextManager:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
new_defer: "defer.Deferred[None]" = defer.Deferred()
curr_readers = self.key_to_current_readers.get(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
curr_readers = self.key_to_current_readers.get(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
# We wait on all latest readers and writer.
to_wait_on = list(curr_readers)
if curr_writer:
to_wait_on.append(curr_writer)
# We wait on all latest readers and writer.
to_wait_on = list(curr_readers)
if curr_writer:
to_wait_on.append(curr_writer)
# We can clear the list of current readers since the new writer waits
# for them to finish.
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
# We can clear the list of current readers since `new_defer` waits
# for them to finish.
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager() -> Iterator[None]:
to_wait_on_defer = defer.gatherResults(to_wait_on)
try:
# Wait for all current readers and the latest writer to finish.
# May raise a `CancelledError` immediately after the wait if the
# `Deferred` wrapping us is cancelled. We must only release the lock
# once we have acquired it, hence the use of `delay_cancellation`
# rather than `stop_cancellation`.
await make_deferred_yieldable(delay_cancellation(to_wait_on_defer))
yield
finally:
# Release the lock.
with PreserveLoggingContext():
new_defer.callback(None)
# `self.key_to_current_writer[key]` may be missing if there was another
@@ -686,12 +695,48 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
Synapse logcontext rules.
Returns:
A new `Deferred`, which will contain the result of the original `Deferred`,
but will not propagate cancellation through to the original. When cancelled,
the new `Deferred` will fail with a `CancelledError` and will not follow the
Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap
the new `Deferred`.
A new `Deferred`, which will contain the result of the original `Deferred`.
The new `Deferred` will not propagate cancellation through to the original.
When cancelled, the new `Deferred` will fail with a `CancelledError`.
The new `Deferred` will not follow the Synapse logcontext rules and should be
wrapped with `make_deferred_yieldable`.
"""
new_deferred: defer.Deferred[T] = defer.Deferred()
new_deferred: "defer.Deferred[T]" = defer.Deferred()
deferred.chainDeferred(new_deferred)
return new_deferred
def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
"""Delay cancellation of a `Deferred` until it resolves.
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
resolve with a `CancelledError` until the original `Deferred` resolves.
Args:
deferred: The `Deferred` to protect against cancellation. May optionally follow
the Synapse logcontext rules.
Returns:
A new `Deferred`, which will contain the result of the original `Deferred`.
The new `Deferred` will not propagate cancellation through to the original.
When cancelled, the new `Deferred` will wait until the original `Deferred`
resolves before failing with a `CancelledError`.
The new `Deferred` will follow the Synapse logcontext rules if `deferred`
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
wrapped with `make_deferred_yieldable`.
"""
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
# before the new deferred is cancelled, we `pause` it to stop the cancellation
# propagating. we then `unpause` it once the wrapped deferred completes, to
# propagate the exception.
new_deferred.pause()
new_deferred.errback(Failure(CancelledError()))
deferred.addBoth(lambda _: new_deferred.unpause())
new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel)
deferred.chainDeferred(new_deferred)
return new_deferred

View File

@@ -41,6 +41,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import delay_cancellation
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.lrucache import LruCache
@@ -350,6 +351,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
ret = cache.set(cache_key, ret, callback=invalidate_callback)
# We started a new call to `self.orig`, so we must always wait for it to
# complete. Otherwise we might mark our current logging context as
# finished while `self.orig` is still using it in the background.
ret = delay_cancellation(ret)
return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped)
@@ -510,6 +516,11 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
lambda _: results, unwrapFirstError
)
if missing:
# We started a new call to `self.orig`, so we must always wait for it to
# complete. Otherwise we might mark our current logging context as
# finished while `self.orig` is still using it in the background.
d = delay_cancellation(d)
return make_deferred_yieldable(d)
else:
return defer.succeed(results)

View File

@@ -22,8 +22,6 @@ class TreeCacheNode(dict):
leaves.
"""
pass
class TreeCache:
"""

View File

@@ -25,6 +25,7 @@ from typing import Iterable, NamedTuple, Optional
from packaging.requirements import Requirement
logger = logging.getLogger(__name__)
DISTRIBUTION_NAME = "matrix-synapse"
try:
@@ -164,6 +165,10 @@ def check_requirements(extra: Optional[str] = None) -> None:
errors.append(_not_installed(requirement, extra))
else:
# We specify prereleases=True to allow prereleases such as RCs.
logger.warning(
"DEP CHECK: %s, %s, %s, %r",
requirement, must_be_installed, dist, dist.version,
)
if not requirement.specifier.contains(dist.version, prereleases=True):
deps_unfulfilled.append(requirement.name)
errors.append(_incorrect_version(requirement, dist.version, extra))

View File

@@ -64,6 +64,7 @@ def build_jinja_env(
{
"format_ts": _format_ts_filter,
"mxc_to_http": _create_mxc_to_http_filter(config.server.public_baseurl),
"localpart_from_email": _localpart_from_email_filter,
}
)
@@ -112,3 +113,7 @@ def _create_mxc_to_http_filter(
def _format_ts_filter(value: int, format: str) -> str:
return time.strftime(format, time.localtime(value / 1000))
def _localpart_from_email_filter(address: str) -> str:
return address.rsplit("@", 1)[0]

View File

@@ -14,12 +14,7 @@
import logging
from typing import Dict, FrozenSet, List, Optional
from synapse.api.constants import (
AccountDataTypes,
EventTypes,
HistoryVisibility,
Membership,
)
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
from synapse.events.utils import prune_event
from synapse.storage import Storage
@@ -87,15 +82,8 @@ async def filter_events_for_client(
state_filter=StateFilter.from_types(types),
)
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
user_id, AccountDataTypes.IGNORED_USER_LIST
)
ignore_list: FrozenSet[str] = frozenset()
if ignore_dict_content:
ignored_users_dict = ignore_dict_content.get("ignored_users", {})
if isinstance(ignored_users_dict, dict):
ignore_list = frozenset(ignored_users_dict.keys())
# Get the users who are ignored by the requesting user.
ignore_list = await storage.main.ignored_users(user_id)
erased_senders = await storage.main.are_users_erased(e.sender for e in events)

View File

@@ -0,0 +1,58 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
from synapse.storage.background_updates import BackgroundUpdater
from tests.unittest import HomeserverTestCase, override_config
class BackgroundUpdateConfigTestCase(HomeserverTestCase):
# Tests that the default values in the config are correctly loaded. Note that the default
# values are loaded when the corresponding config options are commented out, which is why there isn't
# a config specified here.
def test_default_configuration(self):
background_updater = BackgroundUpdater(
self.hs, self.hs.get_datastores().main.db_pool
)
self.assertEqual(background_updater.minimum_background_batch_size, 1)
self.assertEqual(background_updater.default_background_batch_size, 100)
self.assertEqual(background_updater.sleep_enabled, True)
self.assertEqual(background_updater.sleep_duration_ms, 1000)
self.assertEqual(background_updater.update_duration_ms, 100)
# Tests that non-default values for the config options are properly picked up and passed on.
@override_config(
yaml.safe_load(
"""
background_updates:
background_update_duration_ms: 1000
sleep_enabled: false
sleep_duration_ms: 600
min_batch_size: 5
default_batch_size: 50
"""
)
)
def test_custom_configuration(self):
background_updater = BackgroundUpdater(
self.hs, self.hs.get_datastores().main.db_pool
)
self.assertEqual(background_updater.minimum_background_batch_size, 5)
self.assertEqual(background_updater.default_background_batch_size, 50)
self.assertEqual(background_updater.sleep_enabled, False)
self.assertEqual(background_updater.sleep_duration_ms, 600)
self.assertEqual(background_updater.update_duration_ms, 1000)

View File

@@ -15,11 +15,15 @@
from collections import Counter
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.storage
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
@@ -32,7 +36,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
knock.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_handler = hs.get_admin_handler()
self.user1 = self.register_user("user1", "password")
@@ -41,7 +45,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.user2 = self.register_user("user2", "password")
self.token2 = self.login("user2", "password")
def test_single_public_joined_room(self):
def test_single_public_joined_room(self) -> None:
"""Test that we write *all* events for a public room"""
room_id = self.helper.create_room_as(
self.user1, tok=self.token1, is_public=True
@@ -74,7 +78,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
def test_single_private_joined_room(self):
def test_single_private_joined_room(self) -> None:
"""Tests that we correctly write state when we can't see all events in
a room.
"""
@@ -112,7 +116,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
def test_single_left_room(self):
def test_single_left_room(self) -> None:
"""Tests that we don't see events in the room after we leave."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
@@ -144,7 +148,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
def test_single_left_rejoined_private_room(self):
def test_single_left_rejoined_private_room(self) -> None:
"""Tests that see the correct events in private rooms when we
repeatedly join and leave.
"""
@@ -185,7 +189,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
def test_invite(self):
def test_invite(self) -> None:
"""Tests that pending invites get handled correctly."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
@@ -204,7 +208,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[1].content["membership"], "invite")
self.assertTrue(args[2]) # Assert there is at least one bit of state
def test_knock(self):
def test_knock(self) -> None:
"""Tests that knock get handled correctly."""
# create a knockable v7 room
room_id = self.helper.create_room_as(

View File

@@ -15,8 +15,12 @@ from unittest.mock import Mock
import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -27,7 +31,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
admin.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.auth_handler = hs.get_auth_handler()
self.macaroon_generator = hs.get_macaroon_generator()
@@ -42,23 +46,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1 = self.register_user("a_user", "pass")
def test_macaroon_caveats(self):
def test_macaroon_caveats(self) -> None:
token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat):
def verify_gen(caveat: str) -> bool:
return caveat == "gen = 1"
def verify_user(caveat):
def verify_user(caveat: str) -> bool:
return caveat == "user_id = a_user"
def verify_type(caveat):
def verify_type(caveat: str) -> bool:
return caveat == "type = access"
def verify_nonce(caveat):
def verify_nonce(caveat: str) -> bool:
return caveat.startswith("nonce =")
def verify_guest(caveat):
def verify_guest(caveat: str) -> bool:
return caveat == "guest = true"
v = pymacaroons.Verifier()
@@ -69,7 +73,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self):
def test_short_term_login_token_gives_user_id(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
@@ -84,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
AuthError,
)
def test_short_term_login_token_gives_auth_provider(self):
def test_short_term_login_token_gives_auth_provider(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, auth_provider_id="my_idp"
)
@@ -92,7 +96,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.user1, res.user_id)
self.assertEqual("my_idp", res.auth_provider_id)
def test_short_term_login_token_cannot_replace_user_id(self):
def test_short_term_login_token_cannot_replace_user_id(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
@@ -112,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
AuthError,
)
def test_mau_limits_disabled(self):
def test_mau_limits_disabled(self) -> None:
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
self.get_success(
@@ -127,7 +131,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
def test_mau_limits_exceeded_large(self):
def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
@@ -150,7 +154,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
def test_mau_limits_parity(self):
def test_mau_limits_parity(self) -> None:
# Ensure we're not at the unix epoch.
self.reactor.advance(1)
self.auth_blocking._limit_usage_by_mau = True
@@ -189,7 +193,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
def test_mau_limits_not_exceeded(self):
def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -211,7 +215,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
def _get_macaroon(self):
def _get_macaroon(self) -> pymacaroons.Macaroon:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)

View File

@@ -39,7 +39,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self.user = self.register_user("user", "pass")
self.token = self.login("user", "pass")
def _deactivate_my_account(self):
def _deactivate_my_account(self) -> None:
"""
Deactivates the account `self.user` using `self.token` and asserts
that it returns a 200 success code.

View File

@@ -14,9 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.api.errors
import synapse.handlers.device
import synapse.storage
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
@@ -25,28 +30,27 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastores().main
return hs
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
def test_device_is_created_with_invalid_name(self):
def test_device_is_created_with_invalid_name(self) -> None:
self.get_failure(
self.handler.check_device_registered(
user_id="@boris:foo",
device_id="foo",
initial_device_display_name="a"
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1),
),
synapse.api.errors.SynapseError,
SynapseError,
)
def test_device_is_created_if_doesnt_exist(self):
def test_device_is_created_if_doesnt_exist(self) -> None:
res = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
@@ -59,7 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
def test_device_is_preserved_if_exists(self):
def test_device_is_preserved_if_exists(self) -> None:
res1 = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
@@ -81,7 +85,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
def test_device_id_is_made_up_if_unspecified(self):
def test_device_id_is_made_up_if_unspecified(self) -> None:
device_id = self.get_success(
self.handler.check_device_registered(
user_id="@theresa:foo",
@@ -93,7 +97,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
self.assertEqual(dev["display_name"], "display")
def test_get_devices_by_user(self):
def test_get_devices_by_user(self) -> None:
self._record_users()
res = self.get_success(self.handler.get_devices_by_user(user1))
@@ -131,7 +135,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
device_map["abc"],
)
def test_get_device(self):
def test_get_device(self) -> None:
self._record_users()
res = self.get_success(self.handler.get_device(user1, "abc"))
@@ -146,21 +150,19 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res,
)
def test_delete_device(self):
def test_delete_device(self) -> None:
self._record_users()
# delete the device
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
self.get_failure(
self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
)
self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
def test_delete_device_and_device_inbox(self):
def test_delete_device_and_device_inbox(self) -> None:
self._record_users()
# add an device_inbox
@@ -191,7 +193,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(res)
def test_update_device(self):
def test_update_device(self) -> None:
self._record_users()
update = {"display_name": "new display"}
@@ -200,32 +202,29 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display")
def test_update_device_too_long_display_name(self):
def test_update_device_too_long_display_name(self) -> None:
"""Update a device with a display name that is invalid (too long)."""
self._record_users()
# Request to update a device display name with a new value that is longer than allowed.
update = {
"display_name": "a"
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
}
update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)}
self.get_failure(
self.handler.update_device(user1, "abc", update),
synapse.api.errors.SynapseError,
SynapseError,
)
# Ensure the display name was not updated.
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "display 2")
def test_update_unknown_device(self):
def test_update_unknown_device(self) -> None:
update = {"display_name": "new_display"}
self.get_failure(
self.handler.update_device("user_id", "unknown_device_id", update),
synapse.api.errors.NotFoundError,
NotFoundError,
)
def _record_users(self):
def _record_users(self) -> None:
# check this works for both devices which have a recorded client_ip,
# and those which don't.
self._record_user(user1, "xyz", "display 0")
@@ -238,8 +237,13 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10000)
def _record_user(
self, user_id, device_id, display_name, access_token=None, ip=None
):
self,
user_id: str,
device_id: str,
display_name: str,
access_token: Optional[str] = None,
ip: Optional[str] = None,
) -> None:
device_id = self.get_success(
self.handler.check_device_registered(
user_id=user_id,
@@ -248,7 +252,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
)
if ip is not None:
if access_token is not None and ip is not None:
self.get_success(
self.store.insert_client_ip(
user_id, access_token, ip, "user_agent", device_id
@@ -258,7 +262,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.registration = hs.get_registration_handler()
@@ -266,7 +270,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
return hs
def test_dehydrate_and_rehydrate_device(self):
def test_dehydrate_and_rehydrate_device(self) -> None:
user_id = "@boris:dehydration"
self.get_success(self.store.register_user(user_id, "foobar"))
@@ -303,7 +307,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
access_token=access_token,
device_id="not the right device ID",
),
synapse.api.errors.NotFoundError,
NotFoundError,
)
# dehydrating the right devices should succeed and change our device ID
@@ -331,7 +335,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that the device ID that we were initially assigned no longer exists
self.get_failure(
self.handler.get_device(user_id, device_id),
synapse.api.errors.NotFoundError,
NotFoundError,
)
# make sure that there's no device available for dehydrating now

View File

@@ -12,14 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.api.errors
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client import directory, login, room
from synapse.types import RoomAlias, create_requester
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, create_requester
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the directory service."""
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_registry = Mock()
self.query_handlers = {}
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
def register_query_handler(query_type, handler):
def register_query_handler(
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
) -> None:
self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler
@@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
return hs
def test_get_local_association(self):
def test_get_local_association(self) -> None:
self.get_success(
self.store.create_room_alias_association(
self.my_room, "!8765qwer:test", ["test"]
@@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self):
def test_get_remote_association(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
@@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
def test_incoming_fed_query(self):
def test_incoming_fed_query(self) -> None:
self.get_success(
self.store.create_room_alias_association(
self.your_room, "!8765asdf:test", ["test"]
@@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
directory.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_directory_handler()
# Create user
@@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
def test_create_alias_joined_room(self):
def test_create_alias_joined_room(self) -> None:
"""A user can create an alias for a room they're in."""
self.get_success(
self.handler.create_association(
@@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
)
)
def test_create_alias_other_room(self):
def test_create_alias_other_room(self) -> None:
"""A user cannot create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
@@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
def test_create_alias_admin(self):
def test_create_alias_admin(self) -> None:
"""An admin can create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as(
self.test_user, tok=self.test_user_tok
@@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
directory.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
@@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
def _create_alias(self, user):
def _create_alias(self, user) -> None:
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
@@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
)
)
def test_delete_alias_not_allowed(self):
def test_delete_alias_not_allowed(self) -> None:
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
self._create_alias(self.admin_user)
self.get_failure(
@@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.AuthError,
)
def test_delete_alias_creator(self):
def test_delete_alias_creator(self) -> None:
"""An alias creator can delete their own alias."""
# Create an alias from a different user.
self._create_alias(self.test_user)
@@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
def test_delete_alias_admin(self):
def test_delete_alias_admin(self) -> None:
"""A server admin can delete an alias created by another user."""
# Create an alias from a different user.
self._create_alias(self.test_user)
@@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
def test_delete_alias_sufficient_power(self):
def test_delete_alias_sufficient_power(self) -> None:
"""A user with a sufficient power level should be able to delete an alias."""
self._create_alias(self.admin_user)
@@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
directory.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
@@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
return room_alias
def _set_canonical_alias(self, content):
def _set_canonical_alias(self, content) -> None:
"""Configure the canonical alias state on the room."""
self.helper.send_state(
self.room_id,
@@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
)
def test_remove_alias(self):
def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too."""
# Set this new alias as the canonical alias for this room
self._set_canonical_alias(
@@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertNotIn("alias", data["content"])
self.assertNotIn("alt_aliases", data["content"])
def test_remove_other_alias(self):
def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too."""
# Create a second alias.
other_test_alias = "#test2:test"
@@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# Add custom alias creation rules to the config.
@@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
return config
def test_denied(self):
def test_denied(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
)
self.assertEqual(403, channel.code, channel.result)
def test_allowed(self):
def test_allowed(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, channel.result)
def test_denied_during_creation(self):
def test_denied_during_creation(self) -> None:
"""A room alias that is not allowed should be rejected during creation."""
# Invalid room alias.
self.helper.create_room_as(
@@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
extra_content={"room_alias_name": "foo"},
)
def test_allowed_during_creation(self):
def test_allowed_during_creation(self) -> None:
"""A valid room alias should be allowed during creation."""
room_id = self.helper.create_room_as(
self.user_id,
@@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
data = {"room_alias_name": "unofficial_test"}
allowed_localpart = "allowed"
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# Add custom room list publication rules to the config.
@@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config
def prepare(self, reactor, clock, hs):
def prepare(
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
@@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return hs
def test_denied_without_publication_permission(self):
def test_denied_without_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user without permission to publish rooms.
@@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403,
)
def test_allowed_when_creating_private_room(self):
def test_allowed_when_creating_private_room(self) -> None:
"""
Try to create a room, register an alias for it, and NOT publish it,
as a user without permission to publish rooms.
@@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200,
)
def test_allowed_with_publication_permission(self):
def test_allowed_with_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
@@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200,
)
def test_denied_publication_with_invalid_alias(self):
def test_denied_publication_with_invalid_alias(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
@@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403,
)
def test_can_create_as_private_room_after_rejection(self):
def test_can_create_as_private_room_after_rejection(self) -> None:
"""
After failing to publish a room with an alias as a user without publish permission,
retry as the same user, but without publishing the room.
@@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
self.test_denied_without_publication_permission()
self.test_allowed_when_creating_private_room()
def test_can_create_with_permission_after_rejection(self):
def test_can_create_with_permission_after_rejection(self) -> None:
"""
After failing to publish a room with an alias as a user without publish permission,
retry as someone with permission, using the same alias.
@@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
def prepare(self, reactor, clock, hs):
def prepare(
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
return hs
def test_disabling_room_list(self):
def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True

View File

@@ -20,33 +20,37 @@ from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=mock.Mock())
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main
def test_query_local_devices_no_devices(self):
def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list."""
local_user = "@boris:" + self.hs.hostname
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
def test_reupload_one_time_keys(self):
def test_reupload_one_time_keys(self) -> None:
"""we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {
keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
@@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
)
def test_change_one_time_keys(self):
def test_change_one_time_keys(self) -> None:
"""attempts to change one-time-keys should be rejected"""
local_user = "@boris:" + self.hs.hostname
@@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
SynapseError,
)
def test_claim_one_time_key(self):
def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {"alg1:k1": "key1"}
@@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
def test_fallback_key(self):
def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
@@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
def test_replace_master_key(self):
def test_replace_master_key(self) -> None:
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
def test_reupload_signatures(self):
def test_reupload_signatures(self) -> None:
"""re-uploading a signature should not fail"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
def test_self_signing_key_doesnt_show_up_as_device(self):
def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
"""signing keys should be hidden when fetching a user's devices"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
def test_upload_signatures(self):
def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded"""
# set up a user with cross-signing keys and a device. This user will
# try uploading signatures
@@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)
def test_query_devices_remote_no_sync(self):
def test_query_devices_remote_no_sync(self) -> None:
"""Tests that querying keys for a remote user that we don't share a room
with returns the cross signing keys correctly.
"""
@@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
def test_query_devices_remote_sync(self):
def test_query_devices_remote_sync(self) -> None:
"""Tests that querying keys for a remote user that we share a room with,
but haven't yet fetched the keys for, returns the cross signing keys
correctly.
@@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
(["device_1", "device_2"],),
]
)
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
"""Test that requests for all of a remote user's devices are cached.
We do this by asserting that only one call over federation was made, and that
@@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
"""
local_user_id = "@test:test"
remote_user_id = "@test:other"
request_body = {"device_keys": {remote_user_id: []}}
request_body: JsonDict = {"device_keys": {remote_user_id: []}}
response_devices = [
{

View File

@@ -13,14 +13,18 @@
# limitations under the License.
import json
import os
from typing import Any, Dict
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
}
async def get_json(url):
async def get_json(url: str) -> JsonDict:
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
# Minimal discovery document, as defined in OpenID.Discovery
@@ -116,6 +120,8 @@ async def get_json(url):
elif url == JWKS_URI:
return {"keys": []}
return {}
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC:
skip = "requires OIDC"
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
return config
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = b"Synapse Test"
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
sso_handler.render_error = self.render_error
sso_handler.render_error = self.render_error # type: ignore[assignment]
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
return args
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_config(self):
def test_config(self) -> None:
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
def test_discovery(self):
def test_discovery(self) -> None:
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self):
def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_load_jwks(self):
def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self):
def test_validate_config(self) -> None:
"""Provider metadatas are extensively validated."""
h = self.provider
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
force_load_metadata()
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
def test_skip_verification(self):
def test_skip_verification(self) -> None:
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request(self):
def test_redirect_request(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
req.cookies = []
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(redirect, "http://client/redirect")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self):
def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback(self):
def test_callback(self) -> None:
"""Code callback works and display errors if something went wrong.
A lot of scenarios are tested here:
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
self.provider._exchange_code = simple_async_mock(return_value=token)
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
self.provider._parse_id_token = simple_async_mock(raises=Exception())
self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"type": "bearer",
"access_token": "access_token",
}
self.provider._exchange_code = simple_async_mock(return_value=token)
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
id_token = {
"sid": "abcdefgh",
}
self.provider._parse_id_token = simple_async_mock(return_value=id_token)
self.provider._exchange_code = simple_async_mock(return_value=token)
self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
auth_handler.complete_sso_login.reset_mock()
self.provider._fetch_userinfo.reset_mock()
self.get_success(self.handler.handle_oidc_callback(request))
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.assert_not_called()
# Handle userinfo fetching error
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
from synapse.handlers.oidc import OidcError
self.provider._exchange_code = simple_async_mock(
self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self):
def test_callback_session(self) -> None:
"""The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"])
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
)
def test_exchange_code(self):
def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
token_json = json.dumps(token).encode("utf-8")
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
def test_exchange_code_jwt_key(self):
def test_exchange_code_jwt_key(self) -> None:
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
def test_exchange_code_no_auth(self):
def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
token = {"type": "bearer"}
self.http_client.request = simple_async_mock(
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
def test_extra_attributes(self):
def test_extra_attributes(self) -> None:
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo",
"phone": "1234567",
}
self.provider._exchange_code = simple_async_mock(return_value=token)
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self):
def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
userinfo = {
userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastores().main
user = UserID.from_string("@test_user:test")
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self):
def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
def test_map_userinfo_to_user_retries(self):
def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_empty_localpart(self):
def test_empty_localpart(self) -> None:
"""Attempts to map onto an empty localpart should be rejected."""
userinfo = {
"sub": "tester",
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
def test_null_localpart(self):
def test_null_localpart(self) -> None:
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
userinfo = {
"sub": "tester",
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
def test_attribute_requirements(self):
def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
def test_attribute_requirements_contains(self):
def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
def test_attribute_requirements_mismatch(self):
def test_attribute_requirements_mismatch(self) -> None:
"""
Test that auth fails if attributes exist but don't match,
or are non-string values.
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail
userinfo = {
userinfo: dict = {
"sub": "tester",
"username": "tester",
"test": "not_foobar",
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
handler = hs.get_oidc_handler()
provider = handler._providers["oidc"]
provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
provider._parse_id_token = simple_async_mock(return_value=userinfo)
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
state = "state"
session = handler._token_generator.generate_oidc_session_token(

View File

@@ -124,7 +124,6 @@ class PasswordCustomAuthProvider:
("m.login.password", ("password",)): self.check_auth,
}
)
pass
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)

View File

@@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_registry = Mock()
self.query_handlers = {}
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
def register_query_handler(query_type, handler):
def register_query_handler(
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
) -> None:
self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler
@@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
return hs
def prepare(self, reactor, clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.frank = UserID.from_string("@1234abcd:test")
@@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_profile_handler()
def test_get_my_name(self):
def test_get_my_name(self) -> None:
self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
@@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("Frank", displayname)
def test_set_my_name(self):
def test_set_my_name(self) -> None:
self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
@@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
)
def test_set_my_name_if_disabled(self):
def test_set_my_name_if_disabled(self) -> None:
self.hs.config.registration.enable_set_displayname = False
# Setting displayname for the first time is allowed
@@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError,
)
def test_set_my_name_noauth(self):
def test_set_my_name_noauth(self) -> None:
self.get_failure(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
@@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
AuthError,
)
def test_get_other_name(self):
def test_get_other_name(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
@@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
def test_incoming_fed_query(self):
def test_incoming_fed_query(self) -> None:
self.get_success(self.store.create_profile("caroline"))
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
@@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual({"displayname": "Caroline"}, response)
def test_get_my_avatar(self):
def test_get_my_avatar(self) -> None:
self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
@@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("http://my.server/me.png", avatar_url)
def test_set_my_avatar(self):
def test_set_my_avatar(self) -> None:
self.get_success(
self.handler.set_avatar_url(
self.frank,
@@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
)
def test_set_my_avatar_if_disabled(self):
def test_set_my_avatar_if_disabled(self) -> None:
self.hs.config.registration.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
@@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError,
)
def test_avatar_constraints_no_config(self):
def test_avatar_constraints_no_config(self) -> None:
"""Tests that the method to check an avatar against configured constraints skips
all of its check if no constraint is configured.
"""
@@ -263,7 +268,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertTrue(res)
@unittest.override_config({"max_avatar_size": 50})
def test_avatar_constraints_missing(self):
def test_avatar_constraints_missing(self) -> None:
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
be found.
"""
@@ -273,7 +278,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res)
@unittest.override_config({"max_avatar_size": 50})
def test_avatar_constraints_file_size(self):
def test_avatar_constraints_file_size(self) -> None:
"""Tests that a file that's above the allowed file size is forbidden but one
that's below it is allowed.
"""
@@ -295,7 +300,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
def test_avatar_constraint_mime_type(self):
def test_avatar_constraint_mime_type(self) -> None:
"""Tests that a file with an unauthorised MIME type is forbidden but one with
an authorised content type is allowed.
"""

View File

@@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Any, Dict, Optional
from unittest.mock import Mock
import attr
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import RedirectException
from synapse.server import HomeServer
from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
class SamlHandlerTestCase(HomeserverTestCase):
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
saml_config = {
saml_config: Dict[str, Any] = {
"sp_config": {"metadata": {}},
# Disable grandfathering.
"grandfathered_mxid_source_attribute": None,
@@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
self.handler = hs.get_saml_handler()
@@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
elif not has_xmlsec1:
skip = "Requires xmlsec1"
def test_map_saml_response_to_user(self):
def test_map_saml_response_to_user(self) -> None:
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
# stub out the auth handler
@@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self):
def test_map_saml_response_to_existing_user(self) -> None:
"""Existing users can log in with SAML account."""
store = self.hs.get_datastores().main
self.get_success(
@@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
def test_map_saml_response_to_invalid_localpart(self):
def test_map_saml_response_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
# stub out the auth handler
@@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
auth_handler.complete_sso_login.assert_not_called()
def test_map_saml_response_to_user_retries(self):
def test_map_saml_response_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
# stub out the auth handler and error renderer
@@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
}
)
def test_map_saml_response_redirect(self):
def test_map_saml_response_redirect(self) -> None:
"""Test a mapping provider that raises a RedirectException"""
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
},
}
)
def test_attribute_requirements(self):
def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the SAML response."""
# stub out the auth handler

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -18,7 +19,7 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfigException
from synapse.rest.client import login, receipts, room
from synapse.rest.client import login, push_rule, receipts, room
from tests.unittest import HomeserverTestCase, override_config
@@ -29,6 +30,7 @@ class HTTPPusherTests(HomeserverTestCase):
room.register_servlets,
login.register_servlets,
receipts.register_servlets,
push_rule.register_servlets,
]
user_id = True
hijack_auth = False
@@ -39,12 +41,12 @@ class HTTPPusherTests(HomeserverTestCase):
return config
def make_homeserver(self, reactor, clock):
self.push_attempts = []
self.push_attempts: List[tuple[Deferred, str, dict]] = []
m = Mock()
def post_json_get_json(url, body):
d = Deferred()
d: Deferred = Deferred()
self.push_attempts.append((d, url, body))
return make_deferred_yieldable(d)
@@ -719,3 +721,67 @@ class HTTPPusherTests(HomeserverTestCase):
access_token=access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
user_id = self.register_user(username, "pass")
access_token = self.login(username, "pass")
# Register the pusher
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
app_id="m.http",
app_display_name="HTTP Push Notifications",
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
return user_id, access_token
def test_dont_notify_rule_overrides_message(self):
"""
The override push rule will suppress notification
"""
user_id, access_token = self._make_user_with_pusher("user")
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
# Create a room
room = self.helper.create_room_as(user_id, tok=access_token)
# Disable user notifications for this room -> user
body = {
"conditions": [{"kind": "event_match", "key": "room_id", "pattern": room}],
"actions": ["dont_notify"],
}
channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend",
body,
access_token=access_token,
)
self.assertEqual(channel.code, 200)
# Check we start with no pushes
self.assertEqual(len(self.push_attempts), 0)
# The other user joins
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
# The other user sends a message (ignored by dont_notify push rule set above)
self.helper.send(room, body="Hi!", tok=other_access_token)
self.assertEqual(len(self.push_attempts), 0)
# The user sends a message back (sends a notification)
self.helper.send(room, body="Hello", tok=access_token)
self.assertEqual(len(self.push_attempts), 1)

View File

@@ -39,6 +39,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.updater = BackgroundUpdater(hs, self.store.db_pool)
@parameterized.expand(
[
@@ -135,10 +136,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"""Test the status API works with a background update."""
# Create a new background update
self._register_bg_update()
self.store.db_pool.updates.start_doing_background_updates()
self.reactor.pump([1.0, 1.0, 1.0])
channel = self.make_request(
@@ -158,7 +159,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE
self.updater.default_background_batch_size
),
}
},
@@ -213,7 +214,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE
self.updater.default_background_batch_size
),
}
},
@@ -242,7 +243,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE
self.updater.default_background_batch_size
),
}
},

File diff suppressed because it is too large Load Diff

View File

@@ -24,6 +24,7 @@ from synapse.util import Clock
from synapse.visibility import filter_events_for_client
from tests import unittest
from tests.unittest import override_config
one_hour_ms = 3600000
one_day_ms = one_hour_ms * 24
@@ -38,7 +39,10 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["retention"] = {
# merge this default retention config with anything that was specified in
# @override_config
retention_config = {
"enabled": True,
"default_policy": {
"min_lifetime": one_day_ms,
@@ -47,6 +51,8 @@ class RetentionTestCase(unittest.HomeserverTestCase):
"allowed_lifetime_min": one_day_ms,
"allowed_lifetime_max": one_day_ms * 3,
}
retention_config.update(config.get("retention", {}))
config["retention"] = retention_config
self.hs = self.setup_test_homeserver(config=config)
@@ -115,22 +121,20 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self._test_retention_event_purged(room_id, one_day_ms * 2)
@override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}})
def test_visibility(self) -> None:
"""Tests that synapse.visibility.filter_events_for_client correctly filters out
outdated events
outdated events, even if the purge job hasn't got to them yet.
We do this by setting a very long time between purge jobs.
"""
store = self.hs.get_datastores().main
storage = self.hs.get_storage()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
events = []
# Send a first event, which should be filtered out at the end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
# Get the event from the store so that we end up with a FrozenEvent that we can
# give to filter_events_for_client. We need to do this now because the event won't
# be in the database anymore after it has expired.
events.append(self.get_success(store.get_event(resp.get("event_id"))))
first_event_id = resp.get("event_id")
# Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid.
@@ -138,16 +142,17 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send another event, which shouldn't get filtered out.
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
valid_event_id = resp.get("event_id")
events.append(self.get_success(store.get_event(valid_event_id)))
# Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
self.reactor.advance(one_day_ms * 2 / 1000)
# Run filter_events_for_client with our list of FrozenEvents.
# Fetch the events, and run filter_events_for_client on them
events = self.get_success(
store.get_events_as_list([first_event_id, valid_event_id])
)
self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
filter_events_for_client(storage, self.user_id, events)
)

View File

@@ -1,3 +1,18 @@
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock, call
from twisted.internet import defer, reactor
@@ -11,14 +26,14 @@ from tests.utils import MockClock
class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.clock = MockClock()
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (200, "GOOD JOB!")
self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
self.mock_key = "foo"
@defer.inlineCallbacks

Some files were not shown because too many files have changed in this diff Show More