1
0

Compare commits

..

23 Commits

Author SHA1 Message Date
Travis Ralston
b024acffea Add rudimentary API for promoting/demoting other people in a group
For https://github.com/matrix-org/synapse/issues/2855 (initial)
2020-08-18 15:21:30 -06:00
Patrick Cloke
acfb7c3b5d Add a link to the matrix-synapse-rest-password-provider. (#8111) 2020-08-18 09:54:35 -04:00
Patrick Cloke
3c01724b33 Fix the return type of send_nonmember_events. (#8112) 2020-08-18 09:53:13 -04:00
Andrew Morgan
5cf7c12995 Remove : from allowed client_secret chars (#8101)
Closes: https://github.com/matrix-org/synapse/issues/6766

Equivalent Sydent PR: https://github.com/matrix-org/sydent/pull/309

I believe it's now time to remove the extra allowed `:` from `client_secret` parameters.
2020-08-18 14:14:27 +01:00
Patrick Cloke
408aef8276 Rename changelog from bugfix to misc. 2020-08-18 09:09:11 -04:00
Patrick Cloke
2f4d60a5ba Iteratively encode JSON responses to avoid blocking the reactor. (#8013) 2020-08-18 08:49:59 -04:00
Patrick Cloke
25e55d2598 Return the previous stream token if a non-member event is a duplicate. (#8093) 2020-08-18 07:53:23 -04:00
Andrew Morgan
8b6c176aee Add resources.consent conditional dependency back (#8107)
Turns out that part of the codebase (synapse.config.server) checks for this key explicitly. Remove that check.
2020-08-18 10:59:54 +01:00
Patrick Cloke
050e20e7ca Convert some of the general database methods to async (#8100) 2020-08-17 12:18:01 -04:00
Andrew Morgan
e04e465b4d Use the default templates when a custom template file cannot be found (#8037)
Fixes https://github.com/matrix-org/synapse/issues/6583
2020-08-17 17:05:00 +01:00
Olivier Wilkinson (reivilibre)
8390e00c7f Merge branch 'master' into develop 2020-08-17 14:28:49 +01:00
Patrick Cloke
ad6190c925 Convert stream database to async/await. (#8074) 2020-08-17 07:24:46 -04:00
Patrick Cloke
ac77cdb64e Add a shadow-banned flag to users. (#8092) 2020-08-14 12:37:59 -04:00
Patrick Cloke
b069b78bb4 Convert pusher databases to async/await. (#8075) 2020-08-14 10:30:16 -04:00
Patrick Cloke
e8861957d9 Convert receipts and events databases to async/await. (#8076) 2020-08-14 10:05:19 -04:00
Erik Johnston
dc22090a67 Add type hints to synapse.handlers.room (#8090) 2020-08-14 14:47:53 +01:00
Patrick Cloke
6b7ce1d332 Remove some unused database functions. (#8085) 2020-08-14 09:25:40 -04:00
Patrick Cloke
894dae74fe Convert misc database code to async (#8087) 2020-08-14 07:24:26 -04:00
Patrick Cloke
7bdf9828d5 Remove a space at the start of a changelog entry. 2020-08-13 14:16:18 -04:00
Olivier Wilkinson (reivilibre)
bfd79c2988 Merge tag 'v1.19.0rc1' into develop
Synapse 1.19.0rc1 (2020-08-13)
==============================

Removal warning
---------------

As outlined in the [previous release](https://github.com/matrix-org/synapse/releases/tag/v1.18.0), we are no longer publishing Docker images with the `-py3` tag suffix. On top of that, we have also removed the `latest-py3` tag. Please see [the announcement in the upgrade notes for 1.18.0](https://github.com/matrix-org/synapse/blob/develop/UPGRADE.rst#upgrading-to-v1180).

Features
--------

- Add option to allow server admins to join rooms which fail complexity checks. Contributed by @lugino-emeritus. ([\#7902](https://github.com/matrix-org/synapse/issues/7902))
- Add an option to purge room or not with delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel. ([\#7964](https://github.com/matrix-org/synapse/issues/7964))
- Add rate limiting to users joining rooms. ([\#8008](https://github.com/matrix-org/synapse/issues/8008))
- Add a `/health` endpoint to every configured HTTP listener that can be used as a health check endpoint by load balancers. ([\#8048](https://github.com/matrix-org/synapse/issues/8048))
- Allow login to be blocked based on the values of SAML attributes. ([\#8052](https://github.com/matrix-org/synapse/issues/8052))
- Allow guest access to the `GET /_matrix/client/r0/rooms/{room_id}/members` endpoint, according to MSC2689. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7314](https://github.com/matrix-org/synapse/issues/7314))

Bugfixes
--------

- Fix a bug introduced in Synapse v1.7.2 which caused inaccurate membership counts in the room directory. ([\#7977](https://github.com/matrix-org/synapse/issues/7977))
- Fix a long standing bug: 'Duplicate key value violates unique constraint "event_relations_id"' when message retention is configured. ([\#7978](https://github.com/matrix-org/synapse/issues/7978))
- Fix "no create event in auth events" when trying to reject invitation after inviter leaves. Bug introduced in Synapse v1.10.0. ([\#7980](https://github.com/matrix-org/synapse/issues/7980))
- Fix various comments and minor discrepencies in server notices code. ([\#7996](https://github.com/matrix-org/synapse/issues/7996))
- Fix a long standing bug where HTTP HEAD requests resulted in a 400 error. ([\#7999](https://github.com/matrix-org/synapse/issues/7999))
- Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger. ([\#8011](https://github.com/matrix-org/synapse/issues/8011), [\#8012](https://github.com/matrix-org/synapse/issues/8012))

Updates to the Docker image
---------------------------

- We no longer publish Docker images with the `-py3` tag suffix, as [announced in the upgrade notes](https://github.com/matrix-org/synapse/blob/develop/UPGRADE.rst#upgrading-to-v1180). ([\#8056](https://github.com/matrix-org/synapse/issues/8056))

Improved Documentation
----------------------

- Document how to set up a client .well-known file and fix several pieces of outdated documentation. ([\#7899](https://github.com/matrix-org/synapse/issues/7899))
- Improve workers docs. ([\#7990](https://github.com/matrix-org/synapse/issues/7990), [\#8000](https://github.com/matrix-org/synapse/issues/8000))
- Fix typo in `docs/workers.md`. ([\#7992](https://github.com/matrix-org/synapse/issues/7992))
- Add documentation for how to undo a room shutdown. ([\#7998](https://github.com/matrix-org/synapse/issues/7998), [\#8010](https://github.com/matrix-org/synapse/issues/8010))

Internal Changes
----------------

- Reduce the amount of whitespace in JSON stored and sent in responses. Contributed by David Vo. ([\#7372](https://github.com/matrix-org/synapse/issues/7372))
- Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0. ([\#7936](https://github.com/matrix-org/synapse/issues/7936), [\#7979](https://github.com/matrix-org/synapse/issues/7979))
- Convert various parts of the codebase to async/await. ([\#7947](https://github.com/matrix-org/synapse/issues/7947), [\#7948](https://github.com/matrix-org/synapse/issues/7948), [\#7949](https://github.com/matrix-org/synapse/issues/7949), [\#7951](https://github.com/matrix-org/synapse/issues/7951), [\#7963](https://github.com/matrix-org/synapse/issues/7963), [\#7973](https://github.com/matrix-org/synapse/issues/7973), [\#7975](https://github.com/matrix-org/synapse/issues/7975), [\#7976](https://github.com/matrix-org/synapse/issues/7976), [\#7981](https://github.com/matrix-org/synapse/issues/7981), [\#7987](https://github.com/matrix-org/synapse/issues/7987), [\#7989](https://github.com/matrix-org/synapse/issues/7989), [\#8003](https://github.com/matrix-org/synapse/issues/8003), [\#8014](https://github.com/matrix-org/synapse/issues/8014), [\#8016](https://github.com/matrix-org/synapse/issues/8016), [\#8027](https://github.com/matrix-org/synapse/issues/8027), [\#8031](https://github.com/matrix-org/synapse/issues/8031), [\#8032](https://github.com/matrix-org/synapse/issues/8032), [\#8035](https://github.com/matrix-org/synapse/issues/8035), [\#8042](https://github.com/matrix-org/synapse/issues/8042), [\#8044](https://github.com/matrix-org/synapse/issues/8044), [\#8045](https://github.com/matrix-org/synapse/issues/8045), [\#8061](https://github.com/matrix-org/synapse/issues/8061), [\#8062](https://github.com/matrix-org/synapse/issues/8062), [\#8063](https://github.com/matrix-org/synapse/issues/8063), [\#8066](https://github.com/matrix-org/synapse/issues/8066), [\#8069](https://github.com/matrix-org/synapse/issues/8069), [\#8070](https://github.com/matrix-org/synapse/issues/8070))
- Move some database-related log lines from the default logger to the database/transaction loggers. ([\#7952](https://github.com/matrix-org/synapse/issues/7952))
- Add a script to detect source code files using non-unix line terminators. ([\#7965](https://github.com/matrix-org/synapse/issues/7965), [\#7970](https://github.com/matrix-org/synapse/issues/7970))
- Log the SAML session ID during creation. ([\#7971](https://github.com/matrix-org/synapse/issues/7971))
- Implement new experimental push rules for some users. ([\#7997](https://github.com/matrix-org/synapse/issues/7997))
- Remove redundant and unreliable signature check for v1 Identity Service lookup responses. ([\#8001](https://github.com/matrix-org/synapse/issues/8001))
- Improve the performance of the register endpoint. ([\#8009](https://github.com/matrix-org/synapse/issues/8009))
- Reduce less useful output in the newsfragment CI step. Add a link to the changelog section of the contributing guide on error. ([\#8024](https://github.com/matrix-org/synapse/issues/8024))
- Rename storage layer objects to be more sensible. ([\#8033](https://github.com/matrix-org/synapse/issues/8033))
- Change the default log config to reduce disk I/O and storage for new servers. ([\#8040](https://github.com/matrix-org/synapse/issues/8040))
- Add an assertion on `prev_events` in `create_new_client_event`. ([\#8041](https://github.com/matrix-org/synapse/issues/8041))
- Add a comment to `ServerContextFactory` about the use of `SSLv23_METHOD`. ([\#8043](https://github.com/matrix-org/synapse/issues/8043))
- Log `OPTIONS` requests at `DEBUG` rather than `INFO` level to reduce amount logged at `INFO`. ([\#8049](https://github.com/matrix-org/synapse/issues/8049))
- Reduce amount of outbound request logging at `INFO` level. ([\#8050](https://github.com/matrix-org/synapse/issues/8050))
- It is no longer necessary to explicitly define `filters` in the logging configuration. (Continuing to do so is redundant but harmless.) ([\#8051](https://github.com/matrix-org/synapse/issues/8051))
- Add and improve type hints. ([\#8058](https://github.com/matrix-org/synapse/issues/8058), [\#8064](https://github.com/matrix-org/synapse/issues/8064), [\#8060](https://github.com/matrix-org/synapse/issues/8060), [\#8067](https://github.com/matrix-org/synapse/issues/8067))
2020-08-13 18:22:58 +01:00
Richard van der Hoff
53834bb9c4 Run remove_push_actions_from_staging in foreground (#8081)
If we got an error persisting an event, we would try to remove the push actions
asynchronously, which would lead to a 'Re-starting finished log context'
warning.

I don't think there's any need for this to be asynchronous.
2020-08-13 17:05:31 +01:00
reivilibre
ff0e894656 Drop federation transmission queues during a significant remote outage. (#7864)
* Empty federation transmission queues when we are backing off.

Fixes #7828.

Signed-off-by: Olivier Wilkinson (reivilibre) <olivier@librepush.net>

* Address feedback

Signed-off-by: Olivier Wilkinson (reivilibre) <olivier@librepush.net>

* Reword newsfile
2020-08-13 12:35:04 +01:00
Patrick Cloke
dd8f28bd3f Fix unawaited coroutine error in tests. (#8072) 2020-08-13 07:11:39 -04:00
105 changed files with 1280 additions and 1461 deletions

View File

@@ -1,28 +1,15 @@
Synapse 1.19.2 (2020-09-16)
===========================
For the next release
====================
Due to the issue below server admins are encouraged to upgrade as soon as possible.
Removal warning
---------------
Bugfixes
--------
- Fix joining rooms over federation that include malformed events. ([\#8324](https://github.com/matrix-org/synapse/issues/8324))
Synapse 1.19.1 (2020-08-27)
===========================
No significant changes.
Synapse 1.19.1rc1 (2020-08-25)
==============================
Bugfixes
--------
- Fix a bug introduced in v1.19.0 where appservices with ratelimiting disabled would still be ratelimited when joining rooms. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
- Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. ([\#8153](https://github.com/matrix-org/synapse/issues/8153))
Some older clients used a
[disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken)
(`:`) in the `client_secret` parameter of various endpoints. The incorrect
behaviour was allowed for backwards compatibility, but is now being removed
from Synapse as most users have updated their client. Further context can be
found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
Synapse 1.19.0 (2020-08-17)

1
changelog.d/7864.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a memory leak by limiting the length of time that messages will be queued for a remote server that has been unreachable.

1
changelog.d/8013.feature Normal file
View File

@@ -0,0 +1 @@
Iteratively encode JSON to avoid blocking the reactor.

1
changelog.d/8037.feature Normal file
View File

@@ -0,0 +1 @@
Use the default template file when its equivalent is not found in a custom template directory.

1
changelog.d/8072.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8074.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8075.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8076.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8081.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix `Re-starting finished log context PUT-nnnn` warning when event persistence failed.

1
changelog.d/8085.misc Normal file
View File

@@ -0,0 +1 @@
Remove some unused database functions.

1
changelog.d/8087.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8090.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to `synapse.handlers.room`.

1
changelog.d/8092.feature Normal file
View File

@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).

1
changelog.d/8093.misc Normal file
View File

@@ -0,0 +1 @@
Return the previous stream token if a non-member event is a duplicate.

1
changelog.d/8100.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8101.bugfix Normal file
View File

@@ -0,0 +1 @@
Synapse now correctly enforces the valid characters in the `client_secret` parameter used in various endpoints.

1
changelog.d/8107.feature Normal file
View File

@@ -0,0 +1 @@
Use the default template file when its equivalent is not found in a custom template directory.

1
changelog.d/8111.doc Normal file
View File

@@ -0,0 +1 @@
Link to matrix-synapse-rest-password-provider in the password provider documentation.

1
changelog.d/8112.misc Normal file
View File

@@ -0,0 +1 @@
Return the previous stream token if a non-member event is a duplicate.

12
debian/changelog vendored
View File

@@ -1,15 +1,3 @@
matrix-synapse-py3 (1.19.2) stable; urgency=medium
* New synapse release 1.19.2.
-- Synapse Packaging team <packages@matrix.org> Wed, 16 Sep 2020 12:50:30 +0100
matrix-synapse-py3 (1.19.1) stable; urgency=medium
* New synapse release 1.19.1.
-- Synapse Packaging team <packages@matrix.org> Thu, 27 Aug 2020 10:50:19 +0100
matrix-synapse-py3 (1.19.0) stable; urgency=medium
[ Synapse Packaging team ]

View File

@@ -14,6 +14,7 @@ password auth provider module implementations:
* [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/)
* [matrix-synapse-shared-secret-auth](https://github.com/devture/matrix-synapse-shared-secret-auth)
* [matrix-synapse-rest-password-provider](https://github.com/ma1uta/matrix-synapse-rest-password-provider)
## Required methods

View File

@@ -2002,9 +2002,7 @@ email:
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
# Do not uncomment this setting unless you want to customise the templates.
#
# Synapse will look for the following templates in this directory:
#

View File

@@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.19.2"
__version__ = "1.19.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View File

@@ -213,6 +213,7 @@ class Auth(object):
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
shadow_banned = user_info["shadow_banned"]
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
@@ -252,7 +253,12 @@ class Auth(object):
opentracing.set_tag("device_id", device_id)
return synapse.types.create_requester(
user, token_id, is_guest, device_id, app_service=app_service
user,
token_id,
is_guest,
shadow_banned,
device_id,
app_service=app_service,
)
except KeyError:
raise MissingClientTokenError()
@@ -297,6 +303,7 @@ class Auth(object):
dict that includes:
`user` (UserID)
`is_guest` (bool)
`shadow_banned` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
@@ -356,6 +363,7 @@ class Auth(object):
ret = {
"user": user,
"is_guest": True,
"shadow_banned": False,
"token_id": None,
# all guests get the same device id
"device_id": GUEST_DEVICE_ID,
@@ -365,6 +373,7 @@ class Auth(object):
ret = {
"user": user,
"is_guest": False,
"shadow_banned": False,
"token_id": None,
"device_id": None,
}
@@ -488,6 +497,7 @@ class Auth(object):
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
"shadow_banned": ret.get("shadow_banned"),
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}

View File

@@ -23,7 +23,7 @@ from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
from synapse.api.presence import UserPresenceState
from synapse.types import RoomID, UserID
FILTER_SCHEMA = {

View File

@@ -17,7 +17,6 @@ from collections import OrderedDict
from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError
from synapse.types import Requester
from synapse.util import Clock
@@ -44,42 +43,6 @@ class Ratelimiter(object):
# * The rate_hz of this particular entry. This can vary per request
self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]]
def can_requester_do_action(
self,
requester: Requester,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
_time_now_s: Optional[int] = None,
) -> Tuple[bool, float]:
"""Can the requester perform the action?
Args:
requester: The requester to key off when rate limiting. The user property
will be used.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set.
update: Whether to count this check as performing the action
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
Returns:
A tuple containing:
* A bool indicating if they can perform the action now
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return True, -1.0
return self.can_do_action(
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
)
def can_do_action(
self,
key: Any,

View File

@@ -18,12 +18,16 @@
import argparse
import errno
import os
import time
import urllib.parse
from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent
from typing import Any, List, MutableMapping, Optional
from typing import Any, Callable, List, MutableMapping, Optional
import attr
import jinja2
import pkg_resources
import yaml
@@ -100,6 +104,11 @@ class Config(object):
def __init__(self, root_config=None):
self.root = root_config
# Get the path to the default Synapse template directory
self.default_template_dir = pkg_resources.resource_filename(
"synapse", "res/templates"
)
def __getattr__(self, item: str) -> Any:
"""
Try and fetch a configuration option that does not exist on this class.
@@ -184,6 +193,95 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
def read_templates(
self, filenames: List[str], custom_template_directory: Optional[str] = None,
) -> List[jinja2.Template]:
"""Load a list of template files from disk using the given variables.
This function will attempt to load the given templates from the default Synapse
template directory. If `custom_template_directory` is supplied, that directory
is tried first.
Files read are treated as Jinja templates. These templates are not rendered yet.
Args:
filenames: A list of template filenames to read.
custom_template_directory: A directory to try to look for the templates
before using the default Synapse template directory instead.
Raises:
ConfigError: if the file's path is incorrect or otherwise cannot be read.
Returns:
A list of jinja2 templates.
"""
templates = []
search_directories = [self.default_template_dir]
# The loader will first look in the custom template directory (if specified) for the
# given filename. If it doesn't find it, it will use the default template dir instead
if custom_template_directory:
# Check that the given template directory exists
if not self.path_exists(custom_template_directory):
raise ConfigError(
"Configured template directory does not exist: %s"
% (custom_template_directory,)
)
# Search the custom template directory as well
search_directories.insert(0, custom_template_directory)
loader = jinja2.FileSystemLoader(search_directories)
env = jinja2.Environment(loader=loader, autoescape=True)
# Update the environment with our custom filters
env.filters.update(
{
"format_ts": _format_ts_filter,
"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
}
)
for filename in filenames:
# Load the template
template = env.get_template(filename)
templates.append(template)
return templates
def _format_ts_filter(value: int, format: str):
return time.strftime(format, time.localtime(value / 1000))
def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
"""Create and return a jinja2 filter that converts MXC urls to HTTP
Args:
public_baseurl: The public, accessible base URL of the homeserver
"""
def mxc_to_http_filter(value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
return ""
server_and_media_id = value[6:]
fragment = None
if "#" in server_and_media_id:
server_and_media_id, fragment = server_and_media_id.split("#", 1)
fragment = "#" + fragment
params = {"width": width, "height": height, "method": resize_method}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
public_baseurl,
server_and_media_id,
urllib.parse.urlencode(params),
fragment or "",
)
return mxc_to_http_filter
class RootConfig(object):
"""

View File

@@ -23,7 +23,6 @@ from enum import Enum
from typing import Optional
import attr
import pkg_resources
from ._base import Config, ConfigError
@@ -98,21 +97,18 @@ class EmailConfig(Config):
if parsed[1] == "":
raise RuntimeError("Invalid notif_from address")
# A user-configurable template directory
template_dir = email_config.get("template_dir")
# we need an absolute path, because we change directory after starting (and
# we don't yet know what auxiliary templates like mail.css we will need).
# (Note that loading as package_resources with jinja.PackageLoader doesn't
# work for the same reason.)
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates")
self.email_template_dir = os.path.abspath(template_dir)
if isinstance(template_dir, str):
# We need an absolute path, because we change directory after starting (and
# we don't yet know what auxiliary templates like mail.css we will need).
template_dir = os.path.abspath(template_dir)
elif template_dir is not None:
# If template_dir is something other than a str or None, warn the user
raise ConfigError("Config option email.template_dir must be type str")
self.email_enable_notifs = email_config.get("enable_notifs", False)
account_validity_config = config.get("account_validity") or {}
account_validity_renewal_enabled = account_validity_config.get("renew_at")
self.threepid_behaviour_email = (
# Have Synapse handle the email sending if account_threepid_delegates.email
# is not defined
@@ -166,19 +162,6 @@ class EmailConfig(Config):
email_config.get("validation_token_lifetime", "1h")
)
if (
self.email_enable_notifs
or account_validity_renewal_enabled
or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
):
# make sure we can import the required deps
import bleach
import jinja2
# prevent unused warnings
jinja2
bleach
if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
missing = []
if not self.email_notif_from:
@@ -196,49 +179,49 @@ class EmailConfig(Config):
# These email templates have placeholders in them, and thus must be
# parsed using a templating engine during a request
self.email_password_reset_template_html = email_config.get(
password_reset_template_html = email_config.get(
"password_reset_template_html", "password_reset.html"
)
self.email_password_reset_template_text = email_config.get(
password_reset_template_text = email_config.get(
"password_reset_template_text", "password_reset.txt"
)
self.email_registration_template_html = email_config.get(
registration_template_html = email_config.get(
"registration_template_html", "registration.html"
)
self.email_registration_template_text = email_config.get(
registration_template_text = email_config.get(
"registration_template_text", "registration.txt"
)
self.email_add_threepid_template_html = email_config.get(
add_threepid_template_html = email_config.get(
"add_threepid_template_html", "add_threepid.html"
)
self.email_add_threepid_template_text = email_config.get(
add_threepid_template_text = email_config.get(
"add_threepid_template_text", "add_threepid.txt"
)
self.email_password_reset_template_failure_html = email_config.get(
password_reset_template_failure_html = email_config.get(
"password_reset_template_failure_html", "password_reset_failure.html"
)
self.email_registration_template_failure_html = email_config.get(
registration_template_failure_html = email_config.get(
"registration_template_failure_html", "registration_failure.html"
)
self.email_add_threepid_template_failure_html = email_config.get(
add_threepid_template_failure_html = email_config.get(
"add_threepid_template_failure_html", "add_threepid_failure.html"
)
# These templates do not support any placeholder variables, so we
# will read them from disk once during setup
email_password_reset_template_success_html = email_config.get(
password_reset_template_success_html = email_config.get(
"password_reset_template_success_html", "password_reset_success.html"
)
email_registration_template_success_html = email_config.get(
registration_template_success_html = email_config.get(
"registration_template_success_html", "registration_success.html"
)
email_add_threepid_template_success_html = email_config.get(
add_threepid_template_success_html = email_config.get(
"add_threepid_template_success_html", "add_threepid_success.html"
)
# Check templates exist
for f in [
# Read all templates from disk
(
self.email_password_reset_template_html,
self.email_password_reset_template_text,
self.email_registration_template_html,
@@ -248,32 +231,36 @@ class EmailConfig(Config):
self.email_password_reset_template_failure_html,
self.email_registration_template_failure_html,
self.email_add_threepid_template_failure_html,
email_password_reset_template_success_html,
email_registration_template_success_html,
email_add_threepid_template_success_html,
]:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find template file %s" % (p,))
password_reset_template_success_html_template,
registration_template_success_html_template,
add_threepid_template_success_html_template,
) = self.read_templates(
[
password_reset_template_html,
password_reset_template_text,
registration_template_html,
registration_template_text,
add_threepid_template_html,
add_threepid_template_text,
password_reset_template_failure_html,
registration_template_failure_html,
add_threepid_template_failure_html,
password_reset_template_success_html,
registration_template_success_html,
add_threepid_template_success_html,
],
template_dir,
)
# Retrieve content of web templates
filepath = os.path.join(
self.email_template_dir, email_password_reset_template_success_html
# Render templates that do not contain any placeholders
self.email_password_reset_template_success_html_content = (
password_reset_template_success_html_template.render()
)
self.email_password_reset_template_success_html = self.read_file(
filepath, "email.password_reset_template_success_html"
self.email_registration_template_success_html_content = (
registration_template_success_html_template.render()
)
filepath = os.path.join(
self.email_template_dir, email_registration_template_success_html
)
self.email_registration_template_success_html_content = self.read_file(
filepath, "email.registration_template_success_html"
)
filepath = os.path.join(
self.email_template_dir, email_add_threepid_template_success_html
)
self.email_add_threepid_template_success_html_content = self.read_file(
filepath, "email.add_threepid_template_success_html"
self.email_add_threepid_template_success_html_content = (
add_threepid_template_success_html_template.render()
)
if self.email_enable_notifs:
@@ -290,17 +277,19 @@ class EmailConfig(Config):
% (", ".join(missing),)
)
self.email_notif_template_html = email_config.get(
notif_template_html = email_config.get(
"notif_template_html", "notif_mail.html"
)
self.email_notif_template_text = email_config.get(
notif_template_text = email_config.get(
"notif_template_text", "notif_mail.txt"
)
for f in self.email_notif_template_text, self.email_notif_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p,))
(
self.email_notif_template_html,
self.email_notif_template_text,
) = self.read_templates(
[notif_template_html, notif_template_text], template_dir,
)
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
@@ -309,18 +298,20 @@ class EmailConfig(Config):
"client_base_url", email_config.get("riot_base_url", None)
)
if account_validity_renewal_enabled:
self.email_expiry_template_html = email_config.get(
if self.account_validity.renew_by_email_enabled:
expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html"
)
self.email_expiry_template_text = email_config.get(
expiry_template_text = email_config.get(
"expiry_template_text", "notice_expiry.txt"
)
for f in self.email_expiry_template_text, self.email_expiry_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p,))
(
self.account_validity_template_html,
self.account_validity_template_text,
) = self.read_templates(
[expiry_template_html, expiry_template_text], template_dir,
)
subjects_config = email_config.get("subjects", {})
subjects = {}
@@ -400,9 +391,7 @@ class EmailConfig(Config):
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
# Do not uncomment this setting unless you want to customise the templates.
#
# Synapse will look for the following templates in this directory:
#

View File

@@ -18,8 +18,6 @@ import logging
from typing import Any, List
import attr
import jinja2
import pkg_resources
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module
@@ -171,15 +169,9 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "15m")
)
template_dir = saml2_config.get("template_dir")
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
loader = jinja2.FileSystemLoader(template_dir)
# enable auto-escape here, to having to remember to escape manually in the
# template
env = jinja2.Environment(loader=loader, autoescape=True)
self.saml2_error_html_template = env.get_template("saml_error.html")
self.saml2_error_html_template = self.read_templates(
["saml_error.html"], saml2_config.get("template_dir")
)
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set

View File

@@ -26,7 +26,6 @@ import yaml
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@@ -508,8 +507,6 @@ class ServerConfig(Config):
)
)
_check_resource_config(self.listeners)
self.cleanup_extremities_with_dummy_events = config.get(
"cleanup_extremities_with_dummy_events", True
)
@@ -1133,20 +1130,3 @@ def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
if name == "webclient":
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
resource_names = {
res_name
for listener in listeners
if listener.http_options
for res in listener.http_options.resources
for res_name in res.names
}
for resource in resource_names:
if resource == "consent":
try:
check_requirements("resources.consent")
except DependencyException as e:
raise ConfigError(e.message)

View File

@@ -12,11 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict
import pkg_resources
from ._base import Config
@@ -29,22 +26,32 @@ class SSOConfig(Config):
def read_config(self, config, **kwargs):
sso_config = config.get("sso") or {} # type: Dict[str, Any]
# Pick a template directory in order of:
# * The sso-specific template_dir
# * /path/to/synapse/install/res/templates
# The sso-specific template_dir
template_dir = sso_config.get("template_dir")
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
self.sso_template_dir = template_dir
self.sso_account_deactivated_template = self.read_file(
os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
"sso_account_deactivated_template",
# Read templates from disk
(
self.sso_redirect_confirm_template,
self.sso_auth_confirm_template,
self.sso_error_template,
sso_account_deactivated_template,
sso_auth_success_template,
) = self.read_templates(
[
"sso_redirect_confirm.html",
"sso_auth_confirm.html",
"sso_error.html",
"sso_account_deactivated.html",
"sso_auth_success.html",
],
template_dir,
)
self.sso_auth_success_template = self.read_file(
os.path.join(self.sso_template_dir, "sso_auth_success.html"),
"sso_auth_success_template",
# These templates have no placeholders, so render them here
self.sso_account_deactivated_template = (
sso_account_deactivated_template.render()
)
self.sso_auth_success_template = sso_auth_success_template.render()
self.sso_client_whitelist = sso_config.get("client_whitelist") or []

View File

@@ -54,7 +54,7 @@ from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id
from synapse.types import JsonDict
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -217,15 +217,13 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"]
]
pdus[:] = await self._check_sigs_and_hash_and_fetch(
dest,
list(pdus),
outlier=True,
room_version=room_version,
# FIXME: We should handle signature failures more gracefully.
pdus[:] = await make_deferred_yieldable(
defer.gatherResults(
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
).addErrback(unwrapFirstError)
)
logger.info("DDD pdus ended up as: %s", pdus)
return pdus
async def get_pdu(
@@ -388,11 +386,10 @@ class FederationClient(FederationBase):
pdu.event_id, allow_rejected=True, allow_none=True
)
pdu_origin = get_domain_from_id(pdu.sender)
if not res and pdu_origin != origin:
if not res and pdu.origin != origin:
try:
res = await self.get_pdu(
destinations=[pdu_origin],
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,

View File

@@ -37,8 +37,8 @@ from sortedcontainers import SortedDict
from twisted.internet import defer
from synapse.api.presence import UserPresenceState
from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
from .units import Edu

View File

@@ -22,6 +22,7 @@ from twisted.internet import defer
import synapse
import synapse.metrics
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
@@ -39,7 +40,6 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.metrics import Measure, measure_func

View File

@@ -24,12 +24,12 @@ from synapse.api.errors import (
HttpResponseException,
RequestSendFailed,
)
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.units import Edu
from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@@ -337,6 +337,28 @@ class PerDestinationQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0
),
)
if e.retry_interval > 60 * 60 * 1000:
# we won't retry for another hour!
# (this suggests a significant outage)
# We drop pending PDUs and EDUs because otherwise they will
# rack up indefinitely.
# Note that:
# - the EDUs that are being dropped here are those that we can
# afford to drop (specifically, only typing notifications,
# read receipts and presence updates are being dropped here)
# - Other EDUs such as to_device messages are queued with a
# different mechanism
# - this is all volatile state that would be lost if the
# federation sender restarted anyway
# dropping read receipts is a bit sad but should be solved
# through another mechanism, because this is all volatile!
self._pending_pdus = []
self._pending_edus = []
self._pending_edus_keyed = {}
self._pending_presence = {}
self._pending_rrs = {}
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:

View File

@@ -719,6 +719,27 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise NotImplementedError()
async def change_user_admin_in_group(
self, group_id, user_id, want_admin, requester_user_id, content
):
"""Promotes or demotes a user in a group.
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
if requester_user_id == user_id:
raise SynapseError(400, "User cannot target themselves")
is_admin = await self.store.is_user_admin_in_group(
group_id, requester_user_id
)
if not is_admin:
raise SynapseError(403, "User is not admin in group")
await self.store.change_user_admin_in_group(group_id, user_id, want_admin)
return {}
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):

View File

@@ -26,11 +26,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
try:
from synapse.push.mailer import load_jinja2_templates
except ImportError:
load_jinja2_templates = None
logger = logging.getLogger(__name__)
@@ -47,9 +42,11 @@ class AccountValidityHandler(object):
if (
self._account_validity.enabled
and self._account_validity.renew_by_email_enabled
and load_jinja2_templates
):
# Don't do email-specific configuration if renewal by email is disabled.
self._template_html = self.config.account_validity_template_html
self._template_text = self.config.account_validity_template_text
try:
app_name = self.hs.config.email_app_name
@@ -65,17 +62,6 @@ class AccountValidityHandler(object):
self._raw_from = email.utils.parseaddr(self._from_string)[1]
self._template_html, self._template_text = load_jinja2_templates(
self.config.email_template_dir,
[
self.config.email_expiry_template_html,
self.config.email_expiry_template_text,
],
apply_format_ts_filter=True,
apply_mxc_to_http_filter=True,
public_baseurl=self.config.public_baseurl,
)
# Check the renewal emails to send and send them every 30min.
def send_emails():
# run as a background process to make sure that the database transactions

View File

@@ -42,7 +42,6 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.threepids import canonicalise_email
@@ -132,18 +131,17 @@ class AuthHandler(BaseHandler):
# after the SSO completes and before redirecting them back to their client.
# It notifies the user they are about to give access to their matrix account
# to the client.
self._sso_redirect_confirm_template = load_jinja2_templates(
hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
)[0]
self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
self._sso_auth_confirm_template = load_jinja2_templates(
hs.config.sso_template_dir, ["sso_auth_confirm.html"],
)[0]
self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
self._sso_auth_success_template = hs.config.sso_auth_success_template
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (

View File

@@ -461,6 +461,25 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
async def change_user_admin_in_group(
self, group_id, user_id, want_admin, requester_user_id, content
):
"""Promotes or demotes a user in a group.
"""
if not self.is_mine_id(user_id):
raise SynapseError(400, "User not on this server")
# TODO: We should probably support federation, but this is fine for now
if not self.is_mine_id(group_id):
raise SynapseError(400, "Group not on this server")
res = await self.groups_server_handler.change_user_admin_in_group(
group_id, user_id, want_admin, requester_user_id, content
)
return res
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):

View File

@@ -667,14 +667,14 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = await self.deduplicate_state_event(event, context)
if prev_state is not None:
prev_event = await self.deduplicate_state_event(event, context)
if prev_event is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
event.event_id,
prev_state.event_id,
prev_event.event_id,
)
return prev_state
return await self.store.get_stream_id_for_event(prev_event.event_id)
return await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
@@ -682,27 +682,32 @@ class EventCreationHandler(object):
async def deduplicate_state_event(
self, event: EventBase, context: EventContext
) -> None:
) -> Optional[EventBase]:
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
Args:
event: The event to check for duplication.
context: The event context.
Returns:
The previous verion of the event is returned, if it is found in the
event context. Otherwise, None is returned.
"""
prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return
return None
prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
return None
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
return prev_event
return
return None
async def create_and_send_nonmember_event(
self,
@@ -891,9 +896,7 @@ class EventCreationHandler(object):
except Exception:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
run_in_background(
self.store.remove_push_actions_from_staging, event.event_id
)
await self.store.remove_push_actions_from_staging(event.event_id)
raise
async def _validate_canonical_alias(

View File

@@ -38,7 +38,6 @@ from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.push.mailer import load_jinja2_templates
from synapse.types import UserID, map_username_to_mxid_localpart
if TYPE_CHECKING:
@@ -123,9 +122,7 @@ class OidcHandler:
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = load_jinja2_templates(
hs.config.sso_template_dir, ["sso_error.html"]
)[0]
self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"

View File

@@ -33,13 +33,13 @@ from typing_extensions import ContextManager
import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.storage.presence import UserPresenceState
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached

View File

@@ -142,6 +142,7 @@ class RegistrationHandler(BaseHandler):
address=None,
bind_emails=[],
by_admin=False,
shadow_banned=False,
):
"""Registers a new client on the server.
@@ -159,6 +160,7 @@ class RegistrationHandler(BaseHandler):
bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the
admin api, otherwise False.
shadow_banned (bool): Shadow-ban the created user.
Returns:
str: user_id
Raises:
@@ -194,6 +196,7 @@ class RegistrationHandler(BaseHandler):
admin=admin,
user_type=user_type,
address=address,
shadow_banned=shadow_banned,
)
if self.hs.config.user_directory_search_all_users:
@@ -224,6 +227,7 @@ class RegistrationHandler(BaseHandler):
make_guest=make_guest,
create_profile_with_displayname=default_display_name,
address=address,
shadow_banned=shadow_banned,
)
# Successfully registered
@@ -529,6 +533,7 @@ class RegistrationHandler(BaseHandler):
admin=False,
user_type=None,
address=None,
shadow_banned=False,
):
"""Register user in the datastore.
@@ -546,6 +551,7 @@ class RegistrationHandler(BaseHandler):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the registration.
shadow_banned (bool): Whether to shadow-ban the user
Returns:
Awaitable
@@ -561,6 +567,7 @@ class RegistrationHandler(BaseHandler):
admin=admin,
user_type=user_type,
address=address,
shadow_banned=shadow_banned,
)
else:
return self.store.register_user(
@@ -572,6 +579,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
shadow_banned=shadow_banned,
)
async def register_device(

View File

@@ -22,7 +22,7 @@ import logging
import math
import string
from collections import OrderedDict
from typing import Awaitable, Optional, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import (
EventTypes,
@@ -32,11 +32,14 @@ from synapse.api.constants import (
RoomEncryptionAlgorithms,
)
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
Requester,
RoomAlias,
RoomID,
@@ -53,6 +56,9 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
@@ -61,7 +67,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
class RoomCreationHandler(BaseHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker()
@@ -92,7 +98,7 @@ class RoomCreationHandler(BaseHandler):
"guest_can_join": False,
"power_level_content_override": {},
},
}
} # type: Dict[str, Dict[str, Any]]
# Modify presets to selectively enable encryption by default per homeserver config
for preset_name, preset_config in self._presets_dict.items():
@@ -215,6 +221,9 @@ class RoomCreationHandler(BaseHandler):
old_room_state = await tombstone_context.get_current_state_ids()
# We know the tombstone event isn't an outlier so it has current state.
assert old_room_state is not None
# update any aliases
await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
@@ -528,17 +537,21 @@ class RoomCreationHandler(BaseHandler):
logger.error("Unable to send updated alias events in new room: %s", e)
async def create_room(
self, requester, config, ratelimit=True, creator_join_profile=None
self,
requester: Requester,
config: JsonDict,
ratelimit: bool = True,
creator_join_profile: Optional[JsonDict] = None,
) -> Tuple[dict, int]:
""" Creates a new room.
Args:
requester (synapse.types.Requester):
requester:
The user who requested the room creation.
config (dict) : A dict of configuration options.
ratelimit (bool): set to False to disable the rate limiter
config : A dict of configuration options.
ratelimit: set to False to disable the rate limiter
creator_join_profile (dict|None):
creator_join_profile:
Set to override the displayname and avatar for the creating
user in this room. If unset, displayname and avatar will be
derived from the user's profile. If set, should contain the
@@ -601,6 +614,7 @@ class RoomCreationHandler(BaseHandler):
Codes.UNSUPPORTED_ROOM_VERSION,
)
room_alias = None
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
@@ -611,8 +625,6 @@ class RoomCreationHandler(BaseHandler):
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
else:
room_alias = None
invite_list = config.get("invite", [])
for i in invite_list:
@@ -771,23 +783,30 @@ class RoomCreationHandler(BaseHandler):
async def _send_events_for_new_room(
self,
creator, # A Requester object.
room_id,
preset_config,
invite_list,
initial_state,
creation_content,
room_alias=None,
power_level_content_override=None, # Doesn't apply when initial state has power level state event content
creator_join_profile=None,
creator: Requester,
room_id: str,
preset_config: str,
invite_list: List[str],
initial_state: StateMap,
creation_content: JsonDict,
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
) -> int:
"""Sends the initial events into a new room.
`power_level_content_override` doesn't apply when initial state has
power level state event content.
Returns:
The stream_id of the last event persisted.
"""
def create(etype, content, **kwargs):
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
e = {"type": etype, "content": content}
e.update(event_keys)
@@ -795,7 +814,7 @@ class RoomCreationHandler(BaseHandler):
return e
async def send(etype, content, **kwargs) -> int:
async def send(etype: str, content: JsonDict, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
(
@@ -808,10 +827,6 @@ class RoomCreationHandler(BaseHandler):
config = self._presets_dict[preset_config]
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
await send(etype=EventTypes.Create, content=creation_content)
@@ -852,7 +867,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 50,
}
} # type: JsonDict
if config["original_invitees_have_ops"]:
for invitee in invite_list:
@@ -906,7 +921,7 @@ class RoomCreationHandler(BaseHandler):
return last_sent_stream_id
async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,
self, creator_id: str, is_public: bool, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -930,23 +945,30 @@ class RoomCreationHandler(BaseHandler):
class RoomContextHandler(object):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
async def get_event_context(self, user, room_id, event_id, limit, event_filter):
async def get_event_context(
self,
user: UserID,
room_id: str,
event_id: str,
limit: int,
event_filter: Optional[Filter],
) -> Optional[JsonDict]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
user (UserID)
room_id (str)
event_id (str)
limit (int): The maximum number of events to return in total
user
room_id
event_id
limit: The maximum number of events to return in total
(excluding state).
event_filter (Filter|None): the filter to apply to the events returned
event_filter: the filter to apply to the events returned
(excluding the target event_id)
Returns:
@@ -1033,12 +1055,18 @@ class RoomContextHandler(object):
class RoomEventSource(object):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
async def get_new_events(
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
self,
user: UserID,
from_key: str,
limit: int,
room_ids: List[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], str]:
# We just ignore the key for now.
to_key = self.get_current_key()
@@ -1096,7 +1124,7 @@ class RoomShutdownHandler(object):
)
DEFAULT_ROOM_NAME = "Content Violation Notification"
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.room_member_handler = hs.get_room_member_handler()
self._room_creation_handler = hs.get_room_creation_handler()

View File

@@ -210,40 +210,24 @@ class RoomMemberHandler(object):
_, stream_id = await self.store.get_event_ordering(duplicate.event_id)
return duplicate.event_id, stream_id
prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
newly_joined = False
if event.membership == Membership.JOIN:
newly_joined = True
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
if newly_joined:
time_now_s = self.clock.time()
(
allowed,
time_allowed,
) = self._join_rate_limiter_local.can_requester_do_action(requester)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
if event.membership == Membership.JOIN and newly_joined:
prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
await self._user_joined_room(target, room_id)
newly_joined = True
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -473,12 +457,22 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
if is_host_in_room:
time_now_s = self.clock.time()
(
allowed,
time_allowed,
) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
allowed, time_allowed = self._join_rate_limiter_local.can_do_action(
requester.user.to_string(),
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
else:
time_now_s = self.clock.time()
allowed, time_allowed = self._join_rate_limiter_remote.can_do_action(
requester.user.to_string(),
)
if not allowed:
raise LimitExceededError(

View File

@@ -441,7 +441,6 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
"""
logger.debug(
"{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id,
@@ -450,7 +449,6 @@ class MatrixFederationHttpClient(object):
url_str,
_sec_timeout,
)
"""
outgoing_requests_counter.labels(request.method).inc()

View File

@@ -22,12 +22,13 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
from typing import Any, Callable, Dict, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
import jinja2
from canonicaljson import encode_canonical_json, encode_pretty_printed_json
from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet import defer, interfaces
from twisted.python import failure
from twisted.web import resource
from twisted.web.server import NOT_DONE_YET, Request
@@ -499,6 +500,78 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
pass
@implementer(interfaces.IPullProducer)
class _ByteProducer:
"""
Iteratively write bytes to the request.
"""
# The minimum number of bytes for each chunk. Note that the last chunk will
# usually be smaller than this.
min_chunk_size = 1024
def __init__(
self, request: Request, iterator: Iterator[bytes],
):
self._request = request
self._iterator = iterator
def start(self) -> None:
self._request.registerProducer(self, False)
def _send_data(self, data: List[bytes]) -> None:
"""
Send a list of strings as a response to the request.
"""
if not data:
return
self._request.write(b"".join(data))
def resumeProducing(self) -> None:
# We've stopped producing in the meantime (note that this might be
# re-entrant after calling write).
if not self._request:
return
# Get the next chunk and write it to the request.
#
# The output of the JSON encoder is coalesced until min_chunk_size is
# reached. (This is because JSON encoders produce a very small output
# per iteration.)
#
# Note that buffer stores a list of bytes (instead of appending to
# bytes) to hopefully avoid many allocations.
buffer = []
buffered_bytes = 0
while buffered_bytes < self.min_chunk_size:
try:
data = next(self._iterator)
buffer.append(data)
buffered_bytes += len(data)
except StopIteration:
# The entire JSON object has been serialized, write any
# remaining data, finalize the producer and the request, and
# clean-up any references.
self._send_data(buffer)
self._request.unregisterProducer()
self._request.finish()
self.stopProducing()
return
self._send_data(buffer)
def stopProducing(self) -> None:
self._request = None
def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
"""
Encode an object into JSON. Returns an iterator of bytes.
"""
for chunk in json_encoder.iterencode(json_object):
yield chunk.encode("utf-8")
def respond_with_json(
request: Request,
code: int,
@@ -533,15 +606,23 @@ def respond_with_json(
return None
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + b"\n"
encoder = iterencode_pretty_printed_json
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes
json_bytes = encode_canonical_json(json_object)
encoder = iterencode_canonical_json
else:
json_bytes = json_encoder.encode(json_object).encode("utf-8")
encoder = _encode_json_bytes
return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
if send_cors:
set_cors_headers(request)
producer = _ByteProducer(request, encoder(json_object))
producer.start()
return NOT_DONE_YET
def respond_with_json_bytes(

View File

@@ -22,7 +22,6 @@ _TIME_FUNC_ID = 0
def _log_debug_as_f(f, msg, msg_args):
return
name = f.__module__
logger = logging.getLogger(name)

View File

@@ -16,8 +16,7 @@
import email.mime.multipart
import email.utils
import logging
import time
import urllib
import urllib.parse
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Iterable, List, TypeVar
@@ -640,72 +639,3 @@ def string_ordinal_total(s):
for c in s:
tot += ord(c)
return tot
def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000))
def load_jinja2_templates(
template_dir,
template_filenames,
apply_format_ts_filter=False,
apply_mxc_to_http_filter=False,
public_baseurl=None,
):
"""Loads and returns one or more jinja2 templates and applies optional filters
Args:
template_dir (str): The directory where templates are stored
template_filenames (list[str]): A list of template filenames
apply_format_ts_filter (bool): Whether to apply a template filter that formats
timestamps
apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts
mxc urls to http urls
public_baseurl (str|None): The public baseurl of the server. Required for
apply_mxc_to_http_filter to be enabled
Returns:
A list of jinja2 templates corresponding to the given list of filenames,
with order preserved
"""
logger.info(
"loading email templates %s from '%s'", template_filenames, template_dir
)
loader = jinja2.FileSystemLoader(template_dir)
env = jinja2.Environment(loader=loader)
if apply_format_ts_filter:
env.filters["format_ts"] = format_ts_filter
if apply_mxc_to_http_filter and public_baseurl:
env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl)
templates = []
for template_filename in template_filenames:
template = env.get_template(template_filename)
templates.append(template)
return templates
def _create_mxc_to_http_filter(public_baseurl):
def mxc_to_http_filter(value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
return ""
serverAndMediaId = value[6:]
fragment = None
if "#" in serverAndMediaId:
(serverAndMediaId, fragment) = serverAndMediaId.split("#", 1)
fragment = "#" + fragment
params = {"width": width, "height": height, "method": resize_method}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
public_baseurl,
serverAndMediaId,
urllib.parse.urlencode(params),
fragment or "",
)
return mxc_to_http_filter

View File

@@ -15,22 +15,13 @@
import logging
from synapse.push.emailpusher import EmailPusher
from synapse.push.mailer import Mailer
from .httppusher import HttpPusher
logger = logging.getLogger(__name__)
# We try importing this if we can (it will fail if we don't
# have the optional email dependencies installed). We don't
# yet have the config to know if we need the email pusher,
# but importing this after daemonizing seems to fail
# (even though a simple test of importing from a daemonized
# process works fine)
try:
from synapse.push.emailpusher import EmailPusher
from synapse.push.mailer import Mailer, load_jinja2_templates
except Exception:
pass
class PusherFactory(object):
def __init__(self, hs):
@@ -43,16 +34,8 @@ class PusherFactory(object):
if hs.config.email_enable_notifs:
self.mailers = {} # app_name -> Mailer
self.notif_template_html, self.notif_template_text = load_jinja2_templates(
self.config.email_template_dir,
[
self.config.email_notif_template_html,
self.config.email_notif_template_text,
],
apply_format_ts_filter=True,
apply_mxc_to_http_filter=True,
public_baseurl=self.config.public_baseurl,
)
self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text
self.pusher_types["email"] = self._create_email_pusher
@@ -73,8 +56,8 @@ class PusherFactory(object):
mailer = Mailer(
hs=self.hs,
app_name=app_name,
template_html=self.notif_template_html,
template_text=self.notif_template_text,
template_html=self._notif_template_html,
template_text=self._notif_template_text,
)
self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer)

View File

@@ -43,7 +43,7 @@ REQUIREMENTS = [
"jsonschema>=2.5.1",
"frozendict>=1",
"unpaddedbase64>=1.1.0",
"canonicaljson>=1.2.0",
"canonicaljson>=1.3.0",
# we use the type definitions added in signedjson 1.1.
"signedjson>=1.1.0",
"pynacl>=1.2.1",
@@ -78,8 +78,6 @@ CONDITIONAL_REQUIREMENTS = {
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
# we use execute_batch, which arrived in psycopg 2.7.
"postgres": ["psycopg2>=2.7"],
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
"resources.consent": ["Jinja2>=2.9"],
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
"acme": [

View File

@@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin,
user_type,
address,
shadow_banned,
):
"""
Args:
@@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
shadow_banned (bool): Whether to shadow-ban the user
"""
return {
"password_hash": password_hash,
@@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"admin": admin,
"user_type": user_type,
"address": address,
"shadow_banned": shadow_banned,
}
async def _handle_request(self, request, user_id):
@@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin=content["admin"],
user_type=content["user_type"],
address=content["address"],
shadow_banned=content["shadow_banned"],
)
return 200, {}

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import (
NotFoundError,
StoreError,
@@ -163,7 +162,7 @@ class PushRuleRestServlet(RestServlet):
stream_id, _ = self.store.get_push_rules_stream_token()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
def set_rule_attr(self, user_id, spec, val):
async def set_rule_attr(self, user_id, spec, val):
if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -173,7 +172,9 @@ class PushRuleRestServlet(RestServlet):
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
return await self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val
)
elif spec["attr"] == "actions":
actions = val.get("actions")
_check_actions(actions)
@@ -188,7 +189,7 @@ class PushRuleRestServlet(RestServlet):
if namespaced_rule_id not in rule_ids:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
return self.store.set_push_rule_actions(
return await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:

View File

@@ -32,7 +32,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.push.mailer import Mailer, load_jinja2_templates
from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
@@ -53,21 +53,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
template_html, template_text = load_jinja2_templates(
self.config.email_template_dir,
[
self.config.email_password_reset_template_html,
self.config.email_password_reset_template_text,
],
apply_format_ts_filter=True,
apply_mxc_to_http_filter=True,
public_baseurl=self.config.public_baseurl,
)
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
template_html=template_html,
template_text=template_text,
template_html=self.config.email_password_reset_template_html,
template_text=self.config.email_password_reset_template_text,
)
async def on_POST(self, request):
@@ -169,9 +159,8 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
(self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_password_reset_template_failure_html],
self._failure_email_template = (
self.config.email_password_reset_template_failure_html
)
async def on_GET(self, request, medium):
@@ -214,14 +203,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
return None
# Otherwise show the success template
html = self.config.email_password_reset_template_success_html
html = self.config.email_password_reset_template_success_html_content
status_code = 200
except ThreepidValidationError as e:
status_code = e.code
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
html = self.failure_email_template.render(**template_vars)
html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
@@ -411,19 +400,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
template_html, template_text = load_jinja2_templates(
self.config.email_template_dir,
[
self.config.email_add_threepid_template_html,
self.config.email_add_threepid_template_text,
],
public_baseurl=self.config.public_baseurl,
)
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
template_html=template_html,
template_text=template_text,
template_html=self.config.email_add_threepid_template_html,
template_text=self.config.email_add_threepid_template_text,
)
async def on_POST(self, request):
@@ -578,9 +559,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
(self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_add_threepid_template_failure_html],
self._failure_email_template = (
self.config.email_add_threepid_template_failure_html
)
async def on_GET(self, request):
@@ -631,7 +611,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
html = self.failure_email_template.render(**template_vars)
html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)

View File

@@ -548,6 +548,31 @@ class GroupAdminUsersKickServlet(RestServlet):
return 200, result
class GroupAdminChangeAdminServlet(RestServlet):
"""Promote or demote a user in the group
"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/admins/(?P<user_id>[^/]*)$"
)
def __init__(self, hs):
super(GroupAdminChangeAdminServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
async def on_POST(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
want_admin = content["is_admin"]
result = await self.groups_handler.change_user_admin_in_group(
group_id, user_id, want_admin, requester_user_id, content
)
return 200, result
class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group
@@ -722,6 +747,7 @@ def register_servlets(hs, http_server):
GroupAdminRoomsConfigServlet(hs).register(http_server)
GroupAdminUsersInviteServlet(hs).register(http_server)
GroupAdminUsersKickServlet(hs).register(http_server)
GroupAdminChangeAdminServlet(hs).register(http_server)
GroupSelfLeaveServlet(hs).register(http_server)
GroupSelfJoinServlet(hs).register(http_server)
GroupSelfAcceptInviteServlet(hs).register(http_server)

View File

@@ -44,7 +44,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.push.mailer import load_jinja2_templates
from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -81,23 +81,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
self.config = hs.config
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
from synapse.push.mailer import Mailer, load_jinja2_templates
template_html, template_text = load_jinja2_templates(
self.config.email_template_dir,
[
self.config.email_registration_template_html,
self.config.email_registration_template_text,
],
apply_format_ts_filter=True,
apply_mxc_to_http_filter=True,
public_baseurl=self.config.public_baseurl,
)
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
template_html=template_html,
template_text=template_text,
template_html=self.config.email_registration_template_html,
template_text=self.config.email_registration_template_text,
)
async def on_POST(self, request):
@@ -262,15 +250,8 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
(self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_registration_template_failure_html],
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
(self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_registration_template_failure_html],
self._failure_email_template = (
self.config.email_registration_template_failure_html
)
async def on_GET(self, request, medium):
@@ -318,7 +299,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
html = self.failure_email_template.render(**template_vars)
html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)

View File

@@ -15,12 +15,12 @@
import logging
from typing import Dict, Set
from canonicaljson import encode_canonical_json, json
from canonicaljson import json
from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
logger = logging.getLogger(__name__)
@@ -223,4 +223,4 @@ class RemoteKey(DirectServeJsonResource):
results = {"server_keys": signed_keys}
respond_with_json_bytes(request, 200, encode_canonical_json(results))
respond_with_json(request, 200, results, canonical_json=True)

View File

@@ -18,8 +18,6 @@ from typing import Optional
from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines
@@ -308,9 +306,8 @@ class BackgroundUpdater(object):
update_name (str): Name of update
"""
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
async def noop_update(progress, batch_size):
await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, noop_update)
@@ -409,12 +406,11 @@ class BackgroundUpdater(object):
else:
runner = create_index_sqlite
@defer.inlineCallbacks
def updater(progress, batch_size):
async def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
yield self.db_pool.runWithConnection(runner)
yield self._end_background_update(update_name)
await self.db_pool.runWithConnection(runner)
await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, updater)

View File

@@ -332,8 +332,7 @@ class DatabasePool(object):
"""
return self._db_pool.running
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
async def _check_safe_to_upsert(self):
"""
Is it safe to use native UPSERT?
@@ -342,7 +341,7 @@ class DatabasePool(object):
If the background updates have not completed, wait 15 sec and check again.
"""
updates = yield self.simple_select_list(
updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
@@ -614,8 +613,7 @@ class DatabasePool(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
@defer.inlineCallbacks
def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@@ -631,7 +629,7 @@ class DatabasePool(object):
`or_ignore` is True
"""
try:
yield self.runInteraction(desc, self.simple_insert_txn, table, values)
await self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@@ -684,8 +682,7 @@ class DatabasePool(object):
txn.executemany(sql, vals)
@defer.inlineCallbacks
def simple_upsert(
async def simple_upsert(
self,
table,
keyvalues,
@@ -714,14 +711,14 @@ class DatabasePool(object):
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
Deferred(None or bool): Native upserts always return None. Emulated
None or bool: Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
"""
attempts = 0
while True:
try:
result = yield self.runInteraction(
return await self.runInteraction(
desc,
self.simple_upsert_txn,
table,
@@ -730,7 +727,6 @@ class DatabasePool(object):
insertion_values,
lock=lock,
)
return result
except self.engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
@@ -1121,8 +1117,7 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def simple_select_many_batch(
async def simple_select_many_batch(
self,
table,
column,
@@ -1156,7 +1151,7 @@ class DatabasePool(object):
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
rows = await self.runInteraction(
desc,
self.simple_select_many_txn,
table,

View File

@@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore(
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
A Deferred which resolves when the state was set successfully.
An Awaitable which resolves when the state was set successfully.
"""
return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}

View File

@@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = yield self.db_pool.simple_select_many_batch(
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,

View File

@@ -257,11 +257,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_in_room(self, room_id):
return self.db_pool.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
return self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
@@ -303,14 +298,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self.db_pool.simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
def get_prev_events_for_room(self, room_id: str):
"""
Gets a subset of the current forward extremities in the given room.

View File

@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3
self._rotate_count = 10000
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
ret = yield self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
return ret
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id

View File

@@ -17,13 +17,11 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
import attr
from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
@@ -113,15 +111,14 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
@defer.inlineCallbacks
def _persist_events_and_state_updates(
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
):
) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -136,7 +133,7 @@ class PersistEventsStore:
backfilled
Returns:
Deferred: resolves when the events have been persisted
Resolves when the events have been persisted
"""
# We want to calculate the stream orderings as late as possible, as
@@ -168,7 +165,7 @@ class PersistEventsStore:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
@@ -206,16 +203,15 @@ class PersistEventsStore:
(room_id,), list(latest_event_ids)
)
@defer.inlineCallbacks
def _get_events_which_are_prevs(self, event_ids):
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
Args:
event_ids (Iterable[str]): event ids to filter
event_ids: event ids to filter
Returns:
Deferred[List[str]]: filtered event ids
Filtered event ids
"""
results = []
@@ -240,14 +236,13 @@ class PersistEventsStore:
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
return results
@defer.inlineCallbacks
def _get_prevs_before_rejected(self, event_ids):
async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
"""Get soft-failed ancestors to remove from the extremities.
Given a set of events, find all those that have been soft-failed or
@@ -259,11 +254,11 @@ class PersistEventsStore:
are separated by soft failed events.
Args:
event_ids (Iterable[str]): Events to find prev events for. Note
that these must have already been persisted.
event_ids: Events to find prev events for. Note that these must have
already been persisted.
Returns:
Deferred[set[str]]
The previous events.
"""
# The set of event_ids to return. This includes all soft-failed events
@@ -304,7 +299,7 @@ class PersistEventsStore:
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)

View File

@@ -15,8 +15,6 @@
import logging
from twisted.internet import defer
from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
@@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="NOT have_censored",
)
@defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size):
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
result = yield self.db_pool.runInteraction(
result = await self.db_pool.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
@defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
async def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows_to_update)
result = yield self.db_pool.runInteraction(
result = await self.db_pool.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
@defer.inlineCallbacks
def _cleanup_extremities_bg_update(self, progress, batch_size):
async def _cleanup_extremities_bg_update(self, progress, batch_size):
"""Background update to clean out extremities that should have been
deleted previously.
@@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(original_set)
num_handled = yield self.db_pool.runInteraction(
num_handled = await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
return num_handled
@defer.inlineCallbacks
def _redactions_received_ts(self, progress, batch_size):
async def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions.
"""
last_event_id = progress.get("last_event_id", "")
@@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
count = yield self.db_pool.runInteraction(
count = await self.db_pool.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
yield self.db_pool.updates._end_background_update("redactions_received_ts")
await self.db_pool.updates._end_background_update("redactions_received_ts")
return count
@defer.inlineCallbacks
def _event_fix_redactions_bytes(self, progress, batch_size):
async def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
@@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute("DROP INDEX redactions_censored_redacts")
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
return 1
@defer.inlineCallbacks
def _event_store_labels(self, progress, batch_size):
async def _event_store_labels(self, progress, batch_size):
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
@@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return nbrows
num_rows = yield self.db_pool.runInteraction(
num_rows = await self.db_pool.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
if not num_rows:
yield self.db_pool.updates._end_background_update("event_store_labels")
await self.db_pool.updates._end_background_update("event_store_labels")
return num_rows

View File

@@ -43,7 +43,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -137,42 +137,6 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts",
)
def get_received_ts_by_stream_pos(self, stream_ordering):
"""Given a stream ordering get an approximate timestamp of when it
happened.
This is done by simply taking the received ts of the first event that
has a stream ordering greater than or equal to the given stream pos.
If none exists returns the current time, on the assumption that it must
have happened recently.
Args:
stream_ordering (int)
Returns:
Deferred[int]
"""
def _get_approximate_received_ts_txn(txn):
sql = """
SELECT received_ts FROM events
WHERE stream_ordering >= ?
LIMIT 1
"""
txn.execute(sql, (stream_ordering,))
row = txn.fetchone()
if row and row[0]:
ts = row[0]
else:
ts = self.clock.time_msec()
return ts
return self.db_pool.runInteraction(
"get_approximate_received_ts", _get_approximate_received_ts_txn
)
@defer.inlineCallbacks
def get_event(
self,
@@ -883,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
rows = yield defer.ensureDeferred(
self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
)
return {r["event_id"] for r in rows}
@@ -923,36 +889,6 @@ class EventsWorkerStore(SQLBaseStore):
)
return results
def _get_total_state_event_counts_txn(self, txn, room_id):
"""
See get_total_state_event_counts.
"""
# We join against the events table as that has an index on room_id
sql = """
SELECT COUNT(*) FROM state_events
INNER JOIN events USING (room_id, event_id)
WHERE room_id=?
"""
txn.execute(sql, (room_id,))
row = txn.fetchone()
return row[0] if row else 0
def get_total_state_event_counts(self, room_id):
"""
Gets the total number of state events in a room.
Args:
room_id (str)
Returns:
Deferred[int]
"""
return self.db_pool.runInteraction(
"get_total_state_event_counts",
self._get_total_state_event_counts_txn,
room_id,
)
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
@@ -1222,97 +1158,6 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
@cached(num_args=5, max_entries=10)
def get_all_new_events(
self,
last_backfill_id,
last_forward_id,
current_backfill_id,
current_forward_id,
limit,
):
"""Get all the new events that have arrived at the server either as
new events or as backfilled events"""
have_backfill_events = last_backfill_id != current_backfill_id
have_forward_events = last_forward_id != current_forward_id
if not have_backfill_events and not have_forward_events:
return defer.succeed(AllNewEventsResult([], [], [], [], []))
def get_all_new_events_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
if have_forward_events:
txn.execute(sql, (last_forward_id, current_forward_id, limit))
new_forward_events = txn.fetchall()
if len(new_forward_events) == limit:
upper_bound = new_forward_events[-1][0]
else:
upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_forward_id, upper_bound))
forward_ex_outliers = txn.fetchall()
else:
new_forward_events = []
forward_ex_outliers = []
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
if have_backfill_events:
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
new_backfill_events = txn.fetchall()
if len(new_backfill_events) == limit:
upper_bound = new_backfill_events[-1][0]
else:
upper_bound = current_backfill_id
sql = (
"SELECT -event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_backfill_id, -upper_bound))
backward_ex_outliers = txn.fetchall()
else:
new_backfill_events = []
backward_ex_outliers = []
return AllNewEventsResult(
new_forward_events,
new_backfill_events,
forward_ex_outliers,
backward_ex_outliers,
)
return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
@@ -1357,14 +1202,3 @@ class EventsWorkerStore(SQLBaseStore):
return self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
[
"new_forward_events",
"new_backfill_events",
"forward_ex_outliers",
"backward_ex_outliers",
],
)

View File

@@ -1038,6 +1038,14 @@ class GroupServerStore(GroupServerWorkerStore):
"remove_user_from_group", _remove_user_from_group_txn
)
def change_user_admin_in_group(self, group_id, user_id, is_admin):
return self.db_pool.simple_update(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_admin": is_admin},
desc="change_user_admin_in_group"
)
def add_room_to_group(self, group_id, room_id, is_public):
return self.db_pool.simple_insert(
table="group_rooms",

View File

@@ -15,8 +15,8 @@
from typing import List, Tuple
from synapse.api.presence import UserPresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
)
def get_presence_for_users(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch(
async def get_presence_for_users(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self.db_pool.simple_insert(
table="presence_allow_inbound",
values={
"observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="allow_presence_visible",
or_ignore=True,
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
return self.db_pool.simple_delete_one(
table="presence_allow_inbound",
keyvalues={
"observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="disallow_presence_visible",
)

View File

@@ -32,7 +32,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ChainedIdGenerator
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -115,9 +115,9 @@ class PushRulesWorkerStore(
"""
raise NotImplementedError()
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self.db_pool.simple_select_list(
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id):
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@@ -133,17 +133,15 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
use_new_defaults = user_id in self._users_new_default_push_rules
rules = _load_rules(rows, enabled_map, use_new_defaults)
return _load_rules(rows, enabled_map, use_new_defaults)
return rules
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
results = yield self.db_pool.simple_select_list(
@cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id):
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@@ -170,18 +168,15 @@ class PushRulesWorkerStore(
)
@cachedList(
cached_method_name="get_push_rules_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
)
def bulk_get_push_rules(self, user_ids):
async def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}
results = {user_id: [] for user_id in user_ids}
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -194,7 +189,7 @@ class PushRulesWorkerStore(
for row in rows:
results.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
use_new_defaults = user_id in self._users_new_default_push_rules
@@ -205,14 +200,15 @@ class PushRulesWorkerStore(
return results
@defer.inlineCallbacks
def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
async def copy_push_rule_from_room_to_room(
self, new_room_id: str, user_id: str, rule: dict
) -> None:
"""Copy a single push rule from one room to another for a specific user.
Args:
new_room_id (str): ID of the new room.
user_id (str): ID of user the push rule belongs to.
rule (Dict): A push rule.
new_room_id: ID of the new room.
user_id : ID of user the push rule belongs to.
rule: A push rule.
"""
# Create new rule id
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
@@ -224,7 +220,7 @@ class PushRulesWorkerStore(
condition["pattern"] = new_room_id
# Add the rule for the new room
yield self.add_push_rule(
await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
priority_class=rule["priority_class"],
@@ -232,20 +228,19 @@ class PushRulesWorkerStore(
actions=rule["actions"],
)
@defer.inlineCallbacks
def copy_push_rules_from_room_to_room_for_user(
self, old_room_id, new_room_id, user_id
):
async def copy_push_rules_from_room_to_room_for_user(
self, old_room_id: str, new_room_id: str, user_id: str
) -> None:
"""Copy all of the push rules from one room to another for a specific
user.
Args:
old_room_id (str): ID of the old room.
new_room_id (str): ID of the new room.
user_id (str): ID of user to copy push rules for.
old_room_id: ID of the old room.
new_room_id: ID of the new room.
user_id: ID of user to copy push rules for.
"""
# Retrieve push rules for this user
user_push_rules = yield self.get_push_rules_for_user(user_id)
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
@@ -254,21 +249,20 @@ class PushRulesWorkerStore(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def bulk_get_push_rules_enabled(self, user_ids):
async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}
results = {user_id: {} for user_id in user_ids}
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@@ -332,8 +326,7 @@ class PushRulesWorkerStore(
class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
def add_push_rule(
async def add_push_rule(
self,
user_id,
rule_id,
@@ -342,13 +335,13 @@ class PushRuleStore(PushRulesWorkerStore):
actions,
before=None,
after=None,
):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
if before or after:
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
@@ -362,7 +355,7 @@ class PushRuleStore(PushRulesWorkerStore):
after,
)
else:
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
@@ -546,16 +539,15 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
@defer.inlineCallbacks
def delete_push_rule(self, user_id, rule_id):
async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
user_id: The matrix ID of the push rule owner
rule_id: The rule_id of the rule to be deleted
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
@@ -569,18 +561,17 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
event_stream_ordering,
)
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_id, rule_id, enabled):
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
@@ -611,8 +602,9 @@ class PushRuleStore(PushRulesWorkerStore):
op="ENABLE" if enabled else "DISABLE",
)
@defer.inlineCallbacks
def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
async def set_push_rule_actions(
self, user_id, rule_id, actions, is_default_rule
) -> None:
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -653,7 +645,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,

View File

@@ -19,10 +19,8 @@ from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.descriptors import cached, cachedList
logger = logging.getLogger(__name__)
@@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore):
Drops any rows whose data cannot be decoded
"""
for r in rows:
dataJson = r["data"]
data_json = r["data"]
try:
r["data"] = db_to_json(dataJson)
r["data"] = db_to_json(data_json)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
dataJson,
data_json,
e.args[0],
)
continue
yield r
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
ret = yield self.db_pool.simple_select_one_onecol(
async def user_has_pusher(self, user_id):
ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_by_user_id(self, user_id):
return self.get_pushers_by({"user_name": user_id})
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
ret = yield self.db_pool.simple_select_list(
async def get_pushers_by(self, keyvalues):
ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
[
@@ -87,16 +83,14 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
@defer.inlineCallbacks
def get_all_pushers(self):
async def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
return rows
return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -164,19 +158,16 @@ class PusherWorkerStore(SQLBaseStore):
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
@cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id):
@cached(num_args=1, max_entries=15000)
async def get_if_user_has_pusher(self, user_id):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
cached_method_name="get_if_user_has_pusher",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
def get_if_users_have_pushers(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch(
async def get_if_users_have_pushers(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@@ -189,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore):
return result
@defer.inlineCallbacks
def update_pusher_last_stream_ordering(
async def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
):
yield self.db_pool.simple_update_one(
) -> None:
await self.db_pool.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
@defer.inlineCallbacks
def update_pusher_last_stream_ordering_and_success(
self, app_id, pushkey, user_id, last_stream_ordering, last_success
):
async def update_pusher_last_stream_ordering_and_success(
self,
app_id: str,
pushkey: str,
user_id: str,
last_stream_ordering: int,
last_success: int,
) -> bool:
"""Update the last stream ordering position we've processed up to for
the given pusher.
Args:
app_id (str)
pushkey (str)
last_stream_ordering (int)
last_success (int)
app_id
pushkey
user_id
last_stream_ordering
last_success
Returns:
Deferred[bool]: True if the pusher still exists; False if it has been deleted.
True if the pusher still exists; False if it has been deleted.
"""
updated = yield self.db_pool.simple_update(
updated = await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@@ -228,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
@defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
yield self.db_pool.simple_update(
async def update_pusher_failing_since(
self, app_id, pushkey, user_id, failing_since
) -> None:
await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
desc="update_pusher_failing_since",
)
@defer.inlineCallbacks
def get_throttle_params_by_room(self, pusher_id):
res = yield self.db_pool.simple_select_list(
async def get_throttle_params_by_room(self, pusher_id):
res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@@ -255,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room
@defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params):
async def set_throttle_params(self, pusher_id, room_id, params) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
yield self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
@@ -272,8 +266,7 @@ class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
@defer.inlineCallbacks
def add_pusher(
async def add_pusher(
self,
user_id,
access_token,
@@ -287,11 +280,11 @@ class PusherStore(PusherWorkerStore):
data,
last_stream_ordering,
profile_tag="",
):
) -> None:
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
yield self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -316,15 +309,16 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
(user_id,),
)
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
async def delete_pusher_by_app_id_pushkey_user_id(
self, app_id, pushkey, user_id
) -> None:
def delete_pusher_txn(txn, stream_id):
self._invalidate_cache_and_stream(
txn, self.get_if_user_has_pusher, (user_id,)
@@ -351,6 +345,6 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)

View File

@@ -16,7 +16,7 @@
import abc
import logging
from typing import List, Tuple
from typing import List, Optional, Tuple
from twisted.internet import defer
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
raise NotImplementedError()
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
@cached()
async def get_users_with_read_receipts_in_room(self, room_id):
receipts = await self.get_receipts_for_room(room_id, "m.read")
return {r["user_id"] for r in receipts}
@cached(num_args=2)
@@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
rows = yield self.db_pool.simple_select_list(
@cached(num_args=2)
async def get_receipts_for_user(self, user_id, receipt_type):
rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -95,8 +95,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
@defer.inlineCallbacks
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
@@ -110,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
rows = yield self.db_pool.runInteraction(
rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
)
return {
@@ -122,56 +121,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows
}
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
async def get_linearized_receipts_for_rooms(
self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
room_ids (list): List of room_ids.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
room_id: List of room_ids.
to_key: Max stream id to fetch receipts upto.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
list: A list of receipts.
A list of receipts.
"""
room_ids = set(room_ids)
if from_key is not None:
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
room_ids = yield self._receipts_stream_cache.get_entities_changed(
room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
results = yield self._get_linearized_receipts_for_rooms(
results = await self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key
)
return [ev for res in results.values() for ev in res]
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""Get receipts for a single room for sending to clients.
Args:
room_ids (str): The room id.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
room_ids: The room id.
to_key: Max stream id to fetch receipts upto.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
Deferred[list]: A list of receipts.
A list of receipts.
"""
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
defer.succeed([])
return []
return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cachedInlineCallbacks(num_args=3, tree=True)
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
@cached(num_args=3, tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""See get_linearized_receipts_for_room
"""
@@ -195,7 +199,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rows
rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@@ -212,9 +216,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
inlineCallbacks=True,
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}
@@ -243,7 +246,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
txn_results = yield self.db_pool.runInteraction(
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
@@ -346,7 +349,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _invalidate_get_users_with_receipts_in_room(
self, room_id, receipt_type, user_id
self, room_id: str, receipt_type: str, user_id: str
):
if receipt_type != "m.read":
return
@@ -472,15 +475,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
return rx_ts
@defer.inlineCallbacks
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
async def insert_receipt(
self,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: dict,
) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
return
return None
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
@@ -507,13 +516,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = yield self.db_pool.runInteraction(
linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
event_ts = yield self.db_pool.runInteraction(
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -535,7 +544,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
now - event_ts,
)
yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()

View File

@@ -17,9 +17,7 @@
import logging
import re
from typing import Dict, List, Optional
from twisted.internet.defer import Deferred
from typing import Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -304,7 +302,7 @@ class RegistrationWorkerStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
@@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore):
id_server (str)
Returns:
Deferred
Awaitable
"""
# We need to use an upsert, in case they user had already bound the
# threepid
@@ -952,6 +950,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname=None,
admin=False,
user_type=None,
shadow_banned=False,
):
"""Attempts to register an account.
@@ -968,6 +967,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
shadow_banned (bool): Whether the user is shadow-banned,
i.e. they may be told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
@@ -986,6 +987,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
shadow_banned,
)
def _register_user(
@@ -999,6 +1001,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
shadow_banned,
):
user_id_obj = UserID.from_string(user_id)
@@ -1028,6 +1031,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
"shadow_banned": shadow_banned,
},
)
else:
@@ -1042,6 +1046,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
"shadow_banned": shadow_banned,
},
)
@@ -1077,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> Deferred:
) -> Awaitable:
"""Record a mapping from an external user id to a mxid
Args:
@@ -1345,43 +1350,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"validate_threepid_session_txn", validate_threepid_session_txn
)
def upsert_threepid_validation_session(
self,
medium,
address,
client_secret,
send_attempt,
session_id,
validated_at=None,
):
"""Upsert a threepid validation session
Args:
medium (str): The medium of the 3PID
address (str): The address of the 3PID
client_secret (str): A unique string provided by the client to
help identify this validation attempt
send_attempt (int): The latest send_attempt on this session
session_id (str): The id of this validation session
validated_at (int|None): The unix timestamp in milliseconds of
when the session was marked as valid
"""
insertion_values = {
"medium": medium,
"address": address,
"client_secret": client_secret,
}
if validated_at:
insertion_values["validated_at"] = validated_at
return self.db_pool.simple_upsert(
table="threepid_validation_session",
keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt},
insertion_values=insertion_values,
desc="upsert_threepid_validation_session",
)
def start_or_continue_validation_session(
self,
medium,

View File

@@ -35,10 +35,6 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple(
"OpsLevel", ("ban_level", "kick_level", "redact_level")
)
RatelimitOverride = collections.namedtuple(
"RatelimitOverride", ("messages_per_second", "burst_count")
)

View File

@@ -17,8 +17,6 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
@defer.inlineCallbacks
def _count_known_servers(self):
async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
count = yield self.db_pool.runInteraction("get_known_servers", _transact)
count = await self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
inlineCallbacks=True,
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
)
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -772,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
def get_membership_from_event_ids(
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
return self.db_pool.simple_select_many_batch(
return await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,

View File

@@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- A shadow-banned user may be told that their requests succeeded when they were
-- actually ignored.
ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;

View File

@@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- This table is no longer used.
DROP TABLE IF EXISTS presence_allow_inbound;

View File

@@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
inlineCallbacks=True,
)
def _get_state_group_for_events(self, event_ids):
async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,

View File

@@ -39,15 +39,17 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
from typing import Optional
from typing import Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -68,8 +70,12 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
direction, column_names, from_token, to_token, engine
):
direction: str,
column_names: Tuple[str, str],
from_token: Optional[Tuple[int, int]],
to_token: Optional[Tuple[int, int]],
engine: BaseDatabaseEngine,
) -> str:
"""Creates an SQL expression to bound the columns by the pagination
tokens.
@@ -90,21 +96,19 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
direction (str): Whether we're paginating backwards("b") or
forwards ("f").
column_names (tuple[str, str]): The column names to bound. Must *not*
be user defined as these get inserted directly into the SQL
statement without escapes.
from_token (tuple[int, int]|None): The start point for the pagination.
This is an exclusive minimum bound if direction is "f", and an
inclusive maximum bound if direction is "b".
to_token (tuple[int, int]|None): The endpoint point for the pagination.
This is an inclusive maximum bound if direction is "f", and an
exclusive minimum bound if direction is "b".
direction: Whether we're paginating backwards("b") or forwards ("f").
column_names: The column names to bound. Must *not* be user defined as
these get inserted directly into the SQL statement without escapes.
from_token: The start point for the pagination. This is an exclusive
minimum bound if direction is "f", and an inclusive maximum bound if
direction is "b".
to_token: The endpoint point for the pagination. This is an inclusive
maximum bound if direction is "f", and an exclusive minimum bound if
direction is "b".
engine: The database engine to generate the clauses for
Returns:
str: The sql expression
The sql expression
"""
assert direction in ("b", "f")
@@ -132,7 +136,12 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause)
def _make_generic_sql_bound(bound, column_names, values, engine):
def _make_generic_sql_bound(
bound: str,
column_names: Tuple[str, str],
values: Tuple[Optional[int], int],
engine: BaseDatabaseEngine,
) -> str:
"""Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
@@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
out manually.
Args:
bound (str): The comparison operator to use. One of ">", "<", ">=",
bound: The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right.
names (tuple[str, str]): The column names. Must *not* be user defined
names: The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without
escapes.
values (tuple[int|None, int]): The values to bound the columns by. If
values: The values to bound the columns by. If
the first value is None then only creates a bound on the second
column.
engine: The database engine to generate the SQL for
Returns:
str
The SQL statement
"""
assert bound in (">", "<", ">=", "<=")
@@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
)
def filter_to_clause(event_filter):
def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_room_min_stream_ordering(self):
raise NotImplementedError()
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(
self, room_ids, from_key, to_key, limit=0, order="DESC"
):
async def get_room_events_stream_for_rooms(
self,
room_ids: Iterable[str],
from_key: str,
to_key: str,
limit: int = 0,
order: str = "DESC",
) -> Dict[str, Tuple[List[EventBase], str]]:
"""Get new room events in stream ordering since `from_key`.
Args:
room_id (str)
from_key (str): Token from which no events are returned before
to_key (str): Token from which no events are returned after. (This
room_ids
from_key: Token from which no events are returned before
to_key: Token from which no events are returned after. (This
is typically the current stream token)
limit (int): Maximum number of events to return
order (str): Either "DESC" or "ASC". Determines which events are
limit: Maximum number of events to return
order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
Deferred[dict[str,tuple[list[FrozenEvent], str]]]
A map from room id to a tuple containing:
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
A map from room id to a tuple containing:
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = yield self._events_stream_cache.get_entities_changed(
room_ids, from_id
)
room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
if not room_ids:
return {}
@@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
res = yield make_deferred_yieldable(
res = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -361,28 +371,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if self._events_stream_cache.has_entity_changed(room_id, from_key)
}
@defer.inlineCallbacks
def get_room_events_stream_for_room(
self, room_id, from_key, to_key, limit=0, order="DESC"
):
async def get_room_events_stream_for_room(
self,
room_id: str,
from_key: str,
to_key: str,
limit: int = 0,
order: str = "DESC",
) -> Tuple[List[EventBase], str]:
"""Get new room events in stream ordering since `from_key`.
Args:
room_id (str)
from_key (str): Token from which no events are returned before
to_key (str): Token from which no events are returned after. (This
room_id
from_key: Token from which no events are returned before
to_key: Token from which no events are returned after. (This
is typically the current stream token)
limit (int): Maximum number of events to return
order (str): Either "DESC" or "ASC". Determines which events are
limit: Maximum number of events to return
order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
events (in ascending order) and the token from the start of
the chunk of events returned.
The list of events (in ascending order) and the token from the start
of the chunk of events returned.
"""
if from_key == to_key:
return [], from_key
@@ -390,9 +403,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
has_changed = yield self._events_stream_cache.has_entity_changed(
room_id, from_id
)
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
if not has_changed:
return [], from_key
@@ -410,9 +421,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
ret = yield self.get_events_as_list(
ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -430,8 +441,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
async def get_membership_changes_for_user(self, user_id, from_key, to_key):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -460,9 +470,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
ret = yield self.get_events_as_list(
ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -470,27 +480,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
@defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token):
async def get_recent_events_for_room(
self, room_id: str, limit: int, end_token: str
) -> Tuple[List[EventBase], str]:
"""Get the most recent events in the room in topological ordering.
Args:
room_id (str)
limit (int)
end_token (str): The stream token representing now.
room_id
limit
end_token: The stream token representing now.
Returns:
Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
events and a token pointing to the start of the returned
events.
The events returned are in ascending order.
A list of events and a token pointing to the start of the returned
events. The events returned are in ascending order.
"""
rows, token = yield self.get_recent_event_ids_for_room(
rows, token = await self.get_recent_event_ids_for_room(
room_id, limit, end_token
)
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -498,20 +507,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
@defer.inlineCallbacks
def get_recent_event_ids_for_room(self, room_id, limit, end_token):
async def get_recent_event_ids_for_room(
self, room_id: str, limit: int, end_token: str
) -> Tuple[List[_EventDictReturn], str]:
"""Get the most recent events in the room in topological ordering.
Args:
room_id (str)
limit (int)
end_token (str): The stream token representing now.
room_id
limit
end_token: The stream token representing now.
Returns:
Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
_EventDictReturn and a token pointing to the start of the returned
events.
The events returned are in ascending order.
A list of _EventDictReturn and a token pointing to the start of the
returned events. The events returned are in ascending order.
"""
# Allow a zero limit here, and no-op.
if limit == 0:
@@ -519,7 +527,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
rows, token = yield self.db_pool.runInteraction(
rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -532,12 +540,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token
def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
"""Gets details of the first event in a room at or before a stream ordering
Args:
room_id (str):
stream_ordering (int):
room_id:
stream_ordering:
Returns:
Deferred[(int, int, str)]:
@@ -574,55 +582,67 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return "t%d-%d" % (topo, token)
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
async def get_stream_id_for_event(self, event_id: str) -> int:
"""The stream ID for an event
Args:
event_id(str): The id of the event to look up a stream token for.
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A deferred "s%d" stream token.
A stream ID.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,))
)
def get_topological_token_for_event(self, event_id):
async def get_stream_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
event_id(str): The id of the event to look up a stream token for.
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A deferred "t%d-%d" topological token.
A "s%d" stream token.
"""
return self.db_pool.simple_select_one(
stream_id = await self.get_stream_id_for_event(event_id)
return "s%d" % (stream_id,)
async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A "t%d-%d" topological token.
"""
row = await self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
).addCallback(
lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
)
return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
def get_max_topological_token(self, room_id, stream_key):
async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
"""Get the max topological token in a room before the given stream
ordering.
Args:
room_id (str)
stream_key (int)
room_id
stream_key
Returns:
Deferred[int]
The maximum topological token.
"""
sql = (
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
return self.db_pool.execute(
row = await self.db_pool.execute(
"get_max_topological_token", None, sql, room_id, stream_key
).addCallback(lambda r: r[0][0] if r else 0)
)
return row[0][0] if row else 0
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
@@ -634,16 +654,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows[0][0] if rows else 0
@staticmethod
def _set_before_and_after(events, rows, topo_order=True):
def _set_before_and_after(
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
):
"""Inserts ordering information to events' internal metadata from
the DB rows.
Args:
events (list[FrozenEvent])
rows (list[_EventDictReturn])
topo_order (bool): Whether the events were ordered topologically
or by stream ordering. If true then all rows should have a non
null topological_ordering.
events
rows
topo_order: Whether the events were ordered topologically or by stream
ordering. If true then all rows should have a non null
topological_ordering.
"""
for event, row in zip(events, rows):
stream = row.stream_ordering
@@ -656,25 +678,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (int(topo) if topo else 0, int(stream))
@defer.inlineCallbacks
def get_events_around(
self, room_id, event_id, before_limit, after_limit, event_filter=None
):
async def get_events_around(
self,
room_id: str,
event_id: str,
before_limit: int,
after_limit: int,
event_filter: Optional[Filter] = None,
) -> dict:
"""Retrieve events and pagination tokens around a given event in a
room.
Args:
room_id (str)
event_id (str)
before_limit (int)
after_limit (int)
event_filter (Filter|None)
Returns:
dict
"""
results = yield self.db_pool.runInteraction(
results = await self.db_pool.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -684,11 +700,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
events_before = yield self.get_events_as_list(
events_before = await self.get_events_as_list(
list(results["before"]["event_ids"]), get_prev_content=True
)
events_after = yield self.get_events_as_list(
events_after = await self.get_events_as_list(
list(results["after"]["event_ids"]), get_prev_content=True
)
@@ -700,17 +716,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter
):
self,
txn,
room_id: str,
event_id: str,
before_limit: int,
after_limit: int,
event_filter: Optional[Filter],
) -> dict:
"""Retrieves event_ids and pagination tokens around a given event in a
room.
Args:
room_id (str)
event_id (str)
before_limit (int)
after_limit (int)
event_filter (Filter|None)
room_id
event_id
before_limit
after_limit
event_filter
Returns:
dict
@@ -758,22 +780,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
@defer.inlineCallbacks
def get_all_new_events_stream(self, from_id, current_id, limit):
async def get_all_new_events_stream(
self, from_id: int, current_id: int, limit: int
) -> Tuple[int, List[EventBase]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
Args:
from_id (int): the stream_ordering of the last event we processed
current_id (int): the stream_ordering of the most recently processed event
limit (int): the maximum number of events to return
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
Returns:
Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
`next_id` is the next value to pass as `from_id` (it will either be the
stream_ordering of the last returned event, or, if fewer than `limit` events
were found, `current_id`.
A tuple of (next_id, events), where `next_id` is the next value to
pass as `from_id` (it will either be the stream_ordering of the
last returned event, or, if fewer than `limit` events were found,
the `current_id`).
"""
def get_all_new_events_stream_txn(txn):
@@ -795,11 +818,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.db_pool.runInteraction(
upper_bound, event_ids = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
events = yield self.get_events_as_list(event_ids)
events = await self.get_events_as_list(event_ids)
return upper_bound, events
@@ -817,21 +840,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_federation_out_pos",
)
async def update_federation_out_pos(self, typ, stream_id):
async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
if self._need_to_reset_federation_stream_positions:
await self.db_pool.runInteraction(
"_reset_federation_positions_txn", self._reset_federation_positions_txn
)
self._need_to_reset_federation_stream_positions = False
return await self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
def _reset_federation_positions_txn(self, txn):
def _reset_federation_positions_txn(self, txn) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@@ -892,39 +915,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
values={"stream_id": stream_id},
)
def has_room_changed_since(self, room_id, stream_id):
def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(
self,
txn,
room_id,
from_token,
to_token=None,
direction="b",
limit=-1,
event_filter=None,
):
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], str]:
"""Returns list of events before or after a given token.
Args:
txn
room_id (str)
from_token (RoomStreamToken): The token used to stream from
to_token (RoomStreamToken|None): A token which if given limits the
results to only those before
direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`.
limit (int): The maximum number of events to return.
event_filter (Filter|None): If provided filters the events to
room_id
from_token: The token used to stream from
to_token: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating
forwards or backwards from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to
those that match the filter.
Returns:
Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
as a list of _EventDictReturn and a token that points to the end
of the result set. If no events are returned then the end of the
stream has been reached (i.e. there are no events between
`from_token` and `to_token`), or `limit` is zero.
A list of _EventDictReturn and a token that points to the end of the
result set. If no events are returned then the end of the stream has
been reached (i.e. there are no events between `from_token` and
`to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -1008,35 +1029,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, str(next_token)
@defer.inlineCallbacks
def paginate_room_events(
self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
):
async def paginate_room_events(
self,
room_id: str,
from_key: str,
to_key: Optional[str] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], str]:
"""Returns list of events before or after a given token.
Args:
room_id (str)
from_key (str): The token used to stream from
to_key (str|None): A token which if given limits the results to
only those before
direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`.
limit (int): The maximum number of events to return.
event_filter (Filter|None): If provided filters the events to
those that match the filter.
room_id
from_key: The token used to stream from
to_key: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating
forwards or backwards from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to those that match the filter.
Returns:
tuple[list[FrozenEvent], str]: Returns the results as a list of
events and a token that points to the end of the result set. If no
events are returned then the end of the stream has been reached
(i.e. there are no events between `from_key` and `to_key`).
The results as a list of events and a token that points to the end
of the result set. If no events are returned then the end of the
stream has been reached (i.e. there are no events between `from_key`
and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
rows, token = yield self.db_pool.runInteraction(
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
@@ -1047,7 +1071,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -1057,8 +1081,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
class StreamStore(StreamWorkerStore):
def get_room_max_stream_ordering(self):
def get_room_max_stream_ordering(self) -> int:
return self._stream_id_gen.get_current_token()
def get_room_min_stream_ordering(self):
def get_room_min_stream_ordering(self) -> int:
return self._backfill_id_gen.get_current_token()

View File

@@ -38,10 +38,8 @@ class UserErasureWorkerStore(SQLBaseStore):
desc="is_user_erased",
).addCallback(operator.truth)
@cachedList(
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
)
def are_users_erased(self, user_ids):
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure
@@ -49,14 +47,14 @@ class UserErasureWorkerStore(SQLBaseStore):
user_ids (iterable[str]): full user id to check
Returns:
Deferred[dict[str, bool]]:
dict[str, bool]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@@ -65,8 +63,7 @@ class UserErasureWorkerStore(SQLBaseStore):
)
erased_users = {row["user_id"] for row in rows}
res = {u: u in erased_users for u in user_ids}
return res
return {u: u in erased_users for u in user_ids}
class UserErasureStore(UserErasureWorkerStore):

View File

@@ -51,7 +51,15 @@ JsonDict = Dict[str, Any]
class Requester(
namedtuple(
"Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
"Requester",
[
"user",
"access_token_id",
"is_guest",
"shadow_banned",
"device_id",
"app_service",
],
)
):
"""
@@ -62,6 +70,7 @@ class Requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
shadow_banned (bool): True if the user making this request has been shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
"""
@@ -77,6 +86,7 @@ class Requester(
"user_id": self.user.to_string(),
"access_token_id": self.access_token_id,
"is_guest": self.is_guest,
"shadow_banned": self.shadow_banned,
"device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None,
}
@@ -101,13 +111,19 @@ class Requester(
user=UserID.from_string(input["user_id"]),
access_token_id=input["access_token_id"],
is_guest=input["is_guest"],
shadow_banned=input["shadow_banned"],
device_id=input["device_id"],
app_service=appservice,
)
def create_requester(
user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
user_id,
access_token_id=None,
is_guest=False,
shadow_banned=False,
device_id=None,
app_service=None,
):
"""
Create a new ``Requester`` object
@@ -117,6 +133,7 @@ def create_requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
@@ -125,7 +142,9 @@ def create_requester(
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
return Requester(user_id, access_token_id, is_guest, device_id, app_service)
return Requester(
user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
)
def get_domain_from_id(string):

View File

@@ -32,7 +32,6 @@ json_encoder = json.JSONEncoder(separators=(",", ":"))
def unwrapFirstError(failure):
# defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError)
logger.info("DDD failure.value.subFailure: %s", failure.value.subFailure)
return failure.value.subFailure

View File

@@ -24,9 +24,7 @@ from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
# Note: The : character is allowed here for older clients, but will be removed in a
# future release. Context: https://github.com/matrix-org/synapse/issues/6766
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$")
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
# random_string and random_string_with_symbols are used for a range of things,
# some cryptographically important, some less so. We use SystemRandom to make sure

View File

@@ -1,6 +1,4 @@
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from synapse.appservice import ApplicationService
from synapse.types import create_requester
from tests import unittest
@@ -20,77 +18,6 @@ class TestRatelimiter(unittest.TestCase):
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
def test_allowed_user_via_can_requester_do_action(self):
user_requester = create_requester("@user:example.com")
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_requester_do_action(
user_requester, _time_now_s=0
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
user_requester, _time_now_s=5
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
user_requester, _time_now_s=10
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=True,
)
as_requester = create_requester("@user:example.com", app_service=appservice)
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=0
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=5
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=10
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=False,
)
as_requester = create_requester("@user:example.com", app_service=appservice)
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=0
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=5
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=10
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
def test_allowed_via_ratelimit(self):
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)

82
tests/config/test_base.py Normal file
View File

@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import tempfile
from synapse.config import ConfigError
from synapse.util.stringutils import random_string
from tests import unittest
class BaseConfigTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.hs = hs
def test_loading_missing_templates(self):
# Use a temporary directory that exists on the system, but that isn't likely to
# contain template files
with tempfile.TemporaryDirectory() as tmp_dir:
# Attempt to load an HTML template from our custom template directory
template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
# If no errors, we should've gotten the default template instead
# Render the template
a_random_string = random_string(5)
html_content = template.render({"error_description": a_random_string})
# Check that our string exists in the template
self.assertIn(
a_random_string,
html_content,
"Template file did not contain our test string",
)
def test_loading_custom_templates(self):
# Use a temporary directory that exists on the system
with tempfile.TemporaryDirectory() as tmp_dir:
# Create a temporary bogus template file
with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template:
# Get temporary file's filename
template_filename = os.path.basename(tmp_template.name)
# Write a custom HTML template
contents = b"{{ test_variable }}"
tmp_template.write(contents)
tmp_template.flush()
# Attempt to load the template from our custom template directory
template = (
self.hs.config.read_templates([template_filename], tmp_dir)
)[0]
# Render the template
a_random_string = random_string(5)
html_content = template.render({"test_variable": a_random_string})
# Check that our string exists in the template
self.assertIn(
a_random_string,
html_content,
"Template file did not contain our test string",
)
def test_loading_template_from_nonexistent_custom_directory(self):
with self.assertRaises(ConfigError):
self.hs.config.read_templates(
["some_filename.html"], "a_nonexistent_directory"
)

View File

@@ -79,9 +79,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
handler.federation_handler.do_invite_join = Mock(
return_value=make_awaitable(("", 1))
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -110,9 +112,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
handler.federation_handler.do_invite_join = Mock(
return_value=make_awaitable(("", 1))
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -148,9 +152,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
fed_transport.client.get_json = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(None)
)
handler.federation_handler.do_invite_join = Mock(
return_value=make_awaitable(("", 1))
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
# Artificially raise the complexity
@@ -204,9 +210,11 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
handler.federation_handler.do_invite_join = Mock(
return_value=make_awaitable(("", 1))
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -234,9 +242,11 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
handler.federation_handler.do_invite_join = Mock(
return_value=make_awaitable(("", 1))
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(

View File

@@ -19,6 +19,7 @@ from mock import Mock, call
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
@@ -32,7 +33,6 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest.client.v1 import room
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from tests import unittest

View File

@@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
yield self.store.create_profile(self.frank.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler()
self.hs = hs
@@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
yield self.store.create_profile("caroline")
yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.set_profile_displayname("caroline", "Caroline")
response = yield defer.ensureDeferred(

View File

@@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None
)

View File

@@ -28,7 +28,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.types import JsonDict, RoomAlias
from synapse.util.stringutils import random_string
from tests import unittest
@@ -675,91 +675,6 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
servlets = [
profile.register_servlets,
room.register_servlets,
]
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
)
def test_join_local_ratelimit(self):
"""Tests that local joins are actually rate-limited."""
for i in range(5):
self.helper.create_room_as(self.user_id)
self.helper.create_room_as(self.user_id, expect_code=429)
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
)
def test_join_local_ratelimit_profile_change(self):
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
# Create and join more rooms than the rate-limiting config allows in a second.
room_ids = [
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
]
self.reactor.advance(1)
room_ids = room_ids + [
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
]
# Create a profile for the user, since it hasn't been done on registration.
store = self.hs.get_datastore()
store.create_profile(UserID.from_string(self.user_id).localpart)
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
self.render(request)
self.assertEquals(channel.code, 200, channel.json_body)
# Check that all the rooms have been sent a profile update into.
for room_id in room_ids:
path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
room_id,
self.user_id,
)
request, channel = self.make_request("GET", path)
self.render(request)
self.assertEquals(channel.code, 200)
self.assertIn("displayname", channel.json_body)
self.assertEquals(channel.json_body["displayname"], "John Doe")
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
)
def test_join_local_ratelimit_idempotent(self):
"""Tests that the room join endpoints remain idempotent despite rate-limiting
on room joins."""
room_id = self.helper.create_room_as(self.user_id)
# Let's test both paths to be sure.
paths_to_test = [
"/_matrix/client/r0/rooms/%s/join",
"/_matrix/client/r0/join/%s",
]
for path in paths_to_test:
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
for i in range(6):
request, channel = self.make_request("POST", path % room_id, {})
self.render(request)
self.assertEquals(channel.code, 200)
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """

View File

@@ -39,9 +39,7 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
def create_room_as(
self, room_creator=None, is_public=True, tok=None, expect_code=200,
):
def create_room_as(self, room_creator=None, is_public=True, tok=None):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
@@ -56,11 +54,9 @@ class RestHelper(object):
)
render(request, self.resource, self.hs.get_reactor())
assert channel.result["code"] == b"%d" % expect_code, channel.result
assert channel.result["code"] == b"200", channel.result
self.auth_user_id = temp_id
if expect_code == 200:
return channel.json_body["room_id"]
return channel.json_body["room_id"]
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(

View File

@@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"

View File

@@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
yield self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"}
yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"}
)
)
self.mock_txn.execute.assert_called_with(
@@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
yield self.datastore.db_pool.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
)
)
self.mock_txn.execute.assert_called_with(

View File

@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
self.requester = Requester(self.user, None, False, None, None)
self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
self.requester = Requester(self.user, None, False, None, None)
self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()

View File

@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
requester = Requester(user, None, False, None, None)
requester = Requester(user, None, False, False, None, None)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]

View File

@@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
return self.store.db_pool.simple_insert(
"events",
{
"stream_ordering": so,
"received_ts": ts,
"event_id": "event%i" % so,
"type": "",
"room_id": "",
"content": "",
"processed": True,
"outlier": False,
"topological_ordering": 0,
"depth": 0,
},
return defer.ensureDeferred(
self.store.db_pool.simple_insert(
"events",
{
"stream_ordering": so,
"received_ts": ts,
"event_id": "event%i" % so,
"type": "",
"room_id": "",
"content": "",
"processed": True,
"outlier": False,
"topological_ordering": 0,
"depth": 0,
},
)
)
# start with the base case where there are no events in the table

View File

@@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
yield self.store.create_profile(self.user.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
users, total = yield self.store.get_users_paginate(

View File

@@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
yield self.store.create_profile(self.u_frank.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
@@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
yield self.store.create_profile(self.u_frank.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"

View File

@@ -15,6 +15,7 @@
from twisted.internet import defer
from synapse.api.errors import NotFoundError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -46,30 +47,19 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
event = store.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
event = self.get_success(
store.get_topological_token_for_event(last["event_id"])
)
# Purge everything before this topological token
purge = defer.ensureDeferred(
storage.purge_events.purge_history(self.room_id, event, True)
)
self.pump()
self.assertEqual(self.successResultOf(purge), None)
# Try and get the events
get_first = store.get_event(first["event_id"])
get_second = store.get_event(second["event_id"])
get_third = store.get_event(third["event_id"])
get_last = store.get_event(last["event_id"])
self.pump()
self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
self.failureResultOf(get_first)
self.failureResultOf(get_second)
self.failureResultOf(get_third)
self.successResultOf(get_last)
self.get_failure(store.get_event(first["event_id"]), NotFoundError)
self.get_failure(store.get_event(second["event_id"]), NotFoundError)
self.get_failure(store.get_event(third["event_id"]), NotFoundError)
self.get_success(store.get_event(last["event_id"]))
def test_purge_wont_delete_extrems(self):
"""
@@ -84,9 +74,9 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
event = storage.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
event = self.get_success(
storage.get_topological_token_for_event(last["event_id"])
)
event = "t{}-{}".format(
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
)
@@ -98,14 +88,7 @@ class PurgeTests(HomeserverTestCase):
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
get_first = storage.get_event(first["event_id"])
get_second = storage.get_event(second["event_id"])
get_third = storage.get_event(third["event_id"])
get_last = storage.get_event(last["event_id"])
self.pump()
# Nothing is deleted.
self.successResultOf(get_first)
self.successResultOf(get_second)
self.successResultOf(get_third)
self.successResultOf(get_last)
self.get_success(storage.get_event(first["event_id"]))
self.get_success(storage.get_event(second["event_id"]))
self.get_success(storage.get_event(third["event_id"]))
self.get_success(storage.get_event(last["event_id"]))

View File

@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
requester = Requester(user, None, False, None, None)
requester = Requester(user, None, False, False, None, None)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.

Some files were not shown because too many files have changed in this diff Show More