Compare commits
23 Commits
anoa/morga
...
travis/gro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b024acffea | ||
|
|
acfb7c3b5d | ||
|
|
3c01724b33 | ||
|
|
5cf7c12995 | ||
|
|
408aef8276 | ||
|
|
2f4d60a5ba | ||
|
|
25e55d2598 | ||
|
|
8b6c176aee | ||
|
|
050e20e7ca | ||
|
|
e04e465b4d | ||
|
|
8390e00c7f | ||
|
|
ad6190c925 | ||
|
|
ac77cdb64e | ||
|
|
b069b78bb4 | ||
|
|
e8861957d9 | ||
|
|
dc22090a67 | ||
|
|
6b7ce1d332 | ||
|
|
894dae74fe | ||
|
|
7bdf9828d5 | ||
|
|
bfd79c2988 | ||
|
|
53834bb9c4 | ||
|
|
ff0e894656 | ||
|
|
dd8f28bd3f |
33
CHANGES.md
33
CHANGES.md
@@ -1,28 +1,15 @@
|
||||
Synapse 1.19.2 (2020-09-16)
|
||||
===========================
|
||||
For the next release
|
||||
====================
|
||||
|
||||
Due to the issue below server admins are encouraged to upgrade as soon as possible.
|
||||
Removal warning
|
||||
---------------
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix joining rooms over federation that include malformed events. ([\#8324](https://github.com/matrix-org/synapse/issues/8324))
|
||||
|
||||
|
||||
Synapse 1.19.1 (2020-08-27)
|
||||
===========================
|
||||
|
||||
No significant changes.
|
||||
|
||||
|
||||
Synapse 1.19.1rc1 (2020-08-25)
|
||||
==============================
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix a bug introduced in v1.19.0 where appservices with ratelimiting disabled would still be ratelimited when joining rooms. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
|
||||
- Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. ([\#8153](https://github.com/matrix-org/synapse/issues/8153))
|
||||
Some older clients used a
|
||||
[disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken)
|
||||
(`:`) in the `client_secret` parameter of various endpoints. The incorrect
|
||||
behaviour was allowed for backwards compatibility, but is now being removed
|
||||
from Synapse as most users have updated their client. Further context can be
|
||||
found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
|
||||
|
||||
|
||||
Synapse 1.19.0 (2020-08-17)
|
||||
|
||||
1
changelog.d/7864.bugfix
Normal file
1
changelog.d/7864.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a memory leak by limiting the length of time that messages will be queued for a remote server that has been unreachable.
|
||||
1
changelog.d/8013.feature
Normal file
1
changelog.d/8013.feature
Normal file
@@ -0,0 +1 @@
|
||||
Iteratively encode JSON to avoid blocking the reactor.
|
||||
1
changelog.d/8037.feature
Normal file
1
changelog.d/8037.feature
Normal file
@@ -0,0 +1 @@
|
||||
Use the default template file when its equivalent is not found in a custom template directory.
|
||||
1
changelog.d/8072.misc
Normal file
1
changelog.d/8072.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8074.misc
Normal file
1
changelog.d/8074.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8075.misc
Normal file
1
changelog.d/8075.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8076.misc
Normal file
1
changelog.d/8076.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8081.bugfix
Normal file
1
changelog.d/8081.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix `Re-starting finished log context PUT-nnnn` warning when event persistence failed.
|
||||
1
changelog.d/8085.misc
Normal file
1
changelog.d/8085.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove some unused database functions.
|
||||
1
changelog.d/8087.misc
Normal file
1
changelog.d/8087.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8090.misc
Normal file
1
changelog.d/8090.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to `synapse.handlers.room`.
|
||||
1
changelog.d/8092.feature
Normal file
1
changelog.d/8092.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
1
changelog.d/8093.misc
Normal file
1
changelog.d/8093.misc
Normal file
@@ -0,0 +1 @@
|
||||
Return the previous stream token if a non-member event is a duplicate.
|
||||
1
changelog.d/8100.misc
Normal file
1
changelog.d/8100.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8101.bugfix
Normal file
1
changelog.d/8101.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Synapse now correctly enforces the valid characters in the `client_secret` parameter used in various endpoints.
|
||||
1
changelog.d/8107.feature
Normal file
1
changelog.d/8107.feature
Normal file
@@ -0,0 +1 @@
|
||||
Use the default template file when its equivalent is not found in a custom template directory.
|
||||
1
changelog.d/8111.doc
Normal file
1
changelog.d/8111.doc
Normal file
@@ -0,0 +1 @@
|
||||
Link to matrix-synapse-rest-password-provider in the password provider documentation.
|
||||
1
changelog.d/8112.misc
Normal file
1
changelog.d/8112.misc
Normal file
@@ -0,0 +1 @@
|
||||
Return the previous stream token if a non-member event is a duplicate.
|
||||
12
debian/changelog
vendored
12
debian/changelog
vendored
@@ -1,15 +1,3 @@
|
||||
matrix-synapse-py3 (1.19.2) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.19.2.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Wed, 16 Sep 2020 12:50:30 +0100
|
||||
|
||||
matrix-synapse-py3 (1.19.1) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.19.1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Thu, 27 Aug 2020 10:50:19 +0100
|
||||
|
||||
matrix-synapse-py3 (1.19.0) stable; urgency=medium
|
||||
|
||||
[ Synapse Packaging team ]
|
||||
|
||||
@@ -14,6 +14,7 @@ password auth provider module implementations:
|
||||
|
||||
* [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/)
|
||||
* [matrix-synapse-shared-secret-auth](https://github.com/devture/matrix-synapse-shared-secret-auth)
|
||||
* [matrix-synapse-rest-password-provider](https://github.com/ma1uta/matrix-synapse-rest-password-provider)
|
||||
|
||||
## Required methods
|
||||
|
||||
|
||||
@@ -2002,9 +2002,7 @@ email:
|
||||
# Directory in which Synapse will try to find the template files below.
|
||||
# If not set, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
|
||||
# If you *do* uncomment it, you will need to make sure that all the templates
|
||||
# below are in the directory.
|
||||
# Do not uncomment this setting unless you want to customise the templates.
|
||||
#
|
||||
# Synapse will look for the following templates in this directory:
|
||||
#
|
||||
|
||||
@@ -48,7 +48,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.19.2"
|
||||
__version__ = "1.19.0"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
||||
@@ -213,6 +213,7 @@ class Auth(object):
|
||||
user = user_info["user"]
|
||||
token_id = user_info["token_id"]
|
||||
is_guest = user_info["is_guest"]
|
||||
shadow_banned = user_info["shadow_banned"]
|
||||
|
||||
# Deny the request if the user account has expired.
|
||||
if self._account_validity.enabled and not allow_expired:
|
||||
@@ -252,7 +253,12 @@ class Auth(object):
|
||||
opentracing.set_tag("device_id", device_id)
|
||||
|
||||
return synapse.types.create_requester(
|
||||
user, token_id, is_guest, device_id, app_service=app_service
|
||||
user,
|
||||
token_id,
|
||||
is_guest,
|
||||
shadow_banned,
|
||||
device_id,
|
||||
app_service=app_service,
|
||||
)
|
||||
except KeyError:
|
||||
raise MissingClientTokenError()
|
||||
@@ -297,6 +303,7 @@ class Auth(object):
|
||||
dict that includes:
|
||||
`user` (UserID)
|
||||
`is_guest` (bool)
|
||||
`shadow_banned` (bool)
|
||||
`token_id` (int|None): access token id. May be None if guest
|
||||
`device_id` (str|None): device corresponding to access token
|
||||
Raises:
|
||||
@@ -356,6 +363,7 @@ class Auth(object):
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": True,
|
||||
"shadow_banned": False,
|
||||
"token_id": None,
|
||||
# all guests get the same device id
|
||||
"device_id": GUEST_DEVICE_ID,
|
||||
@@ -365,6 +373,7 @@ class Auth(object):
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": False,
|
||||
"shadow_banned": False,
|
||||
"token_id": None,
|
||||
"device_id": None,
|
||||
}
|
||||
@@ -488,6 +497,7 @@ class Auth(object):
|
||||
"user": UserID.from_string(ret.get("name")),
|
||||
"token_id": ret.get("token_id", None),
|
||||
"is_guest": False,
|
||||
"shadow_banned": ret.get("shadow_banned"),
|
||||
"device_id": ret.get("device_id"),
|
||||
"valid_until_ms": ret.get("valid_until_ms"),
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ from jsonschema import FormatChecker
|
||||
|
||||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.types import RoomID, UserID
|
||||
|
||||
FILTER_SCHEMA = {
|
||||
|
||||
@@ -17,7 +17,6 @@ from collections import OrderedDict
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.types import Requester
|
||||
from synapse.util import Clock
|
||||
|
||||
|
||||
@@ -44,42 +43,6 @@ class Ratelimiter(object):
|
||||
# * The rate_hz of this particular entry. This can vary per request
|
||||
self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]]
|
||||
|
||||
def can_requester_do_action(
|
||||
self,
|
||||
requester: Requester,
|
||||
rate_hz: Optional[float] = None,
|
||||
burst_count: Optional[int] = None,
|
||||
update: bool = True,
|
||||
_time_now_s: Optional[int] = None,
|
||||
) -> Tuple[bool, float]:
|
||||
"""Can the requester perform the action?
|
||||
|
||||
Args:
|
||||
requester: The requester to key off when rate limiting. The user property
|
||||
will be used.
|
||||
rate_hz: The long term number of actions that can be performed in a second.
|
||||
Overrides the value set during instantiation if set.
|
||||
burst_count: How many actions that can be performed before being limited.
|
||||
Overrides the value set during instantiation if set.
|
||||
update: Whether to count this check as performing the action
|
||||
_time_now_s: The current time. Optional, defaults to the current time according
|
||||
to self.clock. Only used by tests.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
* A bool indicating if they can perform the action now
|
||||
* The reactor timestamp for when the action can be performed next.
|
||||
-1 if rate_hz is less than or equal to zero
|
||||
"""
|
||||
# Disable rate limiting of users belonging to any AS that is configured
|
||||
# not to be rate limited in its registration file (rate_limited: true|false).
|
||||
if requester.app_service and not requester.app_service.is_rate_limited():
|
||||
return True, -1.0
|
||||
|
||||
return self.can_do_action(
|
||||
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
|
||||
)
|
||||
|
||||
def can_do_action(
|
||||
self,
|
||||
key: Any,
|
||||
|
||||
@@ -18,12 +18,16 @@
|
||||
import argparse
|
||||
import errno
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections import OrderedDict
|
||||
from hashlib import sha256
|
||||
from textwrap import dedent
|
||||
from typing import Any, List, MutableMapping, Optional
|
||||
from typing import Any, Callable, List, MutableMapping, Optional
|
||||
|
||||
import attr
|
||||
import jinja2
|
||||
import pkg_resources
|
||||
import yaml
|
||||
|
||||
|
||||
@@ -100,6 +104,11 @@ class Config(object):
|
||||
def __init__(self, root_config=None):
|
||||
self.root = root_config
|
||||
|
||||
# Get the path to the default Synapse template directory
|
||||
self.default_template_dir = pkg_resources.resource_filename(
|
||||
"synapse", "res/templates"
|
||||
)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Try and fetch a configuration option that does not exist on this class.
|
||||
@@ -184,6 +193,95 @@ class Config(object):
|
||||
with open(file_path) as file_stream:
|
||||
return file_stream.read()
|
||||
|
||||
def read_templates(
|
||||
self, filenames: List[str], custom_template_directory: Optional[str] = None,
|
||||
) -> List[jinja2.Template]:
|
||||
"""Load a list of template files from disk using the given variables.
|
||||
|
||||
This function will attempt to load the given templates from the default Synapse
|
||||
template directory. If `custom_template_directory` is supplied, that directory
|
||||
is tried first.
|
||||
|
||||
Files read are treated as Jinja templates. These templates are not rendered yet.
|
||||
|
||||
Args:
|
||||
filenames: A list of template filenames to read.
|
||||
|
||||
custom_template_directory: A directory to try to look for the templates
|
||||
before using the default Synapse template directory instead.
|
||||
|
||||
Raises:
|
||||
ConfigError: if the file's path is incorrect or otherwise cannot be read.
|
||||
|
||||
Returns:
|
||||
A list of jinja2 templates.
|
||||
"""
|
||||
templates = []
|
||||
search_directories = [self.default_template_dir]
|
||||
|
||||
# The loader will first look in the custom template directory (if specified) for the
|
||||
# given filename. If it doesn't find it, it will use the default template dir instead
|
||||
if custom_template_directory:
|
||||
# Check that the given template directory exists
|
||||
if not self.path_exists(custom_template_directory):
|
||||
raise ConfigError(
|
||||
"Configured template directory does not exist: %s"
|
||||
% (custom_template_directory,)
|
||||
)
|
||||
|
||||
# Search the custom template directory as well
|
||||
search_directories.insert(0, custom_template_directory)
|
||||
|
||||
loader = jinja2.FileSystemLoader(search_directories)
|
||||
env = jinja2.Environment(loader=loader, autoescape=True)
|
||||
|
||||
# Update the environment with our custom filters
|
||||
env.filters.update(
|
||||
{
|
||||
"format_ts": _format_ts_filter,
|
||||
"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
|
||||
}
|
||||
)
|
||||
|
||||
for filename in filenames:
|
||||
# Load the template
|
||||
template = env.get_template(filename)
|
||||
templates.append(template)
|
||||
|
||||
return templates
|
||||
|
||||
|
||||
def _format_ts_filter(value: int, format: str):
|
||||
return time.strftime(format, time.localtime(value / 1000))
|
||||
|
||||
|
||||
def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
|
||||
"""Create and return a jinja2 filter that converts MXC urls to HTTP
|
||||
|
||||
Args:
|
||||
public_baseurl: The public, accessible base URL of the homeserver
|
||||
"""
|
||||
|
||||
def mxc_to_http_filter(value, width, height, resize_method="crop"):
|
||||
if value[0:6] != "mxc://":
|
||||
return ""
|
||||
|
||||
server_and_media_id = value[6:]
|
||||
fragment = None
|
||||
if "#" in server_and_media_id:
|
||||
server_and_media_id, fragment = server_and_media_id.split("#", 1)
|
||||
fragment = "#" + fragment
|
||||
|
||||
params = {"width": width, "height": height, "method": resize_method}
|
||||
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
|
||||
public_baseurl,
|
||||
server_and_media_id,
|
||||
urllib.parse.urlencode(params),
|
||||
fragment or "",
|
||||
)
|
||||
|
||||
return mxc_to_http_filter
|
||||
|
||||
|
||||
class RootConfig(object):
|
||||
"""
|
||||
|
||||
@@ -23,7 +23,6 @@ from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import attr
|
||||
import pkg_resources
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
@@ -98,21 +97,18 @@ class EmailConfig(Config):
|
||||
if parsed[1] == "":
|
||||
raise RuntimeError("Invalid notif_from address")
|
||||
|
||||
# A user-configurable template directory
|
||||
template_dir = email_config.get("template_dir")
|
||||
# we need an absolute path, because we change directory after starting (and
|
||||
# we don't yet know what auxiliary templates like mail.css we will need).
|
||||
# (Note that loading as package_resources with jinja.PackageLoader doesn't
|
||||
# work for the same reason.)
|
||||
if not template_dir:
|
||||
template_dir = pkg_resources.resource_filename("synapse", "res/templates")
|
||||
|
||||
self.email_template_dir = os.path.abspath(template_dir)
|
||||
if isinstance(template_dir, str):
|
||||
# We need an absolute path, because we change directory after starting (and
|
||||
# we don't yet know what auxiliary templates like mail.css we will need).
|
||||
template_dir = os.path.abspath(template_dir)
|
||||
elif template_dir is not None:
|
||||
# If template_dir is something other than a str or None, warn the user
|
||||
raise ConfigError("Config option email.template_dir must be type str")
|
||||
|
||||
self.email_enable_notifs = email_config.get("enable_notifs", False)
|
||||
|
||||
account_validity_config = config.get("account_validity") or {}
|
||||
account_validity_renewal_enabled = account_validity_config.get("renew_at")
|
||||
|
||||
self.threepid_behaviour_email = (
|
||||
# Have Synapse handle the email sending if account_threepid_delegates.email
|
||||
# is not defined
|
||||
@@ -166,19 +162,6 @@ class EmailConfig(Config):
|
||||
email_config.get("validation_token_lifetime", "1h")
|
||||
)
|
||||
|
||||
if (
|
||||
self.email_enable_notifs
|
||||
or account_validity_renewal_enabled
|
||||
or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
|
||||
):
|
||||
# make sure we can import the required deps
|
||||
import bleach
|
||||
import jinja2
|
||||
|
||||
# prevent unused warnings
|
||||
jinja2
|
||||
bleach
|
||||
|
||||
if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
missing = []
|
||||
if not self.email_notif_from:
|
||||
@@ -196,49 +179,49 @@ class EmailConfig(Config):
|
||||
|
||||
# These email templates have placeholders in them, and thus must be
|
||||
# parsed using a templating engine during a request
|
||||
self.email_password_reset_template_html = email_config.get(
|
||||
password_reset_template_html = email_config.get(
|
||||
"password_reset_template_html", "password_reset.html"
|
||||
)
|
||||
self.email_password_reset_template_text = email_config.get(
|
||||
password_reset_template_text = email_config.get(
|
||||
"password_reset_template_text", "password_reset.txt"
|
||||
)
|
||||
self.email_registration_template_html = email_config.get(
|
||||
registration_template_html = email_config.get(
|
||||
"registration_template_html", "registration.html"
|
||||
)
|
||||
self.email_registration_template_text = email_config.get(
|
||||
registration_template_text = email_config.get(
|
||||
"registration_template_text", "registration.txt"
|
||||
)
|
||||
self.email_add_threepid_template_html = email_config.get(
|
||||
add_threepid_template_html = email_config.get(
|
||||
"add_threepid_template_html", "add_threepid.html"
|
||||
)
|
||||
self.email_add_threepid_template_text = email_config.get(
|
||||
add_threepid_template_text = email_config.get(
|
||||
"add_threepid_template_text", "add_threepid.txt"
|
||||
)
|
||||
|
||||
self.email_password_reset_template_failure_html = email_config.get(
|
||||
password_reset_template_failure_html = email_config.get(
|
||||
"password_reset_template_failure_html", "password_reset_failure.html"
|
||||
)
|
||||
self.email_registration_template_failure_html = email_config.get(
|
||||
registration_template_failure_html = email_config.get(
|
||||
"registration_template_failure_html", "registration_failure.html"
|
||||
)
|
||||
self.email_add_threepid_template_failure_html = email_config.get(
|
||||
add_threepid_template_failure_html = email_config.get(
|
||||
"add_threepid_template_failure_html", "add_threepid_failure.html"
|
||||
)
|
||||
|
||||
# These templates do not support any placeholder variables, so we
|
||||
# will read them from disk once during setup
|
||||
email_password_reset_template_success_html = email_config.get(
|
||||
password_reset_template_success_html = email_config.get(
|
||||
"password_reset_template_success_html", "password_reset_success.html"
|
||||
)
|
||||
email_registration_template_success_html = email_config.get(
|
||||
registration_template_success_html = email_config.get(
|
||||
"registration_template_success_html", "registration_success.html"
|
||||
)
|
||||
email_add_threepid_template_success_html = email_config.get(
|
||||
add_threepid_template_success_html = email_config.get(
|
||||
"add_threepid_template_success_html", "add_threepid_success.html"
|
||||
)
|
||||
|
||||
# Check templates exist
|
||||
for f in [
|
||||
# Read all templates from disk
|
||||
(
|
||||
self.email_password_reset_template_html,
|
||||
self.email_password_reset_template_text,
|
||||
self.email_registration_template_html,
|
||||
@@ -248,32 +231,36 @@ class EmailConfig(Config):
|
||||
self.email_password_reset_template_failure_html,
|
||||
self.email_registration_template_failure_html,
|
||||
self.email_add_threepid_template_failure_html,
|
||||
email_password_reset_template_success_html,
|
||||
email_registration_template_success_html,
|
||||
email_add_threepid_template_success_html,
|
||||
]:
|
||||
p = os.path.join(self.email_template_dir, f)
|
||||
if not os.path.isfile(p):
|
||||
raise ConfigError("Unable to find template file %s" % (p,))
|
||||
password_reset_template_success_html_template,
|
||||
registration_template_success_html_template,
|
||||
add_threepid_template_success_html_template,
|
||||
) = self.read_templates(
|
||||
[
|
||||
password_reset_template_html,
|
||||
password_reset_template_text,
|
||||
registration_template_html,
|
||||
registration_template_text,
|
||||
add_threepid_template_html,
|
||||
add_threepid_template_text,
|
||||
password_reset_template_failure_html,
|
||||
registration_template_failure_html,
|
||||
add_threepid_template_failure_html,
|
||||
password_reset_template_success_html,
|
||||
registration_template_success_html,
|
||||
add_threepid_template_success_html,
|
||||
],
|
||||
template_dir,
|
||||
)
|
||||
|
||||
# Retrieve content of web templates
|
||||
filepath = os.path.join(
|
||||
self.email_template_dir, email_password_reset_template_success_html
|
||||
# Render templates that do not contain any placeholders
|
||||
self.email_password_reset_template_success_html_content = (
|
||||
password_reset_template_success_html_template.render()
|
||||
)
|
||||
self.email_password_reset_template_success_html = self.read_file(
|
||||
filepath, "email.password_reset_template_success_html"
|
||||
self.email_registration_template_success_html_content = (
|
||||
registration_template_success_html_template.render()
|
||||
)
|
||||
filepath = os.path.join(
|
||||
self.email_template_dir, email_registration_template_success_html
|
||||
)
|
||||
self.email_registration_template_success_html_content = self.read_file(
|
||||
filepath, "email.registration_template_success_html"
|
||||
)
|
||||
filepath = os.path.join(
|
||||
self.email_template_dir, email_add_threepid_template_success_html
|
||||
)
|
||||
self.email_add_threepid_template_success_html_content = self.read_file(
|
||||
filepath, "email.add_threepid_template_success_html"
|
||||
self.email_add_threepid_template_success_html_content = (
|
||||
add_threepid_template_success_html_template.render()
|
||||
)
|
||||
|
||||
if self.email_enable_notifs:
|
||||
@@ -290,17 +277,19 @@ class EmailConfig(Config):
|
||||
% (", ".join(missing),)
|
||||
)
|
||||
|
||||
self.email_notif_template_html = email_config.get(
|
||||
notif_template_html = email_config.get(
|
||||
"notif_template_html", "notif_mail.html"
|
||||
)
|
||||
self.email_notif_template_text = email_config.get(
|
||||
notif_template_text = email_config.get(
|
||||
"notif_template_text", "notif_mail.txt"
|
||||
)
|
||||
|
||||
for f in self.email_notif_template_text, self.email_notif_template_html:
|
||||
p = os.path.join(self.email_template_dir, f)
|
||||
if not os.path.isfile(p):
|
||||
raise ConfigError("Unable to find email template file %s" % (p,))
|
||||
(
|
||||
self.email_notif_template_html,
|
||||
self.email_notif_template_text,
|
||||
) = self.read_templates(
|
||||
[notif_template_html, notif_template_text], template_dir,
|
||||
)
|
||||
|
||||
self.email_notif_for_new_users = email_config.get(
|
||||
"notif_for_new_users", True
|
||||
@@ -309,18 +298,20 @@ class EmailConfig(Config):
|
||||
"client_base_url", email_config.get("riot_base_url", None)
|
||||
)
|
||||
|
||||
if account_validity_renewal_enabled:
|
||||
self.email_expiry_template_html = email_config.get(
|
||||
if self.account_validity.renew_by_email_enabled:
|
||||
expiry_template_html = email_config.get(
|
||||
"expiry_template_html", "notice_expiry.html"
|
||||
)
|
||||
self.email_expiry_template_text = email_config.get(
|
||||
expiry_template_text = email_config.get(
|
||||
"expiry_template_text", "notice_expiry.txt"
|
||||
)
|
||||
|
||||
for f in self.email_expiry_template_text, self.email_expiry_template_html:
|
||||
p = os.path.join(self.email_template_dir, f)
|
||||
if not os.path.isfile(p):
|
||||
raise ConfigError("Unable to find email template file %s" % (p,))
|
||||
(
|
||||
self.account_validity_template_html,
|
||||
self.account_validity_template_text,
|
||||
) = self.read_templates(
|
||||
[expiry_template_html, expiry_template_text], template_dir,
|
||||
)
|
||||
|
||||
subjects_config = email_config.get("subjects", {})
|
||||
subjects = {}
|
||||
@@ -400,9 +391,7 @@ class EmailConfig(Config):
|
||||
# Directory in which Synapse will try to find the template files below.
|
||||
# If not set, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
|
||||
# If you *do* uncomment it, you will need to make sure that all the templates
|
||||
# below are in the directory.
|
||||
# Do not uncomment this setting unless you want to customise the templates.
|
||||
#
|
||||
# Synapse will look for the following templates in this directory:
|
||||
#
|
||||
|
||||
@@ -18,8 +18,6 @@ import logging
|
||||
from typing import Any, List
|
||||
|
||||
import attr
|
||||
import jinja2
|
||||
import pkg_resources
|
||||
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
from synapse.util.module_loader import load_module, load_python_module
|
||||
@@ -171,15 +169,9 @@ class SAML2Config(Config):
|
||||
saml2_config.get("saml_session_lifetime", "15m")
|
||||
)
|
||||
|
||||
template_dir = saml2_config.get("template_dir")
|
||||
if not template_dir:
|
||||
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
|
||||
|
||||
loader = jinja2.FileSystemLoader(template_dir)
|
||||
# enable auto-escape here, to having to remember to escape manually in the
|
||||
# template
|
||||
env = jinja2.Environment(loader=loader, autoescape=True)
|
||||
self.saml2_error_html_template = env.get_template("saml_error.html")
|
||||
self.saml2_error_html_template = self.read_templates(
|
||||
["saml_error.html"], saml2_config.get("template_dir")
|
||||
)
|
||||
|
||||
def _default_saml_config_dict(
|
||||
self, required_attributes: set, optional_attributes: set
|
||||
|
||||
@@ -26,7 +26,6 @@ import yaml
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
@@ -508,8 +507,6 @@ class ServerConfig(Config):
|
||||
)
|
||||
)
|
||||
|
||||
_check_resource_config(self.listeners)
|
||||
|
||||
self.cleanup_extremities_with_dummy_events = config.get(
|
||||
"cleanup_extremities_with_dummy_events", True
|
||||
)
|
||||
@@ -1133,20 +1130,3 @@ def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
|
||||
if name == "webclient":
|
||||
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
|
||||
return
|
||||
|
||||
|
||||
def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
|
||||
resource_names = {
|
||||
res_name
|
||||
for listener in listeners
|
||||
if listener.http_options
|
||||
for res in listener.http_options.resources
|
||||
for res_name in res.names
|
||||
}
|
||||
|
||||
for resource in resource_names:
|
||||
if resource == "consent":
|
||||
try:
|
||||
check_requirements("resources.consent")
|
||||
except DependencyException as e:
|
||||
raise ConfigError(e.message)
|
||||
|
||||
@@ -12,11 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
@@ -29,22 +26,32 @@ class SSOConfig(Config):
|
||||
def read_config(self, config, **kwargs):
|
||||
sso_config = config.get("sso") or {} # type: Dict[str, Any]
|
||||
|
||||
# Pick a template directory in order of:
|
||||
# * The sso-specific template_dir
|
||||
# * /path/to/synapse/install/res/templates
|
||||
# The sso-specific template_dir
|
||||
template_dir = sso_config.get("template_dir")
|
||||
if not template_dir:
|
||||
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
|
||||
|
||||
self.sso_template_dir = template_dir
|
||||
self.sso_account_deactivated_template = self.read_file(
|
||||
os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
|
||||
"sso_account_deactivated_template",
|
||||
# Read templates from disk
|
||||
(
|
||||
self.sso_redirect_confirm_template,
|
||||
self.sso_auth_confirm_template,
|
||||
self.sso_error_template,
|
||||
sso_account_deactivated_template,
|
||||
sso_auth_success_template,
|
||||
) = self.read_templates(
|
||||
[
|
||||
"sso_redirect_confirm.html",
|
||||
"sso_auth_confirm.html",
|
||||
"sso_error.html",
|
||||
"sso_account_deactivated.html",
|
||||
"sso_auth_success.html",
|
||||
],
|
||||
template_dir,
|
||||
)
|
||||
self.sso_auth_success_template = self.read_file(
|
||||
os.path.join(self.sso_template_dir, "sso_auth_success.html"),
|
||||
"sso_auth_success_template",
|
||||
|
||||
# These templates have no placeholders, so render them here
|
||||
self.sso_account_deactivated_template = (
|
||||
sso_account_deactivated_template.render()
|
||||
)
|
||||
self.sso_auth_success_template = sso_auth_success_template.render()
|
||||
|
||||
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ from synapse.events import EventBase, builder
|
||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
@@ -217,15 +217,13 @@ class FederationClient(FederationBase):
|
||||
for p in transaction_data["pdus"]
|
||||
]
|
||||
|
||||
pdus[:] = await self._check_sigs_and_hash_and_fetch(
|
||||
dest,
|
||||
list(pdus),
|
||||
outlier=True,
|
||||
room_version=room_version,
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
pdus[:] = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
logger.info("DDD pdus ended up as: %s", pdus)
|
||||
|
||||
return pdus
|
||||
|
||||
async def get_pdu(
|
||||
@@ -388,11 +386,10 @@ class FederationClient(FederationBase):
|
||||
pdu.event_id, allow_rejected=True, allow_none=True
|
||||
)
|
||||
|
||||
pdu_origin = get_domain_from_id(pdu.sender)
|
||||
if not res and pdu_origin != origin:
|
||||
if not res and pdu.origin != origin:
|
||||
try:
|
||||
res = await self.get_pdu(
|
||||
destinations=[pdu_origin],
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
room_version=room_version,
|
||||
outlier=outlier,
|
||||
|
||||
@@ -37,8 +37,8 @@ from sortedcontainers import SortedDict
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
from .units import Edu
|
||||
|
||||
@@ -22,6 +22,7 @@ from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
import synapse.metrics
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
|
||||
from synapse.federation.sender.transaction_manager import TransactionManager
|
||||
@@ -39,7 +40,6 @@ from synapse.metrics import (
|
||||
events_processed_counter,
|
||||
)
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import ReadReceipt
|
||||
from synapse.util.metrics import Measure, measure_func
|
||||
|
||||
|
||||
@@ -24,12 +24,12 @@ from synapse.api.errors import (
|
||||
HttpResponseException,
|
||||
RequestSendFailed,
|
||||
)
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.units import Edu
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.metrics import sent_transactions_counter
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import ReadReceipt
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
|
||||
@@ -337,6 +337,28 @@ class PerDestinationQueue(object):
|
||||
(e.retry_last_ts + e.retry_interval) / 1000.0
|
||||
),
|
||||
)
|
||||
|
||||
if e.retry_interval > 60 * 60 * 1000:
|
||||
# we won't retry for another hour!
|
||||
# (this suggests a significant outage)
|
||||
# We drop pending PDUs and EDUs because otherwise they will
|
||||
# rack up indefinitely.
|
||||
# Note that:
|
||||
# - the EDUs that are being dropped here are those that we can
|
||||
# afford to drop (specifically, only typing notifications,
|
||||
# read receipts and presence updates are being dropped here)
|
||||
# - Other EDUs such as to_device messages are queued with a
|
||||
# different mechanism
|
||||
# - this is all volatile state that would be lost if the
|
||||
# federation sender restarted anyway
|
||||
|
||||
# dropping read receipts is a bit sad but should be solved
|
||||
# through another mechanism, because this is all volatile!
|
||||
self._pending_pdus = []
|
||||
self._pending_edus = []
|
||||
self._pending_edus_keyed = {}
|
||||
self._pending_presence = {}
|
||||
self._pending_rrs = {}
|
||||
except FederationDeniedError as e:
|
||||
logger.info(e)
|
||||
except HttpResponseException as e:
|
||||
|
||||
@@ -719,6 +719,27 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
async def change_user_admin_in_group(
|
||||
self, group_id, user_id, want_admin, requester_user_id, content
|
||||
):
|
||||
"""Promotes or demotes a user in a group.
|
||||
"""
|
||||
|
||||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
if requester_user_id == user_id:
|
||||
raise SynapseError(400, "User cannot target themselves")
|
||||
|
||||
is_admin = await self.store.is_user_admin_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
if not is_admin:
|
||||
raise SynapseError(403, "User is not admin in group")
|
||||
|
||||
await self.store.change_user_admin_in_group(group_id, user_id, want_admin)
|
||||
|
||||
return {}
|
||||
|
||||
async def remove_user_from_group(
|
||||
self, group_id, user_id, requester_user_id, content
|
||||
):
|
||||
|
||||
@@ -26,11 +26,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import UserID
|
||||
from synapse.util import stringutils
|
||||
|
||||
try:
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
except ImportError:
|
||||
load_jinja2_templates = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -47,9 +42,11 @@ class AccountValidityHandler(object):
|
||||
if (
|
||||
self._account_validity.enabled
|
||||
and self._account_validity.renew_by_email_enabled
|
||||
and load_jinja2_templates
|
||||
):
|
||||
# Don't do email-specific configuration if renewal by email is disabled.
|
||||
self._template_html = self.config.account_validity_template_html
|
||||
self._template_text = self.config.account_validity_template_text
|
||||
|
||||
try:
|
||||
app_name = self.hs.config.email_app_name
|
||||
|
||||
@@ -65,17 +62,6 @@ class AccountValidityHandler(object):
|
||||
|
||||
self._raw_from = email.utils.parseaddr(self._from_string)[1]
|
||||
|
||||
self._template_html, self._template_text = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[
|
||||
self.config.email_expiry_template_html,
|
||||
self.config.email_expiry_template_text,
|
||||
],
|
||||
apply_format_ts_filter=True,
|
||||
apply_mxc_to_http_filter=True,
|
||||
public_baseurl=self.config.public_baseurl,
|
||||
)
|
||||
|
||||
# Check the renewal emails to send and send them every 30min.
|
||||
def send_emails():
|
||||
# run as a background process to make sure that the database transactions
|
||||
|
||||
@@ -42,7 +42,6 @@ from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
@@ -132,18 +131,17 @@ class AuthHandler(BaseHandler):
|
||||
# after the SSO completes and before redirecting them back to their client.
|
||||
# It notifies the user they are about to give access to their matrix account
|
||||
# to the client.
|
||||
self._sso_redirect_confirm_template = load_jinja2_templates(
|
||||
hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
|
||||
)[0]
|
||||
self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
|
||||
|
||||
# The following template is shown during user interactive authentication
|
||||
# in the fallback auth scenario. It notifies the user that they are
|
||||
# authenticating for an operation to occur on their account.
|
||||
self._sso_auth_confirm_template = load_jinja2_templates(
|
||||
hs.config.sso_template_dir, ["sso_auth_confirm.html"],
|
||||
)[0]
|
||||
self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
|
||||
|
||||
# The following template is shown after a successful user interactive
|
||||
# authentication session. It tells the user they can close the window.
|
||||
self._sso_auth_success_template = hs.config.sso_auth_success_template
|
||||
|
||||
# The following template is shown during the SSO authentication process if
|
||||
# the account is deactivated.
|
||||
self._sso_account_deactivated_template = (
|
||||
|
||||
@@ -461,6 +461,25 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
|
||||
return {"state": "invite", "user_profile": user_profile}
|
||||
|
||||
async def change_user_admin_in_group(
|
||||
self, group_id, user_id, want_admin, requester_user_id, content
|
||||
):
|
||||
"""Promotes or demotes a user in a group.
|
||||
"""
|
||||
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(400, "User not on this server")
|
||||
|
||||
# TODO: We should probably support federation, but this is fine for now
|
||||
if not self.is_mine_id(group_id):
|
||||
raise SynapseError(400, "Group not on this server")
|
||||
|
||||
res = await self.groups_server_handler.change_user_admin_in_group(
|
||||
group_id, user_id, want_admin, requester_user_id, content
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
async def remove_user_from_group(
|
||||
self, group_id, user_id, requester_user_id, content
|
||||
):
|
||||
|
||||
@@ -667,14 +667,14 @@ class EventCreationHandler(object):
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if event.is_state():
|
||||
prev_state = await self.deduplicate_state_event(event, context)
|
||||
if prev_state is not None:
|
||||
prev_event = await self.deduplicate_state_event(event, context)
|
||||
if prev_event is not None:
|
||||
logger.info(
|
||||
"Not bothering to persist state event %s duplicated by %s",
|
||||
event.event_id,
|
||||
prev_state.event_id,
|
||||
prev_event.event_id,
|
||||
)
|
||||
return prev_state
|
||||
return await self.store.get_stream_id_for_event(prev_event.event_id)
|
||||
|
||||
return await self.handle_new_client_event(
|
||||
requester=requester, event=event, context=context, ratelimit=ratelimit
|
||||
@@ -682,27 +682,32 @@ class EventCreationHandler(object):
|
||||
|
||||
async def deduplicate_state_event(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> None:
|
||||
) -> Optional[EventBase]:
|
||||
"""
|
||||
Checks whether event is in the latest resolved state in context.
|
||||
|
||||
If so, returns the version of the event in context.
|
||||
Otherwise, returns None.
|
||||
Args:
|
||||
event: The event to check for duplication.
|
||||
context: The event context.
|
||||
|
||||
Returns:
|
||||
The previous verion of the event is returned, if it is found in the
|
||||
event context. Otherwise, None is returned.
|
||||
"""
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
prev_event_id = prev_state_ids.get((event.type, event.state_key))
|
||||
if not prev_event_id:
|
||||
return
|
||||
return None
|
||||
prev_event = await self.store.get_event(prev_event_id, allow_none=True)
|
||||
if not prev_event:
|
||||
return
|
||||
return None
|
||||
|
||||
if prev_event and event.user_id == prev_event.user_id:
|
||||
prev_content = encode_canonical_json(prev_event.content)
|
||||
next_content = encode_canonical_json(event.content)
|
||||
if prev_content == next_content:
|
||||
return prev_event
|
||||
return
|
||||
return None
|
||||
|
||||
async def create_and_send_nonmember_event(
|
||||
self,
|
||||
@@ -891,9 +896,7 @@ class EventCreationHandler(object):
|
||||
except Exception:
|
||||
# Ensure that we actually remove the entries in the push actions
|
||||
# staging area, if we calculated them.
|
||||
run_in_background(
|
||||
self.store.remove_push_actions_from_staging, event.event_id
|
||||
)
|
||||
await self.store.remove_push_actions_from_staging(event.event_id)
|
||||
raise
|
||||
|
||||
async def _validate_canonical_alias(
|
||||
|
||||
@@ -38,7 +38,6 @@ from synapse.config import ConfigError
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -123,9 +122,7 @@ class OidcHandler:
|
||||
self._hostname = hs.hostname # type: str
|
||||
self._server_name = hs.config.server_name # type: str
|
||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||
self._error_template = load_jinja2_templates(
|
||||
hs.config.sso_template_dir, ["sso_error.html"]
|
||||
)[0]
|
||||
self._error_template = hs.config.sso_error_template
|
||||
|
||||
# identifier for the external_ids table
|
||||
self._auth_provider_id = "oidc"
|
||||
|
||||
@@ -33,13 +33,13 @@ from typing_extensions import ContextManager
|
||||
import synapse.metrics
|
||||
from synapse.api.constants import EventTypes, Membership, PresenceState
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
@@ -142,6 +142,7 @@ class RegistrationHandler(BaseHandler):
|
||||
address=None,
|
||||
bind_emails=[],
|
||||
by_admin=False,
|
||||
shadow_banned=False,
|
||||
):
|
||||
"""Registers a new client on the server.
|
||||
|
||||
@@ -159,6 +160,7 @@ class RegistrationHandler(BaseHandler):
|
||||
bind_emails (List[str]): list of emails to bind to this account.
|
||||
by_admin (bool): True if this registration is being made via the
|
||||
admin api, otherwise False.
|
||||
shadow_banned (bool): Shadow-ban the created user.
|
||||
Returns:
|
||||
str: user_id
|
||||
Raises:
|
||||
@@ -194,6 +196,7 @@ class RegistrationHandler(BaseHandler):
|
||||
admin=admin,
|
||||
user_type=user_type,
|
||||
address=address,
|
||||
shadow_banned=shadow_banned,
|
||||
)
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
@@ -224,6 +227,7 @@ class RegistrationHandler(BaseHandler):
|
||||
make_guest=make_guest,
|
||||
create_profile_with_displayname=default_display_name,
|
||||
address=address,
|
||||
shadow_banned=shadow_banned,
|
||||
)
|
||||
|
||||
# Successfully registered
|
||||
@@ -529,6 +533,7 @@ class RegistrationHandler(BaseHandler):
|
||||
admin=False,
|
||||
user_type=None,
|
||||
address=None,
|
||||
shadow_banned=False,
|
||||
):
|
||||
"""Register user in the datastore.
|
||||
|
||||
@@ -546,6 +551,7 @@ class RegistrationHandler(BaseHandler):
|
||||
user_type (str|None): type of user. One of the values from
|
||||
api.constants.UserTypes, or None for a normal user.
|
||||
address (str|None): the IP address used to perform the registration.
|
||||
shadow_banned (bool): Whether to shadow-ban the user
|
||||
|
||||
Returns:
|
||||
Awaitable
|
||||
@@ -561,6 +567,7 @@ class RegistrationHandler(BaseHandler):
|
||||
admin=admin,
|
||||
user_type=user_type,
|
||||
address=address,
|
||||
shadow_banned=shadow_banned,
|
||||
)
|
||||
else:
|
||||
return self.store.register_user(
|
||||
@@ -572,6 +579,7 @@ class RegistrationHandler(BaseHandler):
|
||||
create_profile_with_displayname=create_profile_with_displayname,
|
||||
admin=admin,
|
||||
user_type=user_type,
|
||||
shadow_banned=shadow_banned,
|
||||
)
|
||||
|
||||
async def register_device(
|
||||
|
||||
@@ -22,7 +22,7 @@ import logging
|
||||
import math
|
||||
import string
|
||||
from collections import OrderedDict
|
||||
from typing import Awaitable, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import (
|
||||
EventTypes,
|
||||
@@ -32,11 +32,14 @@ from synapse.api.constants import (
|
||||
RoomEncryptionAlgorithms,
|
||||
)
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import copy_power_levels_contents
|
||||
from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
@@ -53,6 +56,9 @@ from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
id_server_scheme = "https://"
|
||||
@@ -61,7 +67,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
|
||||
|
||||
|
||||
class RoomCreationHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super(RoomCreationHandler, self).__init__(hs)
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
@@ -92,7 +98,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
"guest_can_join": False,
|
||||
"power_level_content_override": {},
|
||||
},
|
||||
}
|
||||
} # type: Dict[str, Dict[str, Any]]
|
||||
|
||||
# Modify presets to selectively enable encryption by default per homeserver config
|
||||
for preset_name, preset_config in self._presets_dict.items():
|
||||
@@ -215,6 +221,9 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
old_room_state = await tombstone_context.get_current_state_ids()
|
||||
|
||||
# We know the tombstone event isn't an outlier so it has current state.
|
||||
assert old_room_state is not None
|
||||
|
||||
# update any aliases
|
||||
await self._move_aliases_to_new_room(
|
||||
requester, old_room_id, new_room_id, old_room_state
|
||||
@@ -528,17 +537,21 @@ class RoomCreationHandler(BaseHandler):
|
||||
logger.error("Unable to send updated alias events in new room: %s", e)
|
||||
|
||||
async def create_room(
|
||||
self, requester, config, ratelimit=True, creator_join_profile=None
|
||||
self,
|
||||
requester: Requester,
|
||||
config: JsonDict,
|
||||
ratelimit: bool = True,
|
||||
creator_join_profile: Optional[JsonDict] = None,
|
||||
) -> Tuple[dict, int]:
|
||||
""" Creates a new room.
|
||||
|
||||
Args:
|
||||
requester (synapse.types.Requester):
|
||||
requester:
|
||||
The user who requested the room creation.
|
||||
config (dict) : A dict of configuration options.
|
||||
ratelimit (bool): set to False to disable the rate limiter
|
||||
config : A dict of configuration options.
|
||||
ratelimit: set to False to disable the rate limiter
|
||||
|
||||
creator_join_profile (dict|None):
|
||||
creator_join_profile:
|
||||
Set to override the displayname and avatar for the creating
|
||||
user in this room. If unset, displayname and avatar will be
|
||||
derived from the user's profile. If set, should contain the
|
||||
@@ -601,6 +614,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
Codes.UNSUPPORTED_ROOM_VERSION,
|
||||
)
|
||||
|
||||
room_alias = None
|
||||
if "room_alias_name" in config:
|
||||
for wchar in string.whitespace:
|
||||
if wchar in config["room_alias_name"]:
|
||||
@@ -611,8 +625,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
if mapping:
|
||||
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
|
||||
else:
|
||||
room_alias = None
|
||||
|
||||
invite_list = config.get("invite", [])
|
||||
for i in invite_list:
|
||||
@@ -771,23 +783,30 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
async def _send_events_for_new_room(
|
||||
self,
|
||||
creator, # A Requester object.
|
||||
room_id,
|
||||
preset_config,
|
||||
invite_list,
|
||||
initial_state,
|
||||
creation_content,
|
||||
room_alias=None,
|
||||
power_level_content_override=None, # Doesn't apply when initial state has power level state event content
|
||||
creator_join_profile=None,
|
||||
creator: Requester,
|
||||
room_id: str,
|
||||
preset_config: str,
|
||||
invite_list: List[str],
|
||||
initial_state: StateMap,
|
||||
creation_content: JsonDict,
|
||||
room_alias: Optional[RoomAlias] = None,
|
||||
power_level_content_override: Optional[JsonDict] = None,
|
||||
creator_join_profile: Optional[JsonDict] = None,
|
||||
) -> int:
|
||||
"""Sends the initial events into a new room.
|
||||
|
||||
`power_level_content_override` doesn't apply when initial state has
|
||||
power level state event content.
|
||||
|
||||
Returns:
|
||||
The stream_id of the last event persisted.
|
||||
"""
|
||||
|
||||
def create(etype, content, **kwargs):
|
||||
creator_id = creator.user.to_string()
|
||||
|
||||
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
|
||||
|
||||
def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
|
||||
e = {"type": etype, "content": content}
|
||||
|
||||
e.update(event_keys)
|
||||
@@ -795,7 +814,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
return e
|
||||
|
||||
async def send(etype, content, **kwargs) -> int:
|
||||
async def send(etype: str, content: JsonDict, **kwargs) -> int:
|
||||
event = create(etype, content, **kwargs)
|
||||
logger.debug("Sending %s in new room", etype)
|
||||
(
|
||||
@@ -808,10 +827,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
config = self._presets_dict[preset_config]
|
||||
|
||||
creator_id = creator.user.to_string()
|
||||
|
||||
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
|
||||
|
||||
creation_content.update({"creator": creator_id})
|
||||
await send(etype=EventTypes.Create, content=creation_content)
|
||||
|
||||
@@ -852,7 +867,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
"kick": 50,
|
||||
"redact": 50,
|
||||
"invite": 50,
|
||||
}
|
||||
} # type: JsonDict
|
||||
|
||||
if config["original_invitees_have_ops"]:
|
||||
for invitee in invite_list:
|
||||
@@ -906,7 +921,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
return last_sent_stream_id
|
||||
|
||||
async def _generate_room_id(
|
||||
self, creator_id: str, is_public: str, room_version: RoomVersion,
|
||||
self, creator_id: str, is_public: bool, room_version: RoomVersion,
|
||||
):
|
||||
# autogen room IDs and try to create it. We may clash, so just
|
||||
# try a few times till one goes through, giving up eventually.
|
||||
@@ -930,23 +945,30 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
|
||||
class RoomContextHandler(object):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
||||
async def get_event_context(self, user, room_id, event_id, limit, event_filter):
|
||||
async def get_event_context(
|
||||
self,
|
||||
user: UserID,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
limit: int,
|
||||
event_filter: Optional[Filter],
|
||||
) -> Optional[JsonDict]:
|
||||
"""Retrieves events, pagination tokens and state around a given event
|
||||
in a room.
|
||||
|
||||
Args:
|
||||
user (UserID)
|
||||
room_id (str)
|
||||
event_id (str)
|
||||
limit (int): The maximum number of events to return in total
|
||||
user
|
||||
room_id
|
||||
event_id
|
||||
limit: The maximum number of events to return in total
|
||||
(excluding state).
|
||||
event_filter (Filter|None): the filter to apply to the events returned
|
||||
event_filter: the filter to apply to the events returned
|
||||
(excluding the target event_id)
|
||||
|
||||
Returns:
|
||||
@@ -1033,12 +1055,18 @@ class RoomContextHandler(object):
|
||||
|
||||
|
||||
class RoomEventSource(object):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def get_new_events(
|
||||
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
|
||||
):
|
||||
self,
|
||||
user: UserID,
|
||||
from_key: str,
|
||||
limit: int,
|
||||
room_ids: List[str],
|
||||
is_guest: bool,
|
||||
explicit_room_id: Optional[str] = None,
|
||||
) -> Tuple[List[EventBase], str]:
|
||||
# We just ignore the key for now.
|
||||
|
||||
to_key = self.get_current_key()
|
||||
@@ -1096,7 +1124,7 @@ class RoomShutdownHandler(object):
|
||||
)
|
||||
DEFAULT_ROOM_NAME = "Content Violation Notification"
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self._room_creation_handler = hs.get_room_creation_handler()
|
||||
|
||||
@@ -210,40 +210,24 @@ class RoomMemberHandler(object):
|
||||
_, stream_id = await self.store.get_event_ordering(duplicate.event_id)
|
||||
return duplicate.event_id, stream_id
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
|
||||
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
||||
|
||||
newly_joined = False
|
||||
if event.membership == Membership.JOIN:
|
||||
newly_joined = True
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
|
||||
# Only rate-limit if the user actually joined the room, otherwise we'll end
|
||||
# up blocking profile updates.
|
||||
if newly_joined:
|
||||
time_now_s = self.clock.time()
|
||||
(
|
||||
allowed,
|
||||
time_allowed,
|
||||
) = self._join_rate_limiter_local.can_requester_do_action(requester)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now_s))
|
||||
)
|
||||
|
||||
stream_id = await self.event_creation_handler.handle_new_client_event(
|
||||
requester, event, context, extra_users=[target], ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
if event.membership == Membership.JOIN and newly_joined:
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
|
||||
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
# Only fire user_joined_room if the user has actually joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
await self._user_joined_room(target, room_id)
|
||||
newly_joined = True
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
if newly_joined:
|
||||
await self._user_joined_room(target, room_id)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
@@ -473,12 +457,22 @@ class RoomMemberHandler(object):
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
if not is_host_in_room:
|
||||
if is_host_in_room:
|
||||
time_now_s = self.clock.time()
|
||||
(
|
||||
allowed,
|
||||
time_allowed,
|
||||
) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
|
||||
allowed, time_allowed = self._join_rate_limiter_local.can_do_action(
|
||||
requester.user.to_string(),
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now_s))
|
||||
)
|
||||
|
||||
else:
|
||||
time_now_s = self.clock.time()
|
||||
allowed, time_allowed = self._join_rate_limiter_remote.can_do_action(
|
||||
requester.user.to_string(),
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
|
||||
@@ -441,7 +441,6 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
headers_dict[b"Authorization"] = auth_headers
|
||||
|
||||
"""
|
||||
logger.debug(
|
||||
"{%s} [%s] Sending request: %s %s; timeout %fs",
|
||||
request.txn_id,
|
||||
@@ -450,7 +449,6 @@ class MatrixFederationHttpClient(object):
|
||||
url_str,
|
||||
_sec_timeout,
|
||||
)
|
||||
"""
|
||||
|
||||
outgoing_requests_counter.labels(request.method).inc()
|
||||
|
||||
|
||||
@@ -22,12 +22,13 @@ import types
|
||||
import urllib
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
|
||||
|
||||
import jinja2
|
||||
from canonicaljson import encode_canonical_json, encode_pretty_printed_json
|
||||
from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import defer, interfaces
|
||||
from twisted.python import failure
|
||||
from twisted.web import resource
|
||||
from twisted.web.server import NOT_DONE_YET, Request
|
||||
@@ -499,6 +500,78 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(interfaces.IPullProducer)
|
||||
class _ByteProducer:
|
||||
"""
|
||||
Iteratively write bytes to the request.
|
||||
"""
|
||||
|
||||
# The minimum number of bytes for each chunk. Note that the last chunk will
|
||||
# usually be smaller than this.
|
||||
min_chunk_size = 1024
|
||||
|
||||
def __init__(
|
||||
self, request: Request, iterator: Iterator[bytes],
|
||||
):
|
||||
self._request = request
|
||||
self._iterator = iterator
|
||||
|
||||
def start(self) -> None:
|
||||
self._request.registerProducer(self, False)
|
||||
|
||||
def _send_data(self, data: List[bytes]) -> None:
|
||||
"""
|
||||
Send a list of strings as a response to the request.
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
self._request.write(b"".join(data))
|
||||
|
||||
def resumeProducing(self) -> None:
|
||||
# We've stopped producing in the meantime (note that this might be
|
||||
# re-entrant after calling write).
|
||||
if not self._request:
|
||||
return
|
||||
|
||||
# Get the next chunk and write it to the request.
|
||||
#
|
||||
# The output of the JSON encoder is coalesced until min_chunk_size is
|
||||
# reached. (This is because JSON encoders produce a very small output
|
||||
# per iteration.)
|
||||
#
|
||||
# Note that buffer stores a list of bytes (instead of appending to
|
||||
# bytes) to hopefully avoid many allocations.
|
||||
buffer = []
|
||||
buffered_bytes = 0
|
||||
while buffered_bytes < self.min_chunk_size:
|
||||
try:
|
||||
data = next(self._iterator)
|
||||
buffer.append(data)
|
||||
buffered_bytes += len(data)
|
||||
except StopIteration:
|
||||
# The entire JSON object has been serialized, write any
|
||||
# remaining data, finalize the producer and the request, and
|
||||
# clean-up any references.
|
||||
self._send_data(buffer)
|
||||
self._request.unregisterProducer()
|
||||
self._request.finish()
|
||||
self.stopProducing()
|
||||
return
|
||||
|
||||
self._send_data(buffer)
|
||||
|
||||
def stopProducing(self) -> None:
|
||||
self._request = None
|
||||
|
||||
|
||||
def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
|
||||
"""
|
||||
Encode an object into JSON. Returns an iterator of bytes.
|
||||
"""
|
||||
for chunk in json_encoder.iterencode(json_object):
|
||||
yield chunk.encode("utf-8")
|
||||
|
||||
|
||||
def respond_with_json(
|
||||
request: Request,
|
||||
code: int,
|
||||
@@ -533,15 +606,23 @@ def respond_with_json(
|
||||
return None
|
||||
|
||||
if pretty_print:
|
||||
json_bytes = encode_pretty_printed_json(json_object) + b"\n"
|
||||
encoder = iterencode_pretty_printed_json
|
||||
else:
|
||||
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
||||
# canonicaljson already encodes to bytes
|
||||
json_bytes = encode_canonical_json(json_object)
|
||||
encoder = iterencode_canonical_json
|
||||
else:
|
||||
json_bytes = json_encoder.encode(json_object).encode("utf-8")
|
||||
encoder = _encode_json_bytes
|
||||
|
||||
return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
|
||||
request.setResponseCode(code)
|
||||
request.setHeader(b"Content-Type", b"application/json")
|
||||
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
|
||||
|
||||
if send_cors:
|
||||
set_cors_headers(request)
|
||||
|
||||
producer = _ByteProducer(request, encoder(json_object))
|
||||
producer.start()
|
||||
return NOT_DONE_YET
|
||||
|
||||
|
||||
def respond_with_json_bytes(
|
||||
|
||||
@@ -22,7 +22,6 @@ _TIME_FUNC_ID = 0
|
||||
|
||||
|
||||
def _log_debug_as_f(f, msg, msg_args):
|
||||
return
|
||||
name = f.__module__
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
|
||||
@@ -16,8 +16,7 @@
|
||||
import email.mime.multipart
|
||||
import email.utils
|
||||
import logging
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Iterable, List, TypeVar
|
||||
@@ -640,72 +639,3 @@ def string_ordinal_total(s):
|
||||
for c in s:
|
||||
tot += ord(c)
|
||||
return tot
|
||||
|
||||
|
||||
def format_ts_filter(value, format):
|
||||
return time.strftime(format, time.localtime(value / 1000))
|
||||
|
||||
|
||||
def load_jinja2_templates(
|
||||
template_dir,
|
||||
template_filenames,
|
||||
apply_format_ts_filter=False,
|
||||
apply_mxc_to_http_filter=False,
|
||||
public_baseurl=None,
|
||||
):
|
||||
"""Loads and returns one or more jinja2 templates and applies optional filters
|
||||
|
||||
Args:
|
||||
template_dir (str): The directory where templates are stored
|
||||
template_filenames (list[str]): A list of template filenames
|
||||
apply_format_ts_filter (bool): Whether to apply a template filter that formats
|
||||
timestamps
|
||||
apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts
|
||||
mxc urls to http urls
|
||||
public_baseurl (str|None): The public baseurl of the server. Required for
|
||||
apply_mxc_to_http_filter to be enabled
|
||||
|
||||
Returns:
|
||||
A list of jinja2 templates corresponding to the given list of filenames,
|
||||
with order preserved
|
||||
"""
|
||||
logger.info(
|
||||
"loading email templates %s from '%s'", template_filenames, template_dir
|
||||
)
|
||||
loader = jinja2.FileSystemLoader(template_dir)
|
||||
env = jinja2.Environment(loader=loader)
|
||||
|
||||
if apply_format_ts_filter:
|
||||
env.filters["format_ts"] = format_ts_filter
|
||||
|
||||
if apply_mxc_to_http_filter and public_baseurl:
|
||||
env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl)
|
||||
|
||||
templates = []
|
||||
for template_filename in template_filenames:
|
||||
template = env.get_template(template_filename)
|
||||
templates.append(template)
|
||||
|
||||
return templates
|
||||
|
||||
|
||||
def _create_mxc_to_http_filter(public_baseurl):
|
||||
def mxc_to_http_filter(value, width, height, resize_method="crop"):
|
||||
if value[0:6] != "mxc://":
|
||||
return ""
|
||||
|
||||
serverAndMediaId = value[6:]
|
||||
fragment = None
|
||||
if "#" in serverAndMediaId:
|
||||
(serverAndMediaId, fragment) = serverAndMediaId.split("#", 1)
|
||||
fragment = "#" + fragment
|
||||
|
||||
params = {"width": width, "height": height, "method": resize_method}
|
||||
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
|
||||
public_baseurl,
|
||||
serverAndMediaId,
|
||||
urllib.parse.urlencode(params),
|
||||
fragment or "",
|
||||
)
|
||||
|
||||
return mxc_to_http_filter
|
||||
|
||||
@@ -15,22 +15,13 @@
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.push.emailpusher import EmailPusher
|
||||
from synapse.push.mailer import Mailer
|
||||
|
||||
from .httppusher import HttpPusher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# We try importing this if we can (it will fail if we don't
|
||||
# have the optional email dependencies installed). We don't
|
||||
# yet have the config to know if we need the email pusher,
|
||||
# but importing this after daemonizing seems to fail
|
||||
# (even though a simple test of importing from a daemonized
|
||||
# process works fine)
|
||||
try:
|
||||
from synapse.push.emailpusher import EmailPusher
|
||||
from synapse.push.mailer import Mailer, load_jinja2_templates
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class PusherFactory(object):
|
||||
def __init__(self, hs):
|
||||
@@ -43,16 +34,8 @@ class PusherFactory(object):
|
||||
if hs.config.email_enable_notifs:
|
||||
self.mailers = {} # app_name -> Mailer
|
||||
|
||||
self.notif_template_html, self.notif_template_text = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[
|
||||
self.config.email_notif_template_html,
|
||||
self.config.email_notif_template_text,
|
||||
],
|
||||
apply_format_ts_filter=True,
|
||||
apply_mxc_to_http_filter=True,
|
||||
public_baseurl=self.config.public_baseurl,
|
||||
)
|
||||
self._notif_template_html = hs.config.email_notif_template_html
|
||||
self._notif_template_text = hs.config.email_notif_template_text
|
||||
|
||||
self.pusher_types["email"] = self._create_email_pusher
|
||||
|
||||
@@ -73,8 +56,8 @@ class PusherFactory(object):
|
||||
mailer = Mailer(
|
||||
hs=self.hs,
|
||||
app_name=app_name,
|
||||
template_html=self.notif_template_html,
|
||||
template_text=self.notif_template_text,
|
||||
template_html=self._notif_template_html,
|
||||
template_text=self._notif_template_text,
|
||||
)
|
||||
self.mailers[app_name] = mailer
|
||||
return EmailPusher(self.hs, pusherdict, mailer)
|
||||
|
||||
@@ -43,7 +43,7 @@ REQUIREMENTS = [
|
||||
"jsonschema>=2.5.1",
|
||||
"frozendict>=1",
|
||||
"unpaddedbase64>=1.1.0",
|
||||
"canonicaljson>=1.2.0",
|
||||
"canonicaljson>=1.3.0",
|
||||
# we use the type definitions added in signedjson 1.1.
|
||||
"signedjson>=1.1.0",
|
||||
"pynacl>=1.2.1",
|
||||
@@ -78,8 +78,6 @@ CONDITIONAL_REQUIREMENTS = {
|
||||
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
|
||||
# we use execute_batch, which arrived in psycopg 2.7.
|
||||
"postgres": ["psycopg2>=2.7"],
|
||||
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
|
||||
"resources.consent": ["Jinja2>=2.9"],
|
||||
# ACME support is required to provision TLS certificates from authorities
|
||||
# that use the protocol, such as Let's Encrypt.
|
||||
"acme": [
|
||||
|
||||
@@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||
admin,
|
||||
user_type,
|
||||
address,
|
||||
shadow_banned,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||
user_type (str|None): type of user. One of the values from
|
||||
api.constants.UserTypes, or None for a normal user.
|
||||
address (str|None): the IP address used to perform the regitration.
|
||||
shadow_banned (bool): Whether to shadow-ban the user
|
||||
"""
|
||||
return {
|
||||
"password_hash": password_hash,
|
||||
@@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||
"admin": admin,
|
||||
"user_type": user_type,
|
||||
"address": address,
|
||||
"shadow_banned": shadow_banned,
|
||||
}
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
@@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||
admin=content["admin"],
|
||||
user_type=content["user_type"],
|
||||
address=content["address"],
|
||||
shadow_banned=content["shadow_banned"],
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from synapse.api.errors import (
|
||||
NotFoundError,
|
||||
StoreError,
|
||||
@@ -163,7 +162,7 @@ class PushRuleRestServlet(RestServlet):
|
||||
stream_id, _ = self.store.get_push_rules_stream_token()
|
||||
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
|
||||
|
||||
def set_rule_attr(self, user_id, spec, val):
|
||||
async def set_rule_attr(self, user_id, spec, val):
|
||||
if spec["attr"] == "enabled":
|
||||
if isinstance(val, dict) and "enabled" in val:
|
||||
val = val["enabled"]
|
||||
@@ -173,7 +172,9 @@ class PushRuleRestServlet(RestServlet):
|
||||
# bools directly, so let's not break them.
|
||||
raise SynapseError(400, "Value for 'enabled' must be boolean")
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
|
||||
return await self.store.set_push_rule_enabled(
|
||||
user_id, namespaced_rule_id, val
|
||||
)
|
||||
elif spec["attr"] == "actions":
|
||||
actions = val.get("actions")
|
||||
_check_actions(actions)
|
||||
@@ -188,7 +189,7 @@ class PushRuleRestServlet(RestServlet):
|
||||
|
||||
if namespaced_rule_id not in rule_ids:
|
||||
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
|
||||
return self.store.set_push_rule_actions(
|
||||
return await self.store.set_push_rule_actions(
|
||||
user_id, namespaced_rule_id, actions, is_default_rule
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -32,7 +32,7 @@ from synapse.http.servlet import (
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.push.mailer import Mailer, load_jinja2_templates
|
||||
from synapse.push.mailer import Mailer
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
||||
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
|
||||
@@ -53,21 +53,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
template_html, template_text = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[
|
||||
self.config.email_password_reset_template_html,
|
||||
self.config.email_password_reset_template_text,
|
||||
],
|
||||
apply_format_ts_filter=True,
|
||||
apply_mxc_to_http_filter=True,
|
||||
public_baseurl=self.config.public_baseurl,
|
||||
)
|
||||
self.mailer = Mailer(
|
||||
hs=self.hs,
|
||||
app_name=self.config.email_app_name,
|
||||
template_html=template_html,
|
||||
template_text=template_text,
|
||||
template_html=self.config.email_password_reset_template_html,
|
||||
template_text=self.config.email_password_reset_template_text,
|
||||
)
|
||||
|
||||
async def on_POST(self, request):
|
||||
@@ -169,9 +159,8 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
(self.failure_email_template,) = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[self.config.email_password_reset_template_failure_html],
|
||||
self._failure_email_template = (
|
||||
self.config.email_password_reset_template_failure_html
|
||||
)
|
||||
|
||||
async def on_GET(self, request, medium):
|
||||
@@ -214,14 +203,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
||||
return None
|
||||
|
||||
# Otherwise show the success template
|
||||
html = self.config.email_password_reset_template_success_html
|
||||
html = self.config.email_password_reset_template_success_html_content
|
||||
status_code = 200
|
||||
except ThreepidValidationError as e:
|
||||
status_code = e.code
|
||||
|
||||
# Show a failure page with a reason
|
||||
template_vars = {"failure_reason": e.msg}
|
||||
html = self.failure_email_template.render(**template_vars)
|
||||
html = self._failure_email_template.render(**template_vars)
|
||||
|
||||
respond_with_html(request, status_code, html)
|
||||
|
||||
@@ -411,19 +400,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
template_html, template_text = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[
|
||||
self.config.email_add_threepid_template_html,
|
||||
self.config.email_add_threepid_template_text,
|
||||
],
|
||||
public_baseurl=self.config.public_baseurl,
|
||||
)
|
||||
self.mailer = Mailer(
|
||||
hs=self.hs,
|
||||
app_name=self.config.email_app_name,
|
||||
template_html=template_html,
|
||||
template_text=template_text,
|
||||
template_html=self.config.email_add_threepid_template_html,
|
||||
template_text=self.config.email_add_threepid_template_text,
|
||||
)
|
||||
|
||||
async def on_POST(self, request):
|
||||
@@ -578,9 +559,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
(self.failure_email_template,) = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[self.config.email_add_threepid_template_failure_html],
|
||||
self._failure_email_template = (
|
||||
self.config.email_add_threepid_template_failure_html
|
||||
)
|
||||
|
||||
async def on_GET(self, request):
|
||||
@@ -631,7 +611,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
|
||||
|
||||
# Show a failure page with a reason
|
||||
template_vars = {"failure_reason": e.msg}
|
||||
html = self.failure_email_template.render(**template_vars)
|
||||
html = self._failure_email_template.render(**template_vars)
|
||||
|
||||
respond_with_html(request, status_code, html)
|
||||
|
||||
|
||||
@@ -548,6 +548,31 @@ class GroupAdminUsersKickServlet(RestServlet):
|
||||
|
||||
return 200, result
|
||||
|
||||
class GroupAdminChangeAdminServlet(RestServlet):
|
||||
"""Promote or demote a user in the group
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/users/admins/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupAdminChangeAdminServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
async def on_POST(self, request, group_id, user_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
want_admin = content["is_admin"]
|
||||
result = await self.groups_handler.change_user_admin_in_group(
|
||||
group_id, user_id, want_admin, requester_user_id, content
|
||||
)
|
||||
|
||||
return 200, result
|
||||
|
||||
class GroupSelfLeaveServlet(RestServlet):
|
||||
"""Leave a joined group
|
||||
@@ -722,6 +747,7 @@ def register_servlets(hs, http_server):
|
||||
GroupAdminRoomsConfigServlet(hs).register(http_server)
|
||||
GroupAdminUsersInviteServlet(hs).register(http_server)
|
||||
GroupAdminUsersKickServlet(hs).register(http_server)
|
||||
GroupAdminChangeAdminServlet(hs).register(http_server)
|
||||
GroupSelfLeaveServlet(hs).register(http_server)
|
||||
GroupSelfJoinServlet(hs).register(http_server)
|
||||
GroupSelfAcceptInviteServlet(hs).register(http_server)
|
||||
|
||||
@@ -44,7 +44,7 @@ from synapse.http.servlet import (
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.push.mailer import Mailer
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
||||
@@ -81,23 +81,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
self.config = hs.config
|
||||
|
||||
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
from synapse.push.mailer import Mailer, load_jinja2_templates
|
||||
|
||||
template_html, template_text = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[
|
||||
self.config.email_registration_template_html,
|
||||
self.config.email_registration_template_text,
|
||||
],
|
||||
apply_format_ts_filter=True,
|
||||
apply_mxc_to_http_filter=True,
|
||||
public_baseurl=self.config.public_baseurl,
|
||||
)
|
||||
self.mailer = Mailer(
|
||||
hs=self.hs,
|
||||
app_name=self.config.email_app_name,
|
||||
template_html=template_html,
|
||||
template_text=template_text,
|
||||
template_html=self.config.email_registration_template_html,
|
||||
template_text=self.config.email_registration_template_text,
|
||||
)
|
||||
|
||||
async def on_POST(self, request):
|
||||
@@ -262,15 +250,8 @@ class RegistrationSubmitTokenServlet(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
(self.failure_email_template,) = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[self.config.email_registration_template_failure_html],
|
||||
)
|
||||
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
(self.failure_email_template,) = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[self.config.email_registration_template_failure_html],
|
||||
self._failure_email_template = (
|
||||
self.config.email_registration_template_failure_html
|
||||
)
|
||||
|
||||
async def on_GET(self, request, medium):
|
||||
@@ -318,7 +299,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
|
||||
|
||||
# Show a failure page with a reason
|
||||
template_vars = {"failure_reason": e.msg}
|
||||
html = self.failure_email_template.render(**template_vars)
|
||||
html = self._failure_email_template.render(**template_vars)
|
||||
|
||||
respond_with_html(request, status_code, html)
|
||||
|
||||
|
||||
@@ -15,12 +15,12 @@
|
||||
import logging
|
||||
from typing import Dict, Set
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from canonicaljson import json
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.crypto.keyring import ServerKeyFetcher
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -223,4 +223,4 @@ class RemoteKey(DirectServeJsonResource):
|
||||
|
||||
results = {"server_keys": signed_keys}
|
||||
|
||||
respond_with_json_bytes(request, 200, encode_canonical_json(results))
|
||||
respond_with_json(request, 200, results, canonical_json=True)
|
||||
|
||||
@@ -18,8 +18,6 @@ from typing import Optional
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
|
||||
from . import engines
|
||||
@@ -308,9 +306,8 @@ class BackgroundUpdater(object):
|
||||
update_name (str): Name of update
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def noop_update(progress, batch_size):
|
||||
yield self._end_background_update(update_name)
|
||||
async def noop_update(progress, batch_size):
|
||||
await self._end_background_update(update_name)
|
||||
return 1
|
||||
|
||||
self.register_background_update_handler(update_name, noop_update)
|
||||
@@ -409,12 +406,11 @@ class BackgroundUpdater(object):
|
||||
else:
|
||||
runner = create_index_sqlite
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def updater(progress, batch_size):
|
||||
async def updater(progress, batch_size):
|
||||
if runner is not None:
|
||||
logger.info("Adding index %s to %s", index_name, table)
|
||||
yield self.db_pool.runWithConnection(runner)
|
||||
yield self._end_background_update(update_name)
|
||||
await self.db_pool.runWithConnection(runner)
|
||||
await self._end_background_update(update_name)
|
||||
return 1
|
||||
|
||||
self.register_background_update_handler(update_name, updater)
|
||||
|
||||
@@ -332,8 +332,7 @@ class DatabasePool(object):
|
||||
"""
|
||||
return self._db_pool.running
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_safe_to_upsert(self):
|
||||
async def _check_safe_to_upsert(self):
|
||||
"""
|
||||
Is it safe to use native UPSERT?
|
||||
|
||||
@@ -342,7 +341,7 @@ class DatabasePool(object):
|
||||
|
||||
If the background updates have not completed, wait 15 sec and check again.
|
||||
"""
|
||||
updates = yield self.simple_select_list(
|
||||
updates = await self.simple_select_list(
|
||||
"background_updates",
|
||||
keyvalues=None,
|
||||
retcols=["update_name"],
|
||||
@@ -614,8 +613,7 @@ class DatabasePool(object):
|
||||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||
# no complex WHERE clauses, just a dict of values for columns.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
|
||||
async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
Args:
|
||||
@@ -631,7 +629,7 @@ class DatabasePool(object):
|
||||
`or_ignore` is True
|
||||
"""
|
||||
try:
|
||||
yield self.runInteraction(desc, self.simple_insert_txn, table, values)
|
||||
await self.runInteraction(desc, self.simple_insert_txn, table, values)
|
||||
except self.engine.module.IntegrityError:
|
||||
# We have to do or_ignore flag at this layer, since we can't reuse
|
||||
# a cursor after we receive an error from the db.
|
||||
@@ -684,8 +682,7 @@ class DatabasePool(object):
|
||||
|
||||
txn.executemany(sql, vals)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def simple_upsert(
|
||||
async def simple_upsert(
|
||||
self,
|
||||
table,
|
||||
keyvalues,
|
||||
@@ -714,14 +711,14 @@ class DatabasePool(object):
|
||||
inserting
|
||||
lock (bool): True to lock the table when doing the upsert.
|
||||
Returns:
|
||||
Deferred(None or bool): Native upserts always return None. Emulated
|
||||
None or bool: Native upserts always return None. Emulated
|
||||
upserts return True if a new entry was created, False if an existing
|
||||
one was updated.
|
||||
"""
|
||||
attempts = 0
|
||||
while True:
|
||||
try:
|
||||
result = yield self.runInteraction(
|
||||
return await self.runInteraction(
|
||||
desc,
|
||||
self.simple_upsert_txn,
|
||||
table,
|
||||
@@ -730,7 +727,6 @@ class DatabasePool(object):
|
||||
insertion_values,
|
||||
lock=lock,
|
||||
)
|
||||
return result
|
||||
except self.engine.module.IntegrityError as e:
|
||||
attempts += 1
|
||||
if attempts >= 5:
|
||||
@@ -1121,8 +1117,7 @@ class DatabasePool(object):
|
||||
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def simple_select_many_batch(
|
||||
async def simple_select_many_batch(
|
||||
self,
|
||||
table,
|
||||
column,
|
||||
@@ -1156,7 +1151,7 @@ class DatabasePool(object):
|
||||
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
|
||||
]
|
||||
for chunk in chunks:
|
||||
rows = yield self.runInteraction(
|
||||
rows = await self.runInteraction(
|
||||
desc,
|
||||
self.simple_select_many_txn,
|
||||
table,
|
||||
|
||||
@@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
service(ApplicationService): The service whose state to set.
|
||||
state(ApplicationServiceState): The connectivity state to apply.
|
||||
Returns:
|
||||
A Deferred which resolves when the state was set successfully.
|
||||
An Awaitable which resolves when the state was set successfully.
|
||||
"""
|
||||
return self.db_pool.simple_upsert(
|
||||
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||
|
||||
@@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
@cachedList(
|
||||
cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||
list_name="user_ids",
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="device_lists_remote_extremeties",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
|
||||
@@ -257,11 +257,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
# Return all events where not all sets can reach them.
|
||||
return {eid for eid, n in event_to_missing_sets.items() if n}
|
||||
|
||||
def get_oldest_events_in_room(self, room_id):
|
||||
return self.db_pool.runInteraction(
|
||||
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
|
||||
)
|
||||
|
||||
def get_oldest_events_with_depth_in_room(self, room_id):
|
||||
return self.db_pool.runInteraction(
|
||||
"get_oldest_events_with_depth_in_room",
|
||||
@@ -303,14 +298,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
else:
|
||||
return max(row["depth"] for row in rows)
|
||||
|
||||
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
||||
return self.db_pool.simple_select_onecol_txn(
|
||||
txn,
|
||||
table="event_backward_extremities",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
def get_prev_events_for_room(self, room_id: str):
|
||||
"""
|
||||
Gets a subset of the current forward extremities in the given room.
|
||||
|
||||
@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
self._rotate_delay = 3
|
||||
self._rotate_count = 10000
|
||||
|
||||
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
|
||||
def get_unread_event_push_actions_by_room_for_user(
|
||||
@cached(num_args=3, tree=True, max_entries=5000)
|
||||
async def get_unread_event_push_actions_by_room_for_user(
|
||||
self, room_id, user_id, last_read_event_id
|
||||
):
|
||||
ret = yield self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_unread_event_push_actions_by_room",
|
||||
self._get_unread_counts_by_receipt_txn,
|
||||
room_id,
|
||||
user_id,
|
||||
last_read_event_id,
|
||||
)
|
||||
return ret
|
||||
|
||||
def _get_unread_counts_by_receipt_txn(
|
||||
self, txn, room_id, user_id, last_read_event_id
|
||||
|
||||
@@ -17,13 +17,11 @@
|
||||
import itertools
|
||||
import logging
|
||||
from collections import OrderedDict, namedtuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.metrics
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
@@ -113,15 +111,14 @@ class PersistEventsStore:
|
||||
hs.config.worker.writers.events == hs.get_instance_name()
|
||||
), "Can only instantiate EventsStore on master"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _persist_events_and_state_updates(
|
||||
async def _persist_events_and_state_updates(
|
||||
self,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
current_state_for_room: Dict[str, StateMap[str]],
|
||||
state_delta_for_room: Dict[str, DeltaState],
|
||||
new_forward_extremeties: Dict[str, List[str]],
|
||||
backfilled: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Persist a set of events alongside updates to the current state and
|
||||
forward extremities tables.
|
||||
|
||||
@@ -136,7 +133,7 @@ class PersistEventsStore:
|
||||
backfilled
|
||||
|
||||
Returns:
|
||||
Deferred: resolves when the events have been persisted
|
||||
Resolves when the events have been persisted
|
||||
"""
|
||||
|
||||
# We want to calculate the stream orderings as late as possible, as
|
||||
@@ -168,7 +165,7 @@ class PersistEventsStore:
|
||||
for (event, context), stream in zip(events_and_contexts, stream_orderings):
|
||||
event.internal_metadata.stream_ordering = stream
|
||||
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"persist_events",
|
||||
self._persist_events_txn,
|
||||
events_and_contexts=events_and_contexts,
|
||||
@@ -206,16 +203,15 @@ class PersistEventsStore:
|
||||
(room_id,), list(latest_event_ids)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events_which_are_prevs(self, event_ids):
|
||||
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
|
||||
"""Filter the supplied list of event_ids to get those which are prev_events of
|
||||
existing (non-outlier/rejected) events.
|
||||
|
||||
Args:
|
||||
event_ids (Iterable[str]): event ids to filter
|
||||
event_ids: event ids to filter
|
||||
|
||||
Returns:
|
||||
Deferred[List[str]]: filtered event ids
|
||||
Filtered event ids
|
||||
"""
|
||||
results = []
|
||||
|
||||
@@ -240,14 +236,13 @@ class PersistEventsStore:
|
||||
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
|
||||
|
||||
for chunk in batch_iter(event_ids, 100):
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_prevs_before_rejected(self, event_ids):
|
||||
async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
|
||||
"""Get soft-failed ancestors to remove from the extremities.
|
||||
|
||||
Given a set of events, find all those that have been soft-failed or
|
||||
@@ -259,11 +254,11 @@ class PersistEventsStore:
|
||||
are separated by soft failed events.
|
||||
|
||||
Args:
|
||||
event_ids (Iterable[str]): Events to find prev events for. Note
|
||||
that these must have already been persisted.
|
||||
event_ids: Events to find prev events for. Note that these must have
|
||||
already been persisted.
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]
|
||||
The previous events.
|
||||
"""
|
||||
|
||||
# The set of event_ids to return. This includes all soft-failed events
|
||||
@@ -304,7 +299,7 @@ class PersistEventsStore:
|
||||
existing_prevs.add(prev_event_id)
|
||||
|
||||
for chunk in batch_iter(event_ids, 100):
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
|
||||
)
|
||||
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventContentFields
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
@@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
where_clause="NOT have_censored",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_reindex_fields_sender(self, progress, batch_size):
|
||||
async def _background_reindex_fields_sender(self, progress, batch_size):
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
@@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
return len(rows)
|
||||
|
||||
result = yield self.db_pool.runInteraction(
|
||||
result = await self.db_pool.runInteraction(
|
||||
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
|
||||
)
|
||||
|
||||
if not result:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_reindex_origin_server_ts(self, progress, batch_size):
|
||||
async def _background_reindex_origin_server_ts(self, progress, batch_size):
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
@@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
return len(rows_to_update)
|
||||
|
||||
result = yield self.db_pool.runInteraction(
|
||||
result = await self.db_pool.runInteraction(
|
||||
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
|
||||
)
|
||||
|
||||
if not result:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.EVENT_ORIGIN_SERVER_TS_NAME
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _cleanup_extremities_bg_update(self, progress, batch_size):
|
||||
async def _cleanup_extremities_bg_update(self, progress, batch_size):
|
||||
"""Background update to clean out extremities that should have been
|
||||
deleted previously.
|
||||
|
||||
@@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
return len(original_set)
|
||||
|
||||
num_handled = yield self.db_pool.runInteraction(
|
||||
num_handled = await self.db_pool.runInteraction(
|
||||
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
|
||||
)
|
||||
|
||||
if not num_handled:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.DELETE_SOFT_FAILED_EXTREMITIES
|
||||
)
|
||||
|
||||
def _drop_table_txn(txn):
|
||||
txn.execute("DROP TABLE _extremities_to_check")
|
||||
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
|
||||
)
|
||||
|
||||
return num_handled
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _redactions_received_ts(self, progress, batch_size):
|
||||
async def _redactions_received_ts(self, progress, batch_size):
|
||||
"""Handles filling out the `received_ts` column in redactions.
|
||||
"""
|
||||
last_event_id = progress.get("last_event_id", "")
|
||||
@@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
return len(rows)
|
||||
|
||||
count = yield self.db_pool.runInteraction(
|
||||
count = await self.db_pool.runInteraction(
|
||||
"_redactions_received_ts", _redactions_received_ts_txn
|
||||
)
|
||||
|
||||
if not count:
|
||||
yield self.db_pool.updates._end_background_update("redactions_received_ts")
|
||||
await self.db_pool.updates._end_background_update("redactions_received_ts")
|
||||
|
||||
return count
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _event_fix_redactions_bytes(self, progress, batch_size):
|
||||
async def _event_fix_redactions_bytes(self, progress, batch_size):
|
||||
"""Undoes hex encoded censored redacted event JSON.
|
||||
"""
|
||||
|
||||
@@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
txn.execute("DROP INDEX redactions_censored_redacts")
|
||||
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
|
||||
)
|
||||
|
||||
yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
|
||||
await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
|
||||
|
||||
return 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _event_store_labels(self, progress, batch_size):
|
||||
async def _event_store_labels(self, progress, batch_size):
|
||||
"""Background update handler which will store labels for existing events."""
|
||||
last_event_id = progress.get("last_event_id", "")
|
||||
|
||||
@@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
return nbrows
|
||||
|
||||
num_rows = yield self.db_pool.runInteraction(
|
||||
num_rows = await self.db_pool.runInteraction(
|
||||
desc="event_store_labels", func=_event_store_labels_txn
|
||||
)
|
||||
|
||||
if not num_rows:
|
||||
yield self.db_pool.updates._end_background_update("event_store_labels")
|
||||
await self.db_pool.updates._end_background_update("event_store_labels")
|
||||
|
||||
return num_rows
|
||||
|
||||
@@ -43,7 +43,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
@@ -137,42 +137,6 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
desc="get_received_ts",
|
||||
)
|
||||
|
||||
def get_received_ts_by_stream_pos(self, stream_ordering):
|
||||
"""Given a stream ordering get an approximate timestamp of when it
|
||||
happened.
|
||||
|
||||
This is done by simply taking the received ts of the first event that
|
||||
has a stream ordering greater than or equal to the given stream pos.
|
||||
If none exists returns the current time, on the assumption that it must
|
||||
have happened recently.
|
||||
|
||||
Args:
|
||||
stream_ordering (int)
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
"""
|
||||
|
||||
def _get_approximate_received_ts_txn(txn):
|
||||
sql = """
|
||||
SELECT received_ts FROM events
|
||||
WHERE stream_ordering >= ?
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
txn.execute(sql, (stream_ordering,))
|
||||
row = txn.fetchone()
|
||||
if row and row[0]:
|
||||
ts = row[0]
|
||||
else:
|
||||
ts = self.clock.time_msec()
|
||||
|
||||
return ts
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
"get_approximate_received_ts", _get_approximate_received_ts_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(
|
||||
self,
|
||||
@@ -883,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
"""Given a list of event ids, check if we have already processed and
|
||||
stored them as non outliers.
|
||||
"""
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
rows = yield defer.ensureDeferred(
|
||||
self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
)
|
||||
)
|
||||
|
||||
return {r["event_id"] for r in rows}
|
||||
@@ -923,36 +889,6 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
return results
|
||||
|
||||
def _get_total_state_event_counts_txn(self, txn, room_id):
|
||||
"""
|
||||
See get_total_state_event_counts.
|
||||
"""
|
||||
# We join against the events table as that has an index on room_id
|
||||
sql = """
|
||||
SELECT COUNT(*) FROM state_events
|
||||
INNER JOIN events USING (room_id, event_id)
|
||||
WHERE room_id=?
|
||||
"""
|
||||
txn.execute(sql, (room_id,))
|
||||
row = txn.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
def get_total_state_event_counts(self, room_id):
|
||||
"""
|
||||
Gets the total number of state events in a room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
"get_total_state_event_counts",
|
||||
self._get_total_state_event_counts_txn,
|
||||
room_id,
|
||||
)
|
||||
|
||||
def _get_current_state_event_counts_txn(self, txn, room_id):
|
||||
"""
|
||||
See get_current_state_event_counts.
|
||||
@@ -1222,97 +1158,6 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return rows, to_token, True
|
||||
|
||||
@cached(num_args=5, max_entries=10)
|
||||
def get_all_new_events(
|
||||
self,
|
||||
last_backfill_id,
|
||||
last_forward_id,
|
||||
current_backfill_id,
|
||||
current_forward_id,
|
||||
limit,
|
||||
):
|
||||
"""Get all the new events that have arrived at the server either as
|
||||
new events or as backfilled events"""
|
||||
have_backfill_events = last_backfill_id != current_backfill_id
|
||||
have_forward_events = last_forward_id != current_forward_id
|
||||
|
||||
if not have_backfill_events and not have_forward_events:
|
||||
return defer.succeed(AllNewEventsResult([], [], [], [], []))
|
||||
|
||||
def get_all_new_events_txn(txn):
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
||||
" ORDER BY stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
if have_forward_events:
|
||||
txn.execute(sql, (last_forward_id, current_forward_id, limit))
|
||||
new_forward_events = txn.fetchall()
|
||||
|
||||
if len(new_forward_events) == limit:
|
||||
upper_bound = new_forward_events[-1][0]
|
||||
else:
|
||||
upper_bound = current_forward_id
|
||||
|
||||
sql = (
|
||||
"SELECT event_stream_ordering, event_id, state_group"
|
||||
" FROM ex_outlier_stream"
|
||||
" WHERE ? > event_stream_ordering"
|
||||
" AND event_stream_ordering >= ?"
|
||||
" ORDER BY event_stream_ordering DESC"
|
||||
)
|
||||
txn.execute(sql, (last_forward_id, upper_bound))
|
||||
forward_ex_outliers = txn.fetchall()
|
||||
else:
|
||||
new_forward_events = []
|
||||
forward_ex_outliers = []
|
||||
|
||||
sql = (
|
||||
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" WHERE ? > stream_ordering AND stream_ordering >= ?"
|
||||
" ORDER BY stream_ordering DESC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
if have_backfill_events:
|
||||
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
|
||||
new_backfill_events = txn.fetchall()
|
||||
|
||||
if len(new_backfill_events) == limit:
|
||||
upper_bound = new_backfill_events[-1][0]
|
||||
else:
|
||||
upper_bound = current_backfill_id
|
||||
|
||||
sql = (
|
||||
"SELECT -event_stream_ordering, event_id, state_group"
|
||||
" FROM ex_outlier_stream"
|
||||
" WHERE ? > event_stream_ordering"
|
||||
" AND event_stream_ordering >= ?"
|
||||
" ORDER BY event_stream_ordering DESC"
|
||||
)
|
||||
txn.execute(sql, (-last_backfill_id, -upper_bound))
|
||||
backward_ex_outliers = txn.fetchall()
|
||||
else:
|
||||
new_backfill_events = []
|
||||
backward_ex_outliers = []
|
||||
|
||||
return AllNewEventsResult(
|
||||
new_forward_events,
|
||||
new_backfill_events,
|
||||
forward_ex_outliers,
|
||||
backward_ex_outliers,
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
|
||||
|
||||
async def is_event_after(self, event_id1, event_id2):
|
||||
"""Returns True if event_id1 is after event_id2 in the stream
|
||||
"""
|
||||
@@ -1357,14 +1202,3 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
return self.db_pool.runInteraction(
|
||||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
||||
|
||||
AllNewEventsResult = namedtuple(
|
||||
"AllNewEventsResult",
|
||||
[
|
||||
"new_forward_events",
|
||||
"new_backfill_events",
|
||||
"forward_ex_outliers",
|
||||
"backward_ex_outliers",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1038,6 +1038,14 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
"remove_user_from_group", _remove_user_from_group_txn
|
||||
)
|
||||
|
||||
def change_user_admin_in_group(self, group_id, user_id, is_admin):
|
||||
return self.db_pool.simple_update(
|
||||
table="group_users",
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
updatevalues={"is_admin": is_admin},
|
||||
desc="change_user_admin_in_group"
|
||||
)
|
||||
|
||||
def add_room_to_group(self, group_id, room_id, is_public):
|
||||
return self.db_pool.simple_insert(
|
||||
table="group_rooms",
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_presence_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
|
||||
)
|
||||
def get_presence_for_users(self, user_ids):
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
async def get_presence_for_users(self, user_ids):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="presence_stream",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
@@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore):
|
||||
|
||||
def get_current_presence_token(self):
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def allow_presence_visible(self, observed_localpart, observer_userid):
|
||||
return self.db_pool.simple_insert(
|
||||
table="presence_allow_inbound",
|
||||
values={
|
||||
"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid,
|
||||
},
|
||||
desc="allow_presence_visible",
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
def disallow_presence_visible(self, observed_localpart, observer_userid):
|
||||
return self.db_pool.simple_delete_one(
|
||||
table="presence_allow_inbound",
|
||||
keyvalues={
|
||||
"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid,
|
||||
},
|
||||
desc="disallow_presence_visible",
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||
from synapse.storage.util.id_generators import ChainedIdGenerator
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -115,9 +115,9 @@ class PushRulesWorkerStore(
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedInlineCallbacks(max_entries=5000)
|
||||
def get_push_rules_for_user(self, user_id):
|
||||
rows = yield self.db_pool.simple_select_list(
|
||||
@cached(max_entries=5000)
|
||||
async def get_push_rules_for_user(self, user_id):
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
table="push_rules",
|
||||
keyvalues={"user_name": user_id},
|
||||
retcols=(
|
||||
@@ -133,17 +133,15 @@ class PushRulesWorkerStore(
|
||||
|
||||
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
|
||||
|
||||
enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
|
||||
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
|
||||
|
||||
use_new_defaults = user_id in self._users_new_default_push_rules
|
||||
|
||||
rules = _load_rules(rows, enabled_map, use_new_defaults)
|
||||
return _load_rules(rows, enabled_map, use_new_defaults)
|
||||
|
||||
return rules
|
||||
|
||||
@cachedInlineCallbacks(max_entries=5000)
|
||||
def get_push_rules_enabled_for_user(self, user_id):
|
||||
results = yield self.db_pool.simple_select_list(
|
||||
@cached(max_entries=5000)
|
||||
async def get_push_rules_enabled_for_user(self, user_id):
|
||||
results = await self.db_pool.simple_select_list(
|
||||
table="push_rules_enable",
|
||||
keyvalues={"user_name": user_id},
|
||||
retcols=("user_name", "rule_id", "enabled"),
|
||||
@@ -170,18 +168,15 @@ class PushRulesWorkerStore(
|
||||
)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_push_rules_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
|
||||
)
|
||||
def bulk_get_push_rules(self, user_ids):
|
||||
async def bulk_get_push_rules(self, user_ids):
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
results = {user_id: [] for user_id in user_ids}
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
@@ -194,7 +189,7 @@ class PushRulesWorkerStore(
|
||||
for row in rows:
|
||||
results.setdefault(row["user_name"], []).append(row)
|
||||
|
||||
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
|
||||
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
|
||||
|
||||
for user_id, rules in results.items():
|
||||
use_new_defaults = user_id in self._users_new_default_push_rules
|
||||
@@ -205,14 +200,15 @@ class PushRulesWorkerStore(
|
||||
|
||||
return results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
|
||||
async def copy_push_rule_from_room_to_room(
|
||||
self, new_room_id: str, user_id: str, rule: dict
|
||||
) -> None:
|
||||
"""Copy a single push rule from one room to another for a specific user.
|
||||
|
||||
Args:
|
||||
new_room_id (str): ID of the new room.
|
||||
user_id (str): ID of user the push rule belongs to.
|
||||
rule (Dict): A push rule.
|
||||
new_room_id: ID of the new room.
|
||||
user_id : ID of user the push rule belongs to.
|
||||
rule: A push rule.
|
||||
"""
|
||||
# Create new rule id
|
||||
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
|
||||
@@ -224,7 +220,7 @@ class PushRulesWorkerStore(
|
||||
condition["pattern"] = new_room_id
|
||||
|
||||
# Add the rule for the new room
|
||||
yield self.add_push_rule(
|
||||
await self.add_push_rule(
|
||||
user_id=user_id,
|
||||
rule_id=new_rule_id,
|
||||
priority_class=rule["priority_class"],
|
||||
@@ -232,20 +228,19 @@ class PushRulesWorkerStore(
|
||||
actions=rule["actions"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def copy_push_rules_from_room_to_room_for_user(
|
||||
self, old_room_id, new_room_id, user_id
|
||||
):
|
||||
async def copy_push_rules_from_room_to_room_for_user(
|
||||
self, old_room_id: str, new_room_id: str, user_id: str
|
||||
) -> None:
|
||||
"""Copy all of the push rules from one room to another for a specific
|
||||
user.
|
||||
|
||||
Args:
|
||||
old_room_id (str): ID of the old room.
|
||||
new_room_id (str): ID of the new room.
|
||||
user_id (str): ID of user to copy push rules for.
|
||||
old_room_id: ID of the old room.
|
||||
new_room_id: ID of the new room.
|
||||
user_id: ID of user to copy push rules for.
|
||||
"""
|
||||
# Retrieve push rules for this user
|
||||
user_push_rules = yield self.get_push_rules_for_user(user_id)
|
||||
user_push_rules = await self.get_push_rules_for_user(user_id)
|
||||
|
||||
# Get rules relating to the old room and copy them to the new room
|
||||
for rule in user_push_rules:
|
||||
@@ -254,21 +249,20 @@ class PushRulesWorkerStore(
|
||||
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
|
||||
for c in conditions
|
||||
):
|
||||
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
|
||||
await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_push_rules_enabled_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def bulk_get_push_rules_enabled(self, user_ids):
|
||||
async def bulk_get_push_rules_enabled(self, user_ids):
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
results = {user_id: {} for user_id in user_ids}
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules_enable",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
@@ -332,8 +326,7 @@ class PushRulesWorkerStore(
|
||||
|
||||
|
||||
class PushRuleStore(PushRulesWorkerStore):
|
||||
@defer.inlineCallbacks
|
||||
def add_push_rule(
|
||||
async def add_push_rule(
|
||||
self,
|
||||
user_id,
|
||||
rule_id,
|
||||
@@ -342,13 +335,13 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
actions,
|
||||
before=None,
|
||||
after=None,
|
||||
):
|
||||
) -> None:
|
||||
conditions_json = json_encoder.encode(conditions)
|
||||
actions_json = json_encoder.encode(actions)
|
||||
with self._push_rules_stream_id_gen.get_next() as ids:
|
||||
stream_id, event_stream_ordering = ids
|
||||
if before or after:
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_add_push_rule_relative_txn",
|
||||
self._add_push_rule_relative_txn,
|
||||
stream_id,
|
||||
@@ -362,7 +355,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
after,
|
||||
)
|
||||
else:
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_add_push_rule_highest_priority_txn",
|
||||
self._add_push_rule_highest_priority_txn,
|
||||
stream_id,
|
||||
@@ -546,16 +539,15 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_push_rule(self, user_id, rule_id):
|
||||
async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
|
||||
"""
|
||||
Delete a push rule. Args specify the row to be deleted and can be
|
||||
any of the columns in the push_rule table, but below are the
|
||||
standard ones
|
||||
|
||||
Args:
|
||||
user_id (str): The matrix ID of the push rule owner
|
||||
rule_id (str): The rule_id of the rule to be deleted
|
||||
user_id: The matrix ID of the push rule owner
|
||||
rule_id: The rule_id of the rule to be deleted
|
||||
"""
|
||||
|
||||
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
|
||||
@@ -569,18 +561,17 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as ids:
|
||||
stream_id, event_stream_ordering = ids
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_push_rule",
|
||||
delete_push_rule_txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_push_rule_enabled(self, user_id, rule_id, enabled):
|
||||
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
|
||||
with self._push_rules_stream_id_gen.get_next() as ids:
|
||||
stream_id, event_stream_ordering = ids
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_set_push_rule_enabled_txn",
|
||||
self._set_push_rule_enabled_txn,
|
||||
stream_id,
|
||||
@@ -611,8 +602,9 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
op="ENABLE" if enabled else "DISABLE",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
|
||||
async def set_push_rule_actions(
|
||||
self, user_id, rule_id, actions, is_default_rule
|
||||
) -> None:
|
||||
actions_json = json_encoder.encode(actions)
|
||||
|
||||
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
|
||||
@@ -653,7 +645,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as ids:
|
||||
stream_id, event_stream_ordering = ids
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"set_push_rule_actions",
|
||||
set_push_rule_actions_txn,
|
||||
stream_id,
|
||||
|
||||
@@ -19,10 +19,8 @@ from typing import Iterable, Iterator, List, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
Drops any rows whose data cannot be decoded
|
||||
"""
|
||||
for r in rows:
|
||||
dataJson = r["data"]
|
||||
data_json = r["data"]
|
||||
try:
|
||||
r["data"] = db_to_json(dataJson)
|
||||
r["data"] = db_to_json(data_json)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Invalid JSON in data for pusher %d: %s, %s",
|
||||
r["id"],
|
||||
dataJson,
|
||||
data_json,
|
||||
e.args[0],
|
||||
)
|
||||
continue
|
||||
|
||||
yield r
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_has_pusher(self, user_id):
|
||||
ret = yield self.db_pool.simple_select_one_onecol(
|
||||
async def user_has_pusher(self, user_id):
|
||||
ret = await self.db_pool.simple_select_one_onecol(
|
||||
"pushers", {"user_name": user_id}, "id", allow_none=True
|
||||
)
|
||||
return ret is not None
|
||||
@@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
def get_pushers_by_user_id(self, user_id):
|
||||
return self.get_pushers_by({"user_name": user_id})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_pushers_by(self, keyvalues):
|
||||
ret = yield self.db_pool.simple_select_list(
|
||||
async def get_pushers_by(self, keyvalues):
|
||||
ret = await self.db_pool.simple_select_list(
|
||||
"pushers",
|
||||
keyvalues,
|
||||
[
|
||||
@@ -87,16 +83,14 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
)
|
||||
return self._decode_pushers_rows(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_all_pushers(self):
|
||||
async def get_all_pushers(self):
|
||||
def get_pushers(txn):
|
||||
txn.execute("SELECT * FROM pushers")
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return self._decode_pushers_rows(rows)
|
||||
|
||||
rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
|
||||
return rows
|
||||
return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
|
||||
|
||||
async def get_all_updated_pushers_rows(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
@@ -164,19 +158,16 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=1, max_entries=15000)
|
||||
def get_if_user_has_pusher(self, user_id):
|
||||
@cached(num_args=1, max_entries=15000)
|
||||
async def get_if_user_has_pusher(self, user_id):
|
||||
# This only exists for the cachedList decorator
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_if_user_has_pusher",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
|
||||
)
|
||||
def get_if_users_have_pushers(self, user_ids):
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
async def get_if_users_have_pushers(self, user_ids):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="pushers",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
@@ -189,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_pusher_last_stream_ordering(
|
||||
async def update_pusher_last_stream_ordering(
|
||||
self, app_id, pushkey, user_id, last_stream_ordering
|
||||
):
|
||||
yield self.db_pool.simple_update_one(
|
||||
) -> None:
|
||||
await self.db_pool.simple_update_one(
|
||||
"pushers",
|
||||
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
|
||||
{"last_stream_ordering": last_stream_ordering},
|
||||
desc="update_pusher_last_stream_ordering",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_pusher_last_stream_ordering_and_success(
|
||||
self, app_id, pushkey, user_id, last_stream_ordering, last_success
|
||||
):
|
||||
async def update_pusher_last_stream_ordering_and_success(
|
||||
self,
|
||||
app_id: str,
|
||||
pushkey: str,
|
||||
user_id: str,
|
||||
last_stream_ordering: int,
|
||||
last_success: int,
|
||||
) -> bool:
|
||||
"""Update the last stream ordering position we've processed up to for
|
||||
the given pusher.
|
||||
|
||||
Args:
|
||||
app_id (str)
|
||||
pushkey (str)
|
||||
last_stream_ordering (int)
|
||||
last_success (int)
|
||||
app_id
|
||||
pushkey
|
||||
user_id
|
||||
last_stream_ordering
|
||||
last_success
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True if the pusher still exists; False if it has been deleted.
|
||||
True if the pusher still exists; False if it has been deleted.
|
||||
"""
|
||||
updated = yield self.db_pool.simple_update(
|
||||
updated = await self.db_pool.simple_update(
|
||||
table="pushers",
|
||||
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
|
||||
updatevalues={
|
||||
@@ -228,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
|
||||
return bool(updated)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
|
||||
yield self.db_pool.simple_update(
|
||||
async def update_pusher_failing_since(
|
||||
self, app_id, pushkey, user_id, failing_since
|
||||
) -> None:
|
||||
await self.db_pool.simple_update(
|
||||
table="pushers",
|
||||
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
|
||||
updatevalues={"failing_since": failing_since},
|
||||
desc="update_pusher_failing_since",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_throttle_params_by_room(self, pusher_id):
|
||||
res = yield self.db_pool.simple_select_list(
|
||||
async def get_throttle_params_by_room(self, pusher_id):
|
||||
res = await self.db_pool.simple_select_list(
|
||||
"pusher_throttle",
|
||||
{"pusher": pusher_id},
|
||||
["room_id", "last_sent_ts", "throttle_ms"],
|
||||
@@ -255,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
|
||||
return params_by_room
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_throttle_params(self, pusher_id, room_id, params):
|
||||
async def set_throttle_params(self, pusher_id, room_id, params) -> None:
|
||||
# no need to lock because `pusher_throttle` has a primary key on
|
||||
# (pusher, room_id) so simple_upsert will retry
|
||||
yield self.db_pool.simple_upsert(
|
||||
await self.db_pool.simple_upsert(
|
||||
"pusher_throttle",
|
||||
{"pusher": pusher_id, "room_id": room_id},
|
||||
params,
|
||||
@@ -272,8 +266,7 @@ class PusherStore(PusherWorkerStore):
|
||||
def get_pushers_stream_token(self):
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_pusher(
|
||||
async def add_pusher(
|
||||
self,
|
||||
user_id,
|
||||
access_token,
|
||||
@@ -287,11 +280,11 @@ class PusherStore(PusherWorkerStore):
|
||||
data,
|
||||
last_stream_ordering,
|
||||
profile_tag="",
|
||||
):
|
||||
) -> None:
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
# (app_id, pushkey, user_name) so simple_upsert will retry
|
||||
yield self.db_pool.simple_upsert(
|
||||
await self.db_pool.simple_upsert(
|
||||
table="pushers",
|
||||
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
|
||||
values={
|
||||
@@ -316,15 +309,16 @@ class PusherStore(PusherWorkerStore):
|
||||
|
||||
if user_has_pusher is not True:
|
||||
# invalidate, since we the user might not have had a pusher before
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"add_pusher",
|
||||
self._invalidate_cache_and_stream,
|
||||
self.get_if_user_has_pusher,
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
|
||||
async def delete_pusher_by_app_id_pushkey_user_id(
|
||||
self, app_id, pushkey, user_id
|
||||
) -> None:
|
||||
def delete_pusher_txn(txn, stream_id):
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_if_user_has_pusher, (user_id,)
|
||||
@@ -351,6 +345,6 @@ class PusherStore(PusherWorkerStore):
|
||||
)
|
||||
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_pusher", delete_pusher_txn, stream_id
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_users_with_read_receipts_in_room(self, room_id):
|
||||
receipts = yield self.get_receipts_for_room(room_id, "m.read")
|
||||
@cached()
|
||||
async def get_users_with_read_receipts_in_room(self, room_id):
|
||||
receipts = await self.get_receipts_for_room(room_id, "m.read")
|
||||
return {r["user_id"] for r in receipts}
|
||||
|
||||
@cached(num_args=2)
|
||||
@@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2)
|
||||
def get_receipts_for_user(self, user_id, receipt_type):
|
||||
rows = yield self.db_pool.simple_select_list(
|
||||
@cached(num_args=2)
|
||||
async def get_receipts_for_user(self, user_id, receipt_type):
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
table="receipts_linearized",
|
||||
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
|
||||
retcols=("room_id", "event_id"),
|
||||
@@ -95,8 +95,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
return {row["room_id"]: row["event_id"] for row in rows}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
|
||||
async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT rl.room_id, rl.event_id,"
|
||||
@@ -110,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, (user_id,))
|
||||
return txn.fetchall()
|
||||
|
||||
rows = yield self.db_pool.runInteraction(
|
||||
rows = await self.db_pool.runInteraction(
|
||||
"get_receipts_for_user_with_orderings", f
|
||||
)
|
||||
return {
|
||||
@@ -122,56 +121,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
for row in rows
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
async def get_linearized_receipts_for_rooms(
|
||||
self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
"""Get receipts for multiple rooms for sending to clients.
|
||||
|
||||
Args:
|
||||
room_ids (list): List of room_ids.
|
||||
to_key (int): Max stream id to fetch receipts upto.
|
||||
from_key (int): Min stream id to fetch receipts from. None fetches
|
||||
room_id: List of room_ids.
|
||||
to_key: Max stream id to fetch receipts upto.
|
||||
from_key: Min stream id to fetch receipts from. None fetches
|
||||
from the start.
|
||||
|
||||
Returns:
|
||||
list: A list of receipts.
|
||||
A list of receipts.
|
||||
"""
|
||||
room_ids = set(room_ids)
|
||||
|
||||
if from_key is not None:
|
||||
# Only ask the database about rooms where there have been new
|
||||
# receipts added since `from_key`
|
||||
room_ids = yield self._receipts_stream_cache.get_entities_changed(
|
||||
room_ids = self._receipts_stream_cache.get_entities_changed(
|
||||
room_ids, from_key
|
||||
)
|
||||
|
||||
results = yield self._get_linearized_receipts_for_rooms(
|
||||
results = await self._get_linearized_receipts_for_rooms(
|
||||
room_ids, to_key, from_key=from_key
|
||||
)
|
||||
|
||||
return [ev for res in results.values() for ev in res]
|
||||
|
||||
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||
async def get_linearized_receipts_for_room(
|
||||
self, room_id: str, to_key: int, from_key: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
"""Get receipts for a single room for sending to clients.
|
||||
|
||||
Args:
|
||||
room_ids (str): The room id.
|
||||
to_key (int): Max stream id to fetch receipts upto.
|
||||
from_key (int): Min stream id to fetch receipts from. None fetches
|
||||
room_ids: The room id.
|
||||
to_key: Max stream id to fetch receipts upto.
|
||||
from_key: Min stream id to fetch receipts from. None fetches
|
||||
from the start.
|
||||
|
||||
Returns:
|
||||
Deferred[list]: A list of receipts.
|
||||
A list of receipts.
|
||||
"""
|
||||
if from_key is not None:
|
||||
# Check the cache first to see if any new receipts have been added
|
||||
# since`from_key`. If not we can no-op.
|
||||
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
|
||||
defer.succeed([])
|
||||
return []
|
||||
|
||||
return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
|
||||
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
|
||||
|
||||
@cachedInlineCallbacks(num_args=3, tree=True)
|
||||
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||
@cached(num_args=3, tree=True)
|
||||
async def _get_linearized_receipts_for_room(
|
||||
self, room_id: str, to_key: int, from_key: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
"""See get_linearized_receipts_for_room
|
||||
"""
|
||||
|
||||
@@ -195,7 +199,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
return rows
|
||||
|
||||
rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
|
||||
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
@@ -212,9 +216,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
cached_method_name="_get_linearized_receipts_for_room",
|
||||
list_name="room_ids",
|
||||
num_args=3,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
if not room_ids:
|
||||
return {}
|
||||
|
||||
@@ -243,7 +246,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
txn_results = yield self.db_pool.runInteraction(
|
||||
txn_results = await self.db_pool.runInteraction(
|
||||
"_get_linearized_receipts_for_rooms", f
|
||||
)
|
||||
|
||||
@@ -346,7 +349,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
def _invalidate_get_users_with_receipts_in_room(
|
||||
self, room_id, receipt_type, user_id
|
||||
self, room_id: str, receipt_type: str, user_id: str
|
||||
):
|
||||
if receipt_type != "m.read":
|
||||
return
|
||||
@@ -472,15 +475,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
|
||||
return rx_ts
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
|
||||
async def insert_receipt(
|
||||
self,
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
event_ids: List[str],
|
||||
data: dict,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""Insert a receipt, either from local client or remote server.
|
||||
|
||||
Automatically does conversion between linearized and graph
|
||||
representations.
|
||||
"""
|
||||
if not event_ids:
|
||||
return
|
||||
return None
|
||||
|
||||
if len(event_ids) == 1:
|
||||
linearized_event_id = event_ids[0]
|
||||
@@ -507,13 +516,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
else:
|
||||
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
|
||||
|
||||
linearized_event_id = yield self.db_pool.runInteraction(
|
||||
linearized_event_id = await self.db_pool.runInteraction(
|
||||
"insert_receipt_conv", graph_to_linear
|
||||
)
|
||||
|
||||
stream_id_manager = self._receipts_id_gen.get_next()
|
||||
with stream_id_manager as stream_id:
|
||||
event_ts = yield self.db_pool.runInteraction(
|
||||
event_ts = await self.db_pool.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self.insert_linearized_receipt_txn,
|
||||
room_id,
|
||||
@@ -535,7 +544,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
now - event_ts,
|
||||
)
|
||||
|
||||
yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
|
||||
await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
|
||||
|
||||
max_persisted_id = self._receipts_id_gen.get_current_token()
|
||||
|
||||
|
||||
@@ -17,9 +17,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from typing import Awaitable, Dict, List, Optional
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||
@@ -304,7 +302,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
sql = (
|
||||
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
|
||||
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
|
||||
" access_tokens.device_id, access_tokens.valid_until_ms"
|
||||
" FROM users"
|
||||
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
||||
@@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
id_server (str)
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
Awaitable
|
||||
"""
|
||||
# We need to use an upsert, in case they user had already bound the
|
||||
# threepid
|
||||
@@ -952,6 +950,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
create_profile_with_displayname=None,
|
||||
admin=False,
|
||||
user_type=None,
|
||||
shadow_banned=False,
|
||||
):
|
||||
"""Attempts to register an account.
|
||||
|
||||
@@ -968,6 +967,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
admin (boolean): is an admin user?
|
||||
user_type (str|None): type of user. One of the values from
|
||||
api.constants.UserTypes, or None for a normal user.
|
||||
shadow_banned (bool): Whether the user is shadow-banned,
|
||||
i.e. they may be told their requests succeeded but we ignore them.
|
||||
|
||||
Raises:
|
||||
StoreError if the user_id could not be registered.
|
||||
@@ -986,6 +987,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
create_profile_with_displayname,
|
||||
admin,
|
||||
user_type,
|
||||
shadow_banned,
|
||||
)
|
||||
|
||||
def _register_user(
|
||||
@@ -999,6 +1001,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
create_profile_with_displayname,
|
||||
admin,
|
||||
user_type,
|
||||
shadow_banned,
|
||||
):
|
||||
user_id_obj = UserID.from_string(user_id)
|
||||
|
||||
@@ -1028,6 +1031,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
"appservice_id": appservice_id,
|
||||
"admin": 1 if admin else 0,
|
||||
"user_type": user_type,
|
||||
"shadow_banned": shadow_banned,
|
||||
},
|
||||
)
|
||||
else:
|
||||
@@ -1042,6 +1046,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
"appservice_id": appservice_id,
|
||||
"admin": 1 if admin else 0,
|
||||
"user_type": user_type,
|
||||
"shadow_banned": shadow_banned,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1077,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
|
||||
def record_user_external_id(
|
||||
self, auth_provider: str, external_id: str, user_id: str
|
||||
) -> Deferred:
|
||||
) -> Awaitable:
|
||||
"""Record a mapping from an external user id to a mxid
|
||||
|
||||
Args:
|
||||
@@ -1345,43 +1350,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
"validate_threepid_session_txn", validate_threepid_session_txn
|
||||
)
|
||||
|
||||
def upsert_threepid_validation_session(
|
||||
self,
|
||||
medium,
|
||||
address,
|
||||
client_secret,
|
||||
send_attempt,
|
||||
session_id,
|
||||
validated_at=None,
|
||||
):
|
||||
"""Upsert a threepid validation session
|
||||
Args:
|
||||
medium (str): The medium of the 3PID
|
||||
address (str): The address of the 3PID
|
||||
client_secret (str): A unique string provided by the client to
|
||||
help identify this validation attempt
|
||||
send_attempt (int): The latest send_attempt on this session
|
||||
session_id (str): The id of this validation session
|
||||
validated_at (int|None): The unix timestamp in milliseconds of
|
||||
when the session was marked as valid
|
||||
"""
|
||||
insertion_values = {
|
||||
"medium": medium,
|
||||
"address": address,
|
||||
"client_secret": client_secret,
|
||||
}
|
||||
|
||||
if validated_at:
|
||||
insertion_values["validated_at"] = validated_at
|
||||
|
||||
return self.db_pool.simple_upsert(
|
||||
table="threepid_validation_session",
|
||||
keyvalues={"session_id": session_id},
|
||||
values={"last_send_attempt": send_attempt},
|
||||
insertion_values=insertion_values,
|
||||
desc="upsert_threepid_validation_session",
|
||||
)
|
||||
|
||||
def start_or_continue_validation_session(
|
||||
self,
|
||||
medium,
|
||||
|
||||
@@ -35,10 +35,6 @@ from synapse.util.caches.descriptors import cached
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OpsLevel = collections.namedtuple(
|
||||
"OpsLevel", ("ban_level", "kick_level", "redact_level")
|
||||
)
|
||||
|
||||
RatelimitOverride = collections.namedtuple(
|
||||
"RatelimitOverride", ("messages_per_second", "burst_count")
|
||||
)
|
||||
|
||||
@@ -17,8 +17,6 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
lambda: self._known_servers_count,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _count_known_servers(self):
|
||||
async def _count_known_servers(self):
|
||||
"""
|
||||
Count the servers that this server knows about.
|
||||
|
||||
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
txn.execute(query)
|
||||
return list(txn)[0][0]
|
||||
|
||||
count = yield self.db_pool.runInteraction("get_known_servers", _transact)
|
||||
count = await self.db_pool.runInteraction("get_known_servers", _transact)
|
||||
|
||||
# We always know about ourselves, even if we have nothing in
|
||||
# room_memberships (for example, the server is new).
|
||||
@@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_joined_profile_from_event_id",
|
||||
list_name="event_ids",
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
|
||||
)
|
||||
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||
"""For given set of member event_ids check if they point to a join
|
||||
event and if so return the associated user and profile info.
|
||||
|
||||
@@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
event_ids: The member event IDs to lookup
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
|
||||
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
|
||||
to `user_id` and ProfileInfo (or None if not join event).
|
||||
"""
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="room_memberships",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
@@ -772,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
return set(room_ids)
|
||||
|
||||
def get_membership_from_event_ids(
|
||||
async def get_membership_from_event_ids(
|
||||
self, member_event_ids: Iterable[str]
|
||||
) -> List[dict]:
|
||||
"""Get user_id and membership of a set of event IDs.
|
||||
"""
|
||||
|
||||
return self.db_pool.simple_select_many_batch(
|
||||
return await self.db_pool.simple_select_many_batch(
|
||||
table="room_memberships",
|
||||
column="event_id",
|
||||
iterable=member_event_ids,
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- A shadow-banned user may be told that their requests succeeded when they were
|
||||
-- actually ignored.
|
||||
ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;
|
||||
@@ -0,0 +1,17 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- This table is no longer used.
|
||||
DROP TABLE IF EXISTS presence_allow_inbound;
|
||||
@@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
cached_method_name="_get_state_group_for_event",
|
||||
list_name="event_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def _get_state_group_for_events(self, event_ids):
|
||||
async def _get_state_group_for_events(self, event_ids):
|
||||
"""Returns mapping event_id -> state_group
|
||||
"""
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="event_to_state_groups",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
|
||||
@@ -39,15 +39,17 @@ what sort order was used:
|
||||
import abc
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Optional
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
@@ -68,8 +70,12 @@ _EventDictReturn = namedtuple(
|
||||
|
||||
|
||||
def generate_pagination_where_clause(
|
||||
direction, column_names, from_token, to_token, engine
|
||||
):
|
||||
direction: str,
|
||||
column_names: Tuple[str, str],
|
||||
from_token: Optional[Tuple[int, int]],
|
||||
to_token: Optional[Tuple[int, int]],
|
||||
engine: BaseDatabaseEngine,
|
||||
) -> str:
|
||||
"""Creates an SQL expression to bound the columns by the pagination
|
||||
tokens.
|
||||
|
||||
@@ -90,21 +96,19 @@ def generate_pagination_where_clause(
|
||||
token, but include those that match the to token.
|
||||
|
||||
Args:
|
||||
direction (str): Whether we're paginating backwards("b") or
|
||||
forwards ("f").
|
||||
column_names (tuple[str, str]): The column names to bound. Must *not*
|
||||
be user defined as these get inserted directly into the SQL
|
||||
statement without escapes.
|
||||
from_token (tuple[int, int]|None): The start point for the pagination.
|
||||
This is an exclusive minimum bound if direction is "f", and an
|
||||
inclusive maximum bound if direction is "b".
|
||||
to_token (tuple[int, int]|None): The endpoint point for the pagination.
|
||||
This is an inclusive maximum bound if direction is "f", and an
|
||||
exclusive minimum bound if direction is "b".
|
||||
direction: Whether we're paginating backwards("b") or forwards ("f").
|
||||
column_names: The column names to bound. Must *not* be user defined as
|
||||
these get inserted directly into the SQL statement without escapes.
|
||||
from_token: The start point for the pagination. This is an exclusive
|
||||
minimum bound if direction is "f", and an inclusive maximum bound if
|
||||
direction is "b".
|
||||
to_token: The endpoint point for the pagination. This is an inclusive
|
||||
maximum bound if direction is "f", and an exclusive minimum bound if
|
||||
direction is "b".
|
||||
engine: The database engine to generate the clauses for
|
||||
|
||||
Returns:
|
||||
str: The sql expression
|
||||
The sql expression
|
||||
"""
|
||||
assert direction in ("b", "f")
|
||||
|
||||
@@ -132,7 +136,12 @@ def generate_pagination_where_clause(
|
||||
return " AND ".join(where_clause)
|
||||
|
||||
|
||||
def _make_generic_sql_bound(bound, column_names, values, engine):
|
||||
def _make_generic_sql_bound(
|
||||
bound: str,
|
||||
column_names: Tuple[str, str],
|
||||
values: Tuple[Optional[int], int],
|
||||
engine: BaseDatabaseEngine,
|
||||
) -> str:
|
||||
"""Create an SQL expression that bounds the given column names by the
|
||||
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
|
||||
|
||||
@@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
|
||||
out manually.
|
||||
|
||||
Args:
|
||||
bound (str): The comparison operator to use. One of ">", "<", ">=",
|
||||
bound: The comparison operator to use. One of ">", "<", ">=",
|
||||
"<=", where the values are on the left and columns on the right.
|
||||
names (tuple[str, str]): The column names. Must *not* be user defined
|
||||
names: The column names. Must *not* be user defined
|
||||
as these get inserted directly into the SQL statement without
|
||||
escapes.
|
||||
values (tuple[int|None, int]): The values to bound the columns by. If
|
||||
values: The values to bound the columns by. If
|
||||
the first value is None then only creates a bound on the second
|
||||
column.
|
||||
engine: The database engine to generate the SQL for
|
||||
|
||||
Returns:
|
||||
str
|
||||
The SQL statement
|
||||
"""
|
||||
|
||||
assert bound in (">", "<", ">=", "<=")
|
||||
@@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
|
||||
)
|
||||
|
||||
|
||||
def filter_to_clause(event_filter):
|
||||
def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
|
||||
# NB: This may create SQL clauses that don't optimise well (and we don't
|
||||
# have indices on all possible clauses). E.g. it may create
|
||||
# "room_id == X AND room_id != X", which postgres doesn't optimise.
|
||||
@@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
def get_room_min_stream_ordering(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_events_stream_for_rooms(
|
||||
self, room_ids, from_key, to_key, limit=0, order="DESC"
|
||||
):
|
||||
async def get_room_events_stream_for_rooms(
|
||||
self,
|
||||
room_ids: Iterable[str],
|
||||
from_key: str,
|
||||
to_key: str,
|
||||
limit: int = 0,
|
||||
order: str = "DESC",
|
||||
) -> Dict[str, Tuple[List[EventBase], str]]:
|
||||
"""Get new room events in stream ordering since `from_key`.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
from_key (str): Token from which no events are returned before
|
||||
to_key (str): Token from which no events are returned after. (This
|
||||
room_ids
|
||||
from_key: Token from which no events are returned before
|
||||
to_key: Token from which no events are returned after. (This
|
||||
is typically the current stream token)
|
||||
limit (int): Maximum number of events to return
|
||||
order (str): Either "DESC" or "ASC". Determines which events are
|
||||
limit: Maximum number of events to return
|
||||
order: Either "DESC" or "ASC". Determines which events are
|
||||
returned when the result is limited. If "DESC" then the most
|
||||
recent `limit` events are returned, otherwise returns the
|
||||
oldest `limit` events.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str,tuple[list[FrozenEvent], str]]]
|
||||
A map from room id to a tuple containing:
|
||||
- list of recent events in the room
|
||||
- stream ordering key for the start of the chunk of events returned.
|
||||
A map from room id to a tuple containing:
|
||||
- list of recent events in the room
|
||||
- stream ordering key for the start of the chunk of events returned.
|
||||
"""
|
||||
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||
|
||||
room_ids = yield self._events_stream_cache.get_entities_changed(
|
||||
room_ids, from_id
|
||||
)
|
||||
room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
|
||||
|
||||
if not room_ids:
|
||||
return {}
|
||||
@@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
results = {}
|
||||
room_ids = list(room_ids)
|
||||
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
|
||||
res = yield make_deferred_yieldable(
|
||||
res = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
@@ -361,28 +371,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
if self._events_stream_cache.has_entity_changed(room_id, from_key)
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_events_stream_for_room(
|
||||
self, room_id, from_key, to_key, limit=0, order="DESC"
|
||||
):
|
||||
async def get_room_events_stream_for_room(
|
||||
self,
|
||||
room_id: str,
|
||||
from_key: str,
|
||||
to_key: str,
|
||||
limit: int = 0,
|
||||
order: str = "DESC",
|
||||
) -> Tuple[List[EventBase], str]:
|
||||
|
||||
"""Get new room events in stream ordering since `from_key`.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
from_key (str): Token from which no events are returned before
|
||||
to_key (str): Token from which no events are returned after. (This
|
||||
room_id
|
||||
from_key: Token from which no events are returned before
|
||||
to_key: Token from which no events are returned after. (This
|
||||
is typically the current stream token)
|
||||
limit (int): Maximum number of events to return
|
||||
order (str): Either "DESC" or "ASC". Determines which events are
|
||||
limit: Maximum number of events to return
|
||||
order: Either "DESC" or "ASC". Determines which events are
|
||||
returned when the result is limited. If "DESC" then the most
|
||||
recent `limit` events are returned, otherwise returns the
|
||||
oldest `limit` events.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
|
||||
events (in ascending order) and the token from the start of
|
||||
the chunk of events returned.
|
||||
The list of events (in ascending order) and the token from the start
|
||||
of the chunk of events returned.
|
||||
"""
|
||||
if from_key == to_key:
|
||||
return [], from_key
|
||||
@@ -390,9 +403,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||
to_id = RoomStreamToken.parse_stream_token(to_key).stream
|
||||
|
||||
has_changed = yield self._events_stream_cache.has_entity_changed(
|
||||
room_id, from_id
|
||||
)
|
||||
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
|
||||
|
||||
if not has_changed:
|
||||
return [], from_key
|
||||
@@ -410,9 +421,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
|
||||
return rows
|
||||
|
||||
rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
|
||||
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
|
||||
|
||||
ret = yield self.get_events_as_list(
|
||||
ret = await self.get_events_as_list(
|
||||
[r.event_id for r in rows], get_prev_content=True
|
||||
)
|
||||
|
||||
@@ -430,8 +441,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
return ret, key
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_membership_changes_for_user(self, user_id, from_key, to_key):
|
||||
async def get_membership_changes_for_user(self, user_id, from_key, to_key):
|
||||
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||
to_id = RoomStreamToken.parse_stream_token(to_key).stream
|
||||
|
||||
@@ -460,9 +470,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
return rows
|
||||
|
||||
rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
|
||||
rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
|
||||
|
||||
ret = yield self.get_events_as_list(
|
||||
ret = await self.get_events_as_list(
|
||||
[r.event_id for r in rows], get_prev_content=True
|
||||
)
|
||||
|
||||
@@ -470,27 +480,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_recent_events_for_room(self, room_id, limit, end_token):
|
||||
async def get_recent_events_for_room(
|
||||
self, room_id: str, limit: int, end_token: str
|
||||
) -> Tuple[List[EventBase], str]:
|
||||
"""Get the most recent events in the room in topological ordering.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
limit (int)
|
||||
end_token (str): The stream token representing now.
|
||||
room_id
|
||||
limit
|
||||
end_token: The stream token representing now.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
|
||||
events and a token pointing to the start of the returned
|
||||
events.
|
||||
The events returned are in ascending order.
|
||||
A list of events and a token pointing to the start of the returned
|
||||
events. The events returned are in ascending order.
|
||||
"""
|
||||
|
||||
rows, token = yield self.get_recent_event_ids_for_room(
|
||||
rows, token = await self.get_recent_event_ids_for_room(
|
||||
room_id, limit, end_token
|
||||
)
|
||||
|
||||
events = yield self.get_events_as_list(
|
||||
events = await self.get_events_as_list(
|
||||
[r.event_id for r in rows], get_prev_content=True
|
||||
)
|
||||
|
||||
@@ -498,20 +507,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
return (events, token)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_recent_event_ids_for_room(self, room_id, limit, end_token):
|
||||
async def get_recent_event_ids_for_room(
|
||||
self, room_id: str, limit: int, end_token: str
|
||||
) -> Tuple[List[_EventDictReturn], str]:
|
||||
"""Get the most recent events in the room in topological ordering.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
limit (int)
|
||||
end_token (str): The stream token representing now.
|
||||
room_id
|
||||
limit
|
||||
end_token: The stream token representing now.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
|
||||
_EventDictReturn and a token pointing to the start of the returned
|
||||
events.
|
||||
The events returned are in ascending order.
|
||||
A list of _EventDictReturn and a token pointing to the start of the
|
||||
returned events. The events returned are in ascending order.
|
||||
"""
|
||||
# Allow a zero limit here, and no-op.
|
||||
if limit == 0:
|
||||
@@ -519,7 +527,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
end_token = RoomStreamToken.parse(end_token)
|
||||
|
||||
rows, token = yield self.db_pool.runInteraction(
|
||||
rows, token = await self.db_pool.runInteraction(
|
||||
"get_recent_event_ids_for_room",
|
||||
self._paginate_room_events_txn,
|
||||
room_id,
|
||||
@@ -532,12 +540,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
return rows, token
|
||||
|
||||
def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
|
||||
def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
|
||||
"""Gets details of the first event in a room at or before a stream ordering
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
stream_ordering (int):
|
||||
room_id:
|
||||
stream_ordering:
|
||||
|
||||
Returns:
|
||||
Deferred[(int, int, str)]:
|
||||
@@ -574,55 +582,67 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
)
|
||||
return "t%d-%d" % (topo, token)
|
||||
|
||||
def get_stream_token_for_event(self, event_id):
|
||||
"""The stream token for an event
|
||||
async def get_stream_id_for_event(self, event_id: str) -> int:
|
||||
"""The stream ID for an event
|
||||
Args:
|
||||
event_id(str): The id of the event to look up a stream token for.
|
||||
event_id: The id of the event to look up a stream token for.
|
||||
Raises:
|
||||
StoreError if the event wasn't in the database.
|
||||
Returns:
|
||||
A deferred "s%d" stream token.
|
||||
A stream ID.
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
|
||||
).addCallback(lambda row: "s%d" % (row,))
|
||||
)
|
||||
|
||||
def get_topological_token_for_event(self, event_id):
|
||||
async def get_stream_token_for_event(self, event_id: str) -> str:
|
||||
"""The stream token for an event
|
||||
Args:
|
||||
event_id(str): The id of the event to look up a stream token for.
|
||||
event_id: The id of the event to look up a stream token for.
|
||||
Raises:
|
||||
StoreError if the event wasn't in the database.
|
||||
Returns:
|
||||
A deferred "t%d-%d" topological token.
|
||||
A "s%d" stream token.
|
||||
"""
|
||||
return self.db_pool.simple_select_one(
|
||||
stream_id = await self.get_stream_id_for_event(event_id)
|
||||
return "s%d" % (stream_id,)
|
||||
|
||||
async def get_topological_token_for_event(self, event_id: str) -> str:
|
||||
"""The stream token for an event
|
||||
Args:
|
||||
event_id: The id of the event to look up a stream token for.
|
||||
Raises:
|
||||
StoreError if the event wasn't in the database.
|
||||
Returns:
|
||||
A "t%d-%d" topological token.
|
||||
"""
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="events",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcols=("stream_ordering", "topological_ordering"),
|
||||
desc="get_topological_token_for_event",
|
||||
).addCallback(
|
||||
lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
|
||||
)
|
||||
return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
|
||||
|
||||
def get_max_topological_token(self, room_id, stream_key):
|
||||
async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
|
||||
"""Get the max topological token in a room before the given stream
|
||||
ordering.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
stream_key (int)
|
||||
room_id
|
||||
stream_key
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
The maximum topological token.
|
||||
"""
|
||||
sql = (
|
||||
"SELECT coalesce(max(topological_ordering), 0) FROM events"
|
||||
" WHERE room_id = ? AND stream_ordering < ?"
|
||||
)
|
||||
return self.db_pool.execute(
|
||||
row = await self.db_pool.execute(
|
||||
"get_max_topological_token", None, sql, room_id, stream_key
|
||||
).addCallback(lambda r: r[0][0] if r else 0)
|
||||
)
|
||||
return row[0][0] if row else 0
|
||||
|
||||
def _get_max_topological_txn(self, txn, room_id):
|
||||
txn.execute(
|
||||
@@ -634,16 +654,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
return rows[0][0] if rows else 0
|
||||
|
||||
@staticmethod
|
||||
def _set_before_and_after(events, rows, topo_order=True):
|
||||
def _set_before_and_after(
|
||||
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
|
||||
):
|
||||
"""Inserts ordering information to events' internal metadata from
|
||||
the DB rows.
|
||||
|
||||
Args:
|
||||
events (list[FrozenEvent])
|
||||
rows (list[_EventDictReturn])
|
||||
topo_order (bool): Whether the events were ordered topologically
|
||||
or by stream ordering. If true then all rows should have a non
|
||||
null topological_ordering.
|
||||
events
|
||||
rows
|
||||
topo_order: Whether the events were ordered topologically or by stream
|
||||
ordering. If true then all rows should have a non null
|
||||
topological_ordering.
|
||||
"""
|
||||
for event, row in zip(events, rows):
|
||||
stream = row.stream_ordering
|
||||
@@ -656,25 +678,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
internal.after = str(RoomStreamToken(topo, stream))
|
||||
internal.order = (int(topo) if topo else 0, int(stream))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events_around(
|
||||
self, room_id, event_id, before_limit, after_limit, event_filter=None
|
||||
):
|
||||
async def get_events_around(
|
||||
self,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
before_limit: int,
|
||||
after_limit: int,
|
||||
event_filter: Optional[Filter] = None,
|
||||
) -> dict:
|
||||
"""Retrieve events and pagination tokens around a given event in a
|
||||
room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
event_id (str)
|
||||
before_limit (int)
|
||||
after_limit (int)
|
||||
event_filter (Filter|None)
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
|
||||
results = yield self.db_pool.runInteraction(
|
||||
results = await self.db_pool.runInteraction(
|
||||
"get_events_around",
|
||||
self._get_events_around_txn,
|
||||
room_id,
|
||||
@@ -684,11 +700,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
event_filter,
|
||||
)
|
||||
|
||||
events_before = yield self.get_events_as_list(
|
||||
events_before = await self.get_events_as_list(
|
||||
list(results["before"]["event_ids"]), get_prev_content=True
|
||||
)
|
||||
|
||||
events_after = yield self.get_events_as_list(
|
||||
events_after = await self.get_events_as_list(
|
||||
list(results["after"]["event_ids"]), get_prev_content=True
|
||||
)
|
||||
|
||||
@@ -700,17 +716,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
}
|
||||
|
||||
def _get_events_around_txn(
|
||||
self, txn, room_id, event_id, before_limit, after_limit, event_filter
|
||||
):
|
||||
self,
|
||||
txn,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
before_limit: int,
|
||||
after_limit: int,
|
||||
event_filter: Optional[Filter],
|
||||
) -> dict:
|
||||
"""Retrieves event_ids and pagination tokens around a given event in a
|
||||
room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
event_id (str)
|
||||
before_limit (int)
|
||||
after_limit (int)
|
||||
event_filter (Filter|None)
|
||||
room_id
|
||||
event_id
|
||||
before_limit
|
||||
after_limit
|
||||
event_filter
|
||||
|
||||
Returns:
|
||||
dict
|
||||
@@ -758,22 +780,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"after": {"event_ids": events_after, "token": end_token},
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_all_new_events_stream(self, from_id, current_id, limit):
|
||||
async def get_all_new_events_stream(
|
||||
self, from_id: int, current_id: int, limit: int
|
||||
) -> Tuple[int, List[EventBase]]:
|
||||
"""Get all new events
|
||||
|
||||
Returns all events with from_id < stream_ordering <= current_id.
|
||||
|
||||
Args:
|
||||
from_id (int): the stream_ordering of the last event we processed
|
||||
current_id (int): the stream_ordering of the most recently processed event
|
||||
limit (int): the maximum number of events to return
|
||||
from_id: the stream_ordering of the last event we processed
|
||||
current_id: the stream_ordering of the most recently processed event
|
||||
limit: the maximum number of events to return
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
|
||||
`next_id` is the next value to pass as `from_id` (it will either be the
|
||||
stream_ordering of the last returned event, or, if fewer than `limit` events
|
||||
were found, `current_id`.
|
||||
A tuple of (next_id, events), where `next_id` is the next value to
|
||||
pass as `from_id` (it will either be the stream_ordering of the
|
||||
last returned event, or, if fewer than `limit` events were found,
|
||||
the `current_id`).
|
||||
"""
|
||||
|
||||
def get_all_new_events_stream_txn(txn):
|
||||
@@ -795,11 +818,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
return upper_bound, [row[1] for row in rows]
|
||||
|
||||
upper_bound, event_ids = yield self.db_pool.runInteraction(
|
||||
upper_bound, event_ids = await self.db_pool.runInteraction(
|
||||
"get_all_new_events_stream", get_all_new_events_stream_txn
|
||||
)
|
||||
|
||||
events = yield self.get_events_as_list(event_ids)
|
||||
events = await self.get_events_as_list(event_ids)
|
||||
|
||||
return upper_bound, events
|
||||
|
||||
@@ -817,21 +840,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
desc="get_federation_out_pos",
|
||||
)
|
||||
|
||||
async def update_federation_out_pos(self, typ, stream_id):
|
||||
async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
|
||||
if self._need_to_reset_federation_stream_positions:
|
||||
await self.db_pool.runInteraction(
|
||||
"_reset_federation_positions_txn", self._reset_federation_positions_txn
|
||||
)
|
||||
self._need_to_reset_federation_stream_positions = False
|
||||
|
||||
return await self.db_pool.simple_update_one(
|
||||
await self.db_pool.simple_update_one(
|
||||
table="federation_stream_position",
|
||||
keyvalues={"type": typ, "instance_name": self._instance_name},
|
||||
updatevalues={"stream_id": stream_id},
|
||||
desc="update_federation_out_pos",
|
||||
)
|
||||
|
||||
def _reset_federation_positions_txn(self, txn):
|
||||
def _reset_federation_positions_txn(self, txn) -> None:
|
||||
"""Fiddles with the `federation_stream_position` table to make it match
|
||||
the configured federation sender instances during start up.
|
||||
"""
|
||||
@@ -892,39 +915,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
values={"stream_id": stream_id},
|
||||
)
|
||||
|
||||
def has_room_changed_since(self, room_id, stream_id):
|
||||
def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
|
||||
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
|
||||
|
||||
def _paginate_room_events_txn(
|
||||
self,
|
||||
txn,
|
||||
room_id,
|
||||
from_token,
|
||||
to_token=None,
|
||||
direction="b",
|
||||
limit=-1,
|
||||
event_filter=None,
|
||||
):
|
||||
room_id: str,
|
||||
from_token: RoomStreamToken,
|
||||
to_token: Optional[RoomStreamToken] = None,
|
||||
direction: str = "b",
|
||||
limit: int = -1,
|
||||
event_filter: Optional[Filter] = None,
|
||||
) -> Tuple[List[_EventDictReturn], str]:
|
||||
"""Returns list of events before or after a given token.
|
||||
|
||||
Args:
|
||||
txn
|
||||
room_id (str)
|
||||
from_token (RoomStreamToken): The token used to stream from
|
||||
to_token (RoomStreamToken|None): A token which if given limits the
|
||||
results to only those before
|
||||
direction(char): Either 'b' or 'f' to indicate whether we are
|
||||
paginating forwards or backwards from `from_key`.
|
||||
limit (int): The maximum number of events to return.
|
||||
event_filter (Filter|None): If provided filters the events to
|
||||
room_id
|
||||
from_token: The token used to stream from
|
||||
to_token: A token which if given limits the results to only those before
|
||||
direction: Either 'b' or 'f' to indicate whether we are paginating
|
||||
forwards or backwards from `from_key`.
|
||||
limit: The maximum number of events to return.
|
||||
event_filter: If provided filters the events to
|
||||
those that match the filter.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
|
||||
as a list of _EventDictReturn and a token that points to the end
|
||||
of the result set. If no events are returned then the end of the
|
||||
stream has been reached (i.e. there are no events between
|
||||
`from_token` and `to_token`), or `limit` is zero.
|
||||
A list of _EventDictReturn and a token that points to the end of the
|
||||
result set. If no events are returned then the end of the stream has
|
||||
been reached (i.e. there are no events between `from_token` and
|
||||
`to_token`), or `limit` is zero.
|
||||
"""
|
||||
|
||||
assert int(limit) >= 0
|
||||
@@ -1008,35 +1029,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
return rows, str(next_token)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def paginate_room_events(
|
||||
self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
|
||||
):
|
||||
async def paginate_room_events(
|
||||
self,
|
||||
room_id: str,
|
||||
from_key: str,
|
||||
to_key: Optional[str] = None,
|
||||
direction: str = "b",
|
||||
limit: int = -1,
|
||||
event_filter: Optional[Filter] = None,
|
||||
) -> Tuple[List[EventBase], str]:
|
||||
"""Returns list of events before or after a given token.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
from_key (str): The token used to stream from
|
||||
to_key (str|None): A token which if given limits the results to
|
||||
only those before
|
||||
direction(char): Either 'b' or 'f' to indicate whether we are
|
||||
paginating forwards or backwards from `from_key`.
|
||||
limit (int): The maximum number of events to return.
|
||||
event_filter (Filter|None): If provided filters the events to
|
||||
those that match the filter.
|
||||
room_id
|
||||
from_key: The token used to stream from
|
||||
to_key: A token which if given limits the results to only those before
|
||||
direction: Either 'b' or 'f' to indicate whether we are paginating
|
||||
forwards or backwards from `from_key`.
|
||||
limit: The maximum number of events to return.
|
||||
event_filter: If provided filters the events to those that match the filter.
|
||||
|
||||
Returns:
|
||||
tuple[list[FrozenEvent], str]: Returns the results as a list of
|
||||
events and a token that points to the end of the result set. If no
|
||||
events are returned then the end of the stream has been reached
|
||||
(i.e. there are no events between `from_key` and `to_key`).
|
||||
The results as a list of events and a token that points to the end
|
||||
of the result set. If no events are returned then the end of the
|
||||
stream has been reached (i.e. there are no events between `from_key`
|
||||
and `to_key`).
|
||||
"""
|
||||
|
||||
from_key = RoomStreamToken.parse(from_key)
|
||||
if to_key:
|
||||
to_key = RoomStreamToken.parse(to_key)
|
||||
|
||||
rows, token = yield self.db_pool.runInteraction(
|
||||
rows, token = await self.db_pool.runInteraction(
|
||||
"paginate_room_events",
|
||||
self._paginate_room_events_txn,
|
||||
room_id,
|
||||
@@ -1047,7 +1071,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
event_filter,
|
||||
)
|
||||
|
||||
events = yield self.get_events_as_list(
|
||||
events = await self.get_events_as_list(
|
||||
[r.event_id for r in rows], get_prev_content=True
|
||||
)
|
||||
|
||||
@@ -1057,8 +1081,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
|
||||
class StreamStore(StreamWorkerStore):
|
||||
def get_room_max_stream_ordering(self):
|
||||
def get_room_max_stream_ordering(self) -> int:
|
||||
return self._stream_id_gen.get_current_token()
|
||||
|
||||
def get_room_min_stream_ordering(self):
|
||||
def get_room_min_stream_ordering(self) -> int:
|
||||
return self._backfill_id_gen.get_current_token()
|
||||
|
||||
@@ -38,10 +38,8 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||
desc="is_user_erased",
|
||||
).addCallback(operator.truth)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
|
||||
)
|
||||
def are_users_erased(self, user_ids):
|
||||
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
||||
async def are_users_erased(self, user_ids):
|
||||
"""
|
||||
Checks which users in a list have requested erasure
|
||||
|
||||
@@ -49,14 +47,14 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||
user_ids (iterable[str]): full user id to check
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, bool]]:
|
||||
dict[str, bool]:
|
||||
for each user, whether the user has requested erasure.
|
||||
"""
|
||||
# this serves the dual purpose of (a) making sure we can do len and
|
||||
# iterate it multiple times, and (b) avoiding duplicates.
|
||||
user_ids = tuple(set(user_ids))
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="erased_users",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
@@ -65,8 +63,7 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||
)
|
||||
erased_users = {row["user_id"] for row in rows}
|
||||
|
||||
res = {u: u in erased_users for u in user_ids}
|
||||
return res
|
||||
return {u: u in erased_users for u in user_ids}
|
||||
|
||||
|
||||
class UserErasureStore(UserErasureWorkerStore):
|
||||
|
||||
@@ -51,7 +51,15 @@ JsonDict = Dict[str, Any]
|
||||
|
||||
class Requester(
|
||||
namedtuple(
|
||||
"Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
|
||||
"Requester",
|
||||
[
|
||||
"user",
|
||||
"access_token_id",
|
||||
"is_guest",
|
||||
"shadow_banned",
|
||||
"device_id",
|
||||
"app_service",
|
||||
],
|
||||
)
|
||||
):
|
||||
"""
|
||||
@@ -62,6 +70,7 @@ class Requester(
|
||||
access_token_id (int|None): *ID* of the access token used for this
|
||||
request, or None if it came via the appservice API or similar
|
||||
is_guest (bool): True if the user making this request is a guest user
|
||||
shadow_banned (bool): True if the user making this request has been shadow-banned.
|
||||
device_id (str|None): device_id which was set at authentication time
|
||||
app_service (ApplicationService|None): the AS requesting on behalf of the user
|
||||
"""
|
||||
@@ -77,6 +86,7 @@ class Requester(
|
||||
"user_id": self.user.to_string(),
|
||||
"access_token_id": self.access_token_id,
|
||||
"is_guest": self.is_guest,
|
||||
"shadow_banned": self.shadow_banned,
|
||||
"device_id": self.device_id,
|
||||
"app_server_id": self.app_service.id if self.app_service else None,
|
||||
}
|
||||
@@ -101,13 +111,19 @@ class Requester(
|
||||
user=UserID.from_string(input["user_id"]),
|
||||
access_token_id=input["access_token_id"],
|
||||
is_guest=input["is_guest"],
|
||||
shadow_banned=input["shadow_banned"],
|
||||
device_id=input["device_id"],
|
||||
app_service=appservice,
|
||||
)
|
||||
|
||||
|
||||
def create_requester(
|
||||
user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
|
||||
user_id,
|
||||
access_token_id=None,
|
||||
is_guest=False,
|
||||
shadow_banned=False,
|
||||
device_id=None,
|
||||
app_service=None,
|
||||
):
|
||||
"""
|
||||
Create a new ``Requester`` object
|
||||
@@ -117,6 +133,7 @@ def create_requester(
|
||||
access_token_id (int|None): *ID* of the access token used for this
|
||||
request, or None if it came via the appservice API or similar
|
||||
is_guest (bool): True if the user making this request is a guest user
|
||||
shadow_banned (bool): True if the user making this request is shadow-banned.
|
||||
device_id (str|None): device_id which was set at authentication time
|
||||
app_service (ApplicationService|None): the AS requesting on behalf of the user
|
||||
|
||||
@@ -125,7 +142,9 @@ def create_requester(
|
||||
"""
|
||||
if not isinstance(user_id, UserID):
|
||||
user_id = UserID.from_string(user_id)
|
||||
return Requester(user_id, access_token_id, is_guest, device_id, app_service)
|
||||
return Requester(
|
||||
user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
|
||||
)
|
||||
|
||||
|
||||
def get_domain_from_id(string):
|
||||
|
||||
@@ -32,7 +32,6 @@ json_encoder = json.JSONEncoder(separators=(",", ":"))
|
||||
def unwrapFirstError(failure):
|
||||
# defer.gatherResults and DeferredLists wrap failures.
|
||||
failure.trap(defer.FirstError)
|
||||
logger.info("DDD failure.value.subFailure: %s", failure.value.subFailure)
|
||||
return failure.value.subFailure
|
||||
|
||||
|
||||
|
||||
@@ -24,9 +24,7 @@ from synapse.api.errors import Codes, SynapseError
|
||||
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
|
||||
|
||||
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
|
||||
# Note: The : character is allowed here for older clients, but will be removed in a
|
||||
# future release. Context: https://github.com/matrix-org/synapse/issues/6766
|
||||
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$")
|
||||
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
|
||||
|
||||
# random_string and random_string_with_symbols are used for a range of things,
|
||||
# some cryptographically important, some less so. We use SystemRandom to make sure
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.types import create_requester
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -20,77 +18,6 @@ class TestRatelimiter(unittest.TestCase):
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(20.0, time_allowed)
|
||||
|
||||
def test_allowed_user_via_can_requester_do_action(self):
|
||||
user_requester = create_requester("@user:example.com")
|
||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
user_requester, _time_now_s=0
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(10.0, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
user_requester, _time_now_s=5
|
||||
)
|
||||
self.assertFalse(allowed)
|
||||
self.assertEquals(10.0, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
user_requester, _time_now_s=10
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(20.0, time_allowed)
|
||||
|
||||
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
|
||||
appservice = ApplicationService(
|
||||
None, "example.com", id="foo", rate_limited=True,
|
||||
)
|
||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||
|
||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=0
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(10.0, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=5
|
||||
)
|
||||
self.assertFalse(allowed)
|
||||
self.assertEquals(10.0, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=10
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(20.0, time_allowed)
|
||||
|
||||
def test_allowed_appservice_via_can_requester_do_action(self):
|
||||
appservice = ApplicationService(
|
||||
None, "example.com", id="foo", rate_limited=False,
|
||||
)
|
||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||
|
||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=0
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(-1, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=5
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(-1, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=10
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(-1, time_allowed)
|
||||
|
||||
def test_allowed_via_ratelimit(self):
|
||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
||||
|
||||
|
||||
82
tests/config/test_base.py
Normal file
82
tests/config/test_base.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os.path
|
||||
import tempfile
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class BaseConfigTestCase(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.hs = hs
|
||||
|
||||
def test_loading_missing_templates(self):
|
||||
# Use a temporary directory that exists on the system, but that isn't likely to
|
||||
# contain template files
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Attempt to load an HTML template from our custom template directory
|
||||
template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
|
||||
|
||||
# If no errors, we should've gotten the default template instead
|
||||
|
||||
# Render the template
|
||||
a_random_string = random_string(5)
|
||||
html_content = template.render({"error_description": a_random_string})
|
||||
|
||||
# Check that our string exists in the template
|
||||
self.assertIn(
|
||||
a_random_string,
|
||||
html_content,
|
||||
"Template file did not contain our test string",
|
||||
)
|
||||
|
||||
def test_loading_custom_templates(self):
|
||||
# Use a temporary directory that exists on the system
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Create a temporary bogus template file
|
||||
with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template:
|
||||
# Get temporary file's filename
|
||||
template_filename = os.path.basename(tmp_template.name)
|
||||
|
||||
# Write a custom HTML template
|
||||
contents = b"{{ test_variable }}"
|
||||
tmp_template.write(contents)
|
||||
tmp_template.flush()
|
||||
|
||||
# Attempt to load the template from our custom template directory
|
||||
template = (
|
||||
self.hs.config.read_templates([template_filename], tmp_dir)
|
||||
)[0]
|
||||
|
||||
# Render the template
|
||||
a_random_string = random_string(5)
|
||||
html_content = template.render({"test_variable": a_random_string})
|
||||
|
||||
# Check that our string exists in the template
|
||||
self.assertIn(
|
||||
a_random_string,
|
||||
html_content,
|
||||
"Template file did not contain our test string",
|
||||
)
|
||||
|
||||
def test_loading_template_from_nonexistent_custom_directory(self):
|
||||
with self.assertRaises(ConfigError):
|
||||
self.hs.config.read_templates(
|
||||
["some_filename.html"], "a_nonexistent_directory"
|
||||
)
|
||||
@@ -79,9 +79,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(
|
||||
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
|
||||
)
|
||||
handler.federation_handler.do_invite_join = Mock(
|
||||
return_value=make_awaitable(("", 1))
|
||||
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
|
||||
)
|
||||
|
||||
d = handler._remote_join(
|
||||
@@ -110,9 +112,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(
|
||||
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
|
||||
)
|
||||
handler.federation_handler.do_invite_join = Mock(
|
||||
return_value=make_awaitable(("", 1))
|
||||
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
|
||||
)
|
||||
|
||||
d = handler._remote_join(
|
||||
@@ -148,9 +152,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
|
||||
fed_transport.client.get_json = Mock(
|
||||
side_effect=lambda *args, **kwargs: make_awaitable(None)
|
||||
)
|
||||
handler.federation_handler.do_invite_join = Mock(
|
||||
return_value=make_awaitable(("", 1))
|
||||
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
|
||||
)
|
||||
|
||||
# Artificially raise the complexity
|
||||
@@ -204,9 +210,11 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
||||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(
|
||||
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
|
||||
)
|
||||
handler.federation_handler.do_invite_join = Mock(
|
||||
return_value=make_awaitable(("", 1))
|
||||
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
|
||||
)
|
||||
|
||||
d = handler._remote_join(
|
||||
@@ -234,9 +242,11 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
||||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(
|
||||
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
|
||||
)
|
||||
handler.federation_handler.do_invite_join = Mock(
|
||||
return_value=make_awaitable(("", 1))
|
||||
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
|
||||
)
|
||||
|
||||
d = handler._remote_join(
|
||||
|
||||
@@ -19,6 +19,7 @@ from mock import Mock, call
|
||||
from signedjson.key import generate_signing_key
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, PresenceState
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events.builder import EventBuilder
|
||||
from synapse.handlers.presence import (
|
||||
@@ -32,7 +33,6 @@ from synapse.handlers.presence import (
|
||||
handle_update,
|
||||
)
|
||||
from synapse.rest.client.v1 import room
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||
self.bob = UserID.from_string("@4567:test")
|
||||
self.alice = UserID.from_string("@alice:remote")
|
||||
|
||||
yield self.store.create_profile(self.frank.localpart)
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
|
||||
|
||||
self.handler = hs.get_profile_handler()
|
||||
self.hs = hs
|
||||
@@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_incoming_fed_query(self):
|
||||
yield self.store.create_profile("caroline")
|
||||
yield defer.ensureDeferred(self.store.create_profile("caroline"))
|
||||
yield self.store.set_profile_displayname("caroline", "Caroline")
|
||||
|
||||
response = yield defer.ensureDeferred(
|
||||
|
||||
@@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
([], 0)
|
||||
)
|
||||
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
|
||||
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
|
||||
self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
|
||||
None
|
||||
)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.handlers.pagination import PurgeStatus
|
||||
from synapse.rest.client.v1 import directory, login, profile, room
|
||||
from synapse.rest.client.v2_alpha import account
|
||||
from synapse.types import JsonDict, RoomAlias, UserID
|
||||
from synapse.types import JsonDict, RoomAlias
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from tests import unittest
|
||||
@@ -675,91 +675,6 @@ class RoomMemberStateTestCase(RoomBase):
|
||||
self.assertEquals(json.loads(content), channel.json_body)
|
||||
|
||||
|
||||
class RoomJoinRatelimitTestCase(RoomBase):
|
||||
user_id = "@sid1:red"
|
||||
|
||||
servlets = [
|
||||
profile.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit(self):
|
||||
"""Tests that local joins are actually rate-limited."""
|
||||
for i in range(5):
|
||||
self.helper.create_room_as(self.user_id)
|
||||
|
||||
self.helper.create_room_as(self.user_id, expect_code=429)
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit_profile_change(self):
|
||||
"""Tests that sending a profile update into all of the user's joined rooms isn't
|
||||
rate-limited by the rate-limiter on joins."""
|
||||
|
||||
# Create and join more rooms than the rate-limiting config allows in a second.
|
||||
room_ids = [
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
]
|
||||
self.reactor.advance(1)
|
||||
room_ids = room_ids + [
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
]
|
||||
|
||||
# Create a profile for the user, since it hasn't been done on registration.
|
||||
store = self.hs.get_datastore()
|
||||
store.create_profile(UserID.from_string(self.user_id).localpart)
|
||||
|
||||
# Update the display name for the user.
|
||||
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
|
||||
request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
|
||||
self.render(request)
|
||||
self.assertEquals(channel.code, 200, channel.json_body)
|
||||
|
||||
# Check that all the rooms have been sent a profile update into.
|
||||
for room_id in room_ids:
|
||||
path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
|
||||
room_id,
|
||||
self.user_id,
|
||||
)
|
||||
|
||||
request, channel = self.make_request("GET", path)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.code, 200)
|
||||
|
||||
self.assertIn("displayname", channel.json_body)
|
||||
self.assertEquals(channel.json_body["displayname"], "John Doe")
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit_idempotent(self):
|
||||
"""Tests that the room join endpoints remain idempotent despite rate-limiting
|
||||
on room joins."""
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
# Let's test both paths to be sure.
|
||||
paths_to_test = [
|
||||
"/_matrix/client/r0/rooms/%s/join",
|
||||
"/_matrix/client/r0/join/%s",
|
||||
]
|
||||
|
||||
for path in paths_to_test:
|
||||
# Make sure we send more requests than the rate-limiting config would allow
|
||||
# if all of these requests ended up joining the user to a room.
|
||||
for i in range(6):
|
||||
request, channel = self.make_request("POST", path % room_id, {})
|
||||
self.render(request)
|
||||
self.assertEquals(channel.code, 200)
|
||||
|
||||
|
||||
class RoomMessagesTestCase(RoomBase):
|
||||
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
|
||||
|
||||
|
||||
@@ -39,9 +39,7 @@ class RestHelper(object):
|
||||
resource = attr.ib()
|
||||
auth_user_id = attr.ib()
|
||||
|
||||
def create_room_as(
|
||||
self, room_creator=None, is_public=True, tok=None, expect_code=200,
|
||||
):
|
||||
def create_room_as(self, room_creator=None, is_public=True, tok=None):
|
||||
temp_id = self.auth_user_id
|
||||
self.auth_user_id = room_creator
|
||||
path = "/_matrix/client/r0/createRoom"
|
||||
@@ -56,11 +54,9 @@ class RestHelper(object):
|
||||
)
|
||||
render(request, self.resource, self.hs.get_reactor())
|
||||
|
||||
assert channel.result["code"] == b"%d" % expect_code, channel.result
|
||||
assert channel.result["code"] == b"200", channel.result
|
||||
self.auth_user_id = temp_id
|
||||
|
||||
if expect_code == 200:
|
||||
return channel.json_body["room_id"]
|
||||
return channel.json_body["room_id"]
|
||||
|
||||
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
|
||||
self.change_membership(
|
||||
|
||||
@@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_set_appservices_state_down(self):
|
||||
service = Mock(id=self.as_list[1]["id"])
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
)
|
||||
rows = yield self.db_pool.runQuery(
|
||||
self.engine.convert_param_style(
|
||||
"SELECT as_id FROM application_services_state WHERE state=?"
|
||||
@@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_set_appservices_state_multiple_up(self):
|
||||
service = Mock(id=self.as_list[1]["id"])
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||
)
|
||||
rows = yield self.db_pool.runQuery(
|
||||
self.engine.convert_param_style(
|
||||
"SELECT as_id FROM application_services_state WHERE state=?"
|
||||
|
||||
@@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
def test_insert_1col(self):
|
||||
self.mock_txn.rowcount = 1
|
||||
|
||||
yield self.datastore.db_pool.simple_insert(
|
||||
table="tablename", values={"columname": "Value"}
|
||||
yield defer.ensureDeferred(
|
||||
self.datastore.db_pool.simple_insert(
|
||||
table="tablename", values={"columname": "Value"}
|
||||
)
|
||||
)
|
||||
|
||||
self.mock_txn.execute.assert_called_with(
|
||||
@@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
def test_insert_3cols(self):
|
||||
self.mock_txn.rowcount = 1
|
||||
|
||||
yield self.datastore.db_pool.simple_insert(
|
||||
table="tablename",
|
||||
# Use OrderedDict() so we can assert on the SQL generated
|
||||
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
|
||||
yield defer.ensureDeferred(
|
||||
self.datastore.db_pool.simple_insert(
|
||||
table="tablename",
|
||||
# Use OrderedDict() so we can assert on the SQL generated
|
||||
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
|
||||
)
|
||||
)
|
||||
|
||||
self.mock_txn.execute.assert_called_with(
|
||||
|
||||
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||
|
||||
# Create a test user and room
|
||||
self.user = UserID("alice", "test")
|
||||
self.requester = Requester(self.user, None, False, None, None)
|
||||
self.requester = Requester(self.user, None, False, False, None, None)
|
||||
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
|
||||
self.room_id = info["room_id"]
|
||||
|
||||
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
||||
# Create a test user and room
|
||||
self.user = UserID.from_string(self.register_user("user1", "password"))
|
||||
self.token1 = self.login("user1", "password")
|
||||
self.requester = Requester(self.user, None, False, None, None)
|
||||
self.requester = Requester(self.user, None, False, False, None, None)
|
||||
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
|
||||
self.room_id = info["room_id"]
|
||||
self.event_creator = homeserver.get_event_creation_handler()
|
||||
|
||||
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
|
||||
room_creator = self.hs.get_room_creation_handler()
|
||||
|
||||
user = UserID("alice", "test")
|
||||
requester = Requester(user, None, False, None, None)
|
||||
requester = Requester(user, None, False, False, None, None)
|
||||
|
||||
# Real events, forward extremities
|
||||
events = [(3, 2), (6, 2), (4, 6)]
|
||||
|
||||
@@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_find_first_stream_ordering_after_ts(self):
|
||||
def add_event(so, ts):
|
||||
return self.store.db_pool.simple_insert(
|
||||
"events",
|
||||
{
|
||||
"stream_ordering": so,
|
||||
"received_ts": ts,
|
||||
"event_id": "event%i" % so,
|
||||
"type": "",
|
||||
"room_id": "",
|
||||
"content": "",
|
||||
"processed": True,
|
||||
"outlier": False,
|
||||
"topological_ordering": 0,
|
||||
"depth": 0,
|
||||
},
|
||||
return defer.ensureDeferred(
|
||||
self.store.db_pool.simple_insert(
|
||||
"events",
|
||||
{
|
||||
"stream_ordering": so,
|
||||
"received_ts": ts,
|
||||
"event_id": "event%i" % so,
|
||||
"type": "",
|
||||
"room_id": "",
|
||||
"content": "",
|
||||
"processed": True,
|
||||
"outlier": False,
|
||||
"topological_ordering": 0,
|
||||
"depth": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# start with the base case where there are no events in the table
|
||||
|
||||
@@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_get_users_paginate(self):
|
||||
yield self.store.register_user(self.user.to_string(), "pass")
|
||||
yield self.store.create_profile(self.user.localpart)
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
|
||||
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
|
||||
|
||||
users, total = yield self.store.get_users_paginate(
|
||||
|
||||
@@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_displayname(self):
|
||||
yield self.store.create_profile(self.u_frank.localpart)
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
|
||||
|
||||
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
|
||||
|
||||
@@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_avatar_url(self):
|
||||
yield self.store.create_profile(self.u_frank.localpart)
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
|
||||
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.u_frank.localpart, "http://my.site/here"
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.rest.client.v1 import room
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
@@ -46,30 +47,19 @@ class PurgeTests(HomeserverTestCase):
|
||||
storage = self.hs.get_storage()
|
||||
|
||||
# Get the topological token
|
||||
event = store.get_topological_token_for_event(last["event_id"])
|
||||
self.pump()
|
||||
event = self.successResultOf(event)
|
||||
event = self.get_success(
|
||||
store.get_topological_token_for_event(last["event_id"])
|
||||
)
|
||||
|
||||
# Purge everything before this topological token
|
||||
purge = defer.ensureDeferred(
|
||||
storage.purge_events.purge_history(self.room_id, event, True)
|
||||
)
|
||||
self.pump()
|
||||
self.assertEqual(self.successResultOf(purge), None)
|
||||
|
||||
# Try and get the events
|
||||
get_first = store.get_event(first["event_id"])
|
||||
get_second = store.get_event(second["event_id"])
|
||||
get_third = store.get_event(third["event_id"])
|
||||
get_last = store.get_event(last["event_id"])
|
||||
self.pump()
|
||||
self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
|
||||
|
||||
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
||||
# and last is not.
|
||||
self.failureResultOf(get_first)
|
||||
self.failureResultOf(get_second)
|
||||
self.failureResultOf(get_third)
|
||||
self.successResultOf(get_last)
|
||||
self.get_failure(store.get_event(first["event_id"]), NotFoundError)
|
||||
self.get_failure(store.get_event(second["event_id"]), NotFoundError)
|
||||
self.get_failure(store.get_event(third["event_id"]), NotFoundError)
|
||||
self.get_success(store.get_event(last["event_id"]))
|
||||
|
||||
def test_purge_wont_delete_extrems(self):
|
||||
"""
|
||||
@@ -84,9 +74,9 @@ class PurgeTests(HomeserverTestCase):
|
||||
storage = self.hs.get_datastore()
|
||||
|
||||
# Set the topological token higher than it should be
|
||||
event = storage.get_topological_token_for_event(last["event_id"])
|
||||
self.pump()
|
||||
event = self.successResultOf(event)
|
||||
event = self.get_success(
|
||||
storage.get_topological_token_for_event(last["event_id"])
|
||||
)
|
||||
event = "t{}-{}".format(
|
||||
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
|
||||
)
|
||||
@@ -98,14 +88,7 @@ class PurgeTests(HomeserverTestCase):
|
||||
self.assertIn("greater than forward", f.value.args[0])
|
||||
|
||||
# Try and get the events
|
||||
get_first = storage.get_event(first["event_id"])
|
||||
get_second = storage.get_event(second["event_id"])
|
||||
get_third = storage.get_event(third["event_id"])
|
||||
get_last = storage.get_event(last["event_id"])
|
||||
self.pump()
|
||||
|
||||
# Nothing is deleted.
|
||||
self.successResultOf(get_first)
|
||||
self.successResultOf(get_second)
|
||||
self.successResultOf(get_third)
|
||||
self.successResultOf(get_last)
|
||||
self.get_success(storage.get_event(first["event_id"]))
|
||||
self.get_success(storage.get_event(second["event_id"]))
|
||||
self.get_success(storage.get_event(third["event_id"]))
|
||||
self.get_success(storage.get_event(last["event_id"]))
|
||||
|
||||
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Now let's create a room, which will insert a membership
|
||||
user = UserID("alice", "test")
|
||||
requester = Requester(user, None, False, None, None)
|
||||
requester = Requester(user, None, False, False, None, None)
|
||||
self.get_success(self.room_creator.create_room(requester, {}))
|
||||
|
||||
# Register the background update to run again.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user