1
0

Compare commits

..

14 Commits

Author SHA1 Message Date
Andrew Morgan bbd04152c9 Something daring 2020-09-18 10:58:14 +01:00
Andrew Morgan 8544a78fa6 Log pdus 2020-09-18 10:53:12 +01:00
Andrew Morgan 41f4242b68 What is a subfailure? 2020-09-18 10:49:47 +01:00
Andrew Morgan dca31cf978 Really kill it 2020-09-18 10:47:40 +01:00
Andrew Morgan f373ce79d7 Comment out failing log line 2020-09-18 10:45:06 +01:00
Erik Johnston 5ffd68dca1 1.19.2 2020-09-16 13:37:03 +01:00
Erik Johnston f1c9ded738 Merge branch 'erikj/fix_origin_check' into release-v1.19.2 2020-09-16 12:40:58 +01:00
Erik Johnston 97659b7489 Newsfile 2020-09-16 12:05:01 +01:00
Erik Johnston c570f24acc Don't assume that an event has an origin field
This fixes #8319.
2020-09-16 11:56:23 +01:00
Brendan Abolivier eadfda3ebc 1.19.1 2020-08-27 10:50:39 +01:00
Brendan Abolivier 0a4e541dc5 Changelog fixes 2020-08-25 15:29:57 +01:00
Brendan Abolivier b79d69796c 1.19.1rc1 2020-08-25 15:24:39 +01:00
Brendan Abolivier 393a811a41 Fix join ratelimiter breaking profile updates and idempotency (#8153) 2020-08-24 18:06:04 +01:00
Will Hunt 2df82ae451 Do not apply ratelimiting on joins to appservices (#8139)
Add new method ratelimiter.can_requester_do_action and ensure that appservices are exempt from being ratelimited.

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Co-authored-by: Erik Johnston <erik@matrix.org>
2020-08-24 14:53:53 +01:00
105 changed files with 1458 additions and 1277 deletions
+23 -10
View File
@@ -1,15 +1,28 @@
For the next release
====================
Synapse 1.19.2 (2020-09-16)
===========================
Removal warning
---------------
Due to the issue below server admins are encouraged to upgrade as soon as possible.
Some older clients used a
[disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken)
(`:`) in the `client_secret` parameter of various endpoints. The incorrect
behaviour was allowed for backwards compatibility, but is now being removed
from Synapse as most users have updated their client. Further context can be
found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
Bugfixes
--------
- Fix joining rooms over federation that include malformed events. ([\#8324](https://github.com/matrix-org/synapse/issues/8324))
Synapse 1.19.1 (2020-08-27)
===========================
No significant changes.
Synapse 1.19.1rc1 (2020-08-25)
==============================
Bugfixes
--------
- Fix a bug introduced in v1.19.0 where appservices with ratelimiting disabled would still be ratelimited when joining rooms. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
- Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. ([\#8153](https://github.com/matrix-org/synapse/issues/8153))
Synapse 1.19.0 (2020-08-17)
-1
View File
@@ -1 +0,0 @@
Fix a memory leak by limiting the length of time that messages will be queued for a remote server that has been unreachable.
-1
View File
@@ -1 +0,0 @@
Iteratively encode JSON to avoid blocking the reactor.
-1
View File
@@ -1 +0,0 @@
Use the default template file when its equivalent is not found in a custom template directory.
-1
View File
@@ -1 +0,0 @@
Convert various parts of the codebase to async/await.
-1
View File
@@ -1 +0,0 @@
Convert various parts of the codebase to async/await.
-1
View File
@@ -1 +0,0 @@
Convert various parts of the codebase to async/await.
-1
View File
@@ -1 +0,0 @@
Convert various parts of the codebase to async/await.
-1
View File
@@ -1 +0,0 @@
Fix `Re-starting finished log context PUT-nnnn` warning when event persistence failed.
-1
View File
@@ -1 +0,0 @@
Remove some unused database functions.
-1
View File
@@ -1 +0,0 @@
Convert various parts of the codebase to async/await.
-1
View File
@@ -1 +0,0 @@
Add type hints to `synapse.handlers.room`.
-1
View File
@@ -1 +0,0 @@
Add support for shadow-banning users (ignoring any message send requests).
-1
View File
@@ -1 +0,0 @@
Return the previous stream token if a non-member event is a duplicate.
-1
View File
@@ -1 +0,0 @@
Convert various parts of the codebase to async/await.
-1
View File
@@ -1 +0,0 @@
Synapse now correctly enforces the valid characters in the `client_secret` parameter used in various endpoints.
-1
View File
@@ -1 +0,0 @@
Use the default template file when its equivalent is not found in a custom template directory.
-1
View File
@@ -1 +0,0 @@
Link to matrix-synapse-rest-password-provider in the password provider documentation.
-1
View File
@@ -1 +0,0 @@
Return the previous stream token if a non-member event is a duplicate.
+12
View File
@@ -1,3 +1,15 @@
matrix-synapse-py3 (1.19.2) stable; urgency=medium
* New synapse release 1.19.2.
-- Synapse Packaging team <packages@matrix.org> Wed, 16 Sep 2020 12:50:30 +0100
matrix-synapse-py3 (1.19.1) stable; urgency=medium
* New synapse release 1.19.1.
-- Synapse Packaging team <packages@matrix.org> Thu, 27 Aug 2020 10:50:19 +0100
matrix-synapse-py3 (1.19.0) stable; urgency=medium
[ Synapse Packaging team ]
-1
View File
@@ -14,7 +14,6 @@ password auth provider module implementations:
* [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/)
* [matrix-synapse-shared-secret-auth](https://github.com/devture/matrix-synapse-shared-secret-auth)
* [matrix-synapse-rest-password-provider](https://github.com/ma1uta/matrix-synapse-rest-password-provider)
## Required methods
+3 -1
View File
@@ -2002,7 +2002,9 @@ email:
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# Do not uncomment this setting unless you want to customise the templates.
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
+1 -1
View File
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.19.0"
__version__ = "1.19.2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
+1 -11
View File
@@ -213,7 +213,6 @@ class Auth(object):
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
shadow_banned = user_info["shadow_banned"]
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
@@ -253,12 +252,7 @@ class Auth(object):
opentracing.set_tag("device_id", device_id)
return synapse.types.create_requester(
user,
token_id,
is_guest,
shadow_banned,
device_id,
app_service=app_service,
user, token_id, is_guest, device_id, app_service=app_service
)
except KeyError:
raise MissingClientTokenError()
@@ -303,7 +297,6 @@ class Auth(object):
dict that includes:
`user` (UserID)
`is_guest` (bool)
`shadow_banned` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
@@ -363,7 +356,6 @@ class Auth(object):
ret = {
"user": user,
"is_guest": True,
"shadow_banned": False,
"token_id": None,
# all guests get the same device id
"device_id": GUEST_DEVICE_ID,
@@ -373,7 +365,6 @@ class Auth(object):
ret = {
"user": user,
"is_guest": False,
"shadow_banned": False,
"token_id": None,
"device_id": None,
}
@@ -497,7 +488,6 @@ class Auth(object):
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
"shadow_banned": ret.get("shadow_banned"),
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
+1 -1
View File
@@ -23,7 +23,7 @@ from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.storage.presence import UserPresenceState
from synapse.types import RoomID, UserID
FILTER_SCHEMA = {
+37
View File
@@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError
from synapse.types import Requester
from synapse.util import Clock
@@ -43,6 +44,42 @@ class Ratelimiter(object):
# * The rate_hz of this particular entry. This can vary per request
self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]]
def can_requester_do_action(
self,
requester: Requester,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
_time_now_s: Optional[int] = None,
) -> Tuple[bool, float]:
"""Can the requester perform the action?
Args:
requester: The requester to key off when rate limiting. The user property
will be used.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set.
update: Whether to count this check as performing the action
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
Returns:
A tuple containing:
* A bool indicating if they can perform the action now
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return True, -1.0
return self.can_do_action(
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
)
def can_do_action(
self,
key: Any,
+1 -99
View File
@@ -18,16 +18,12 @@
import argparse
import errno
import os
import time
import urllib.parse
from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent
from typing import Any, Callable, List, MutableMapping, Optional
from typing import Any, List, MutableMapping, Optional
import attr
import jinja2
import pkg_resources
import yaml
@@ -104,11 +100,6 @@ class Config(object):
def __init__(self, root_config=None):
self.root = root_config
# Get the path to the default Synapse template directory
self.default_template_dir = pkg_resources.resource_filename(
"synapse", "res/templates"
)
def __getattr__(self, item: str) -> Any:
"""
Try and fetch a configuration option that does not exist on this class.
@@ -193,95 +184,6 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
def read_templates(
self, filenames: List[str], custom_template_directory: Optional[str] = None,
) -> List[jinja2.Template]:
"""Load a list of template files from disk using the given variables.
This function will attempt to load the given templates from the default Synapse
template directory. If `custom_template_directory` is supplied, that directory
is tried first.
Files read are treated as Jinja templates. These templates are not rendered yet.
Args:
filenames: A list of template filenames to read.
custom_template_directory: A directory to try to look for the templates
before using the default Synapse template directory instead.
Raises:
ConfigError: if the file's path is incorrect or otherwise cannot be read.
Returns:
A list of jinja2 templates.
"""
templates = []
search_directories = [self.default_template_dir]
# The loader will first look in the custom template directory (if specified) for the
# given filename. If it doesn't find it, it will use the default template dir instead
if custom_template_directory:
# Check that the given template directory exists
if not self.path_exists(custom_template_directory):
raise ConfigError(
"Configured template directory does not exist: %s"
% (custom_template_directory,)
)
# Search the custom template directory as well
search_directories.insert(0, custom_template_directory)
loader = jinja2.FileSystemLoader(search_directories)
env = jinja2.Environment(loader=loader, autoescape=True)
# Update the environment with our custom filters
env.filters.update(
{
"format_ts": _format_ts_filter,
"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
}
)
for filename in filenames:
# Load the template
template = env.get_template(filename)
templates.append(template)
return templates
def _format_ts_filter(value: int, format: str):
return time.strftime(format, time.localtime(value / 1000))
def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
"""Create and return a jinja2 filter that converts MXC urls to HTTP
Args:
public_baseurl: The public, accessible base URL of the homeserver
"""
def mxc_to_http_filter(value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
return ""
server_and_media_id = value[6:]
fragment = None
if "#" in server_and_media_id:
server_and_media_id, fragment = server_and_media_id.split("#", 1)
fragment = "#" + fragment
params = {"width": width, "height": height, "method": resize_method}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
public_baseurl,
server_and_media_id,
urllib.parse.urlencode(params),
fragment or "",
)
return mxc_to_http_filter
class RootConfig(object):
"""
+78 -67
View File
@@ -23,6 +23,7 @@ from enum import Enum
from typing import Optional
import attr
import pkg_resources
from ._base import Config, ConfigError
@@ -97,18 +98,21 @@ class EmailConfig(Config):
if parsed[1] == "":
raise RuntimeError("Invalid notif_from address")
# A user-configurable template directory
template_dir = email_config.get("template_dir")
if isinstance(template_dir, str):
# We need an absolute path, because we change directory after starting (and
# we don't yet know what auxiliary templates like mail.css we will need).
template_dir = os.path.abspath(template_dir)
elif template_dir is not None:
# If template_dir is something other than a str or None, warn the user
raise ConfigError("Config option email.template_dir must be type str")
# we need an absolute path, because we change directory after starting (and
# we don't yet know what auxiliary templates like mail.css we will need).
# (Note that loading as package_resources with jinja.PackageLoader doesn't
# work for the same reason.)
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates")
self.email_template_dir = os.path.abspath(template_dir)
self.email_enable_notifs = email_config.get("enable_notifs", False)
account_validity_config = config.get("account_validity") or {}
account_validity_renewal_enabled = account_validity_config.get("renew_at")
self.threepid_behaviour_email = (
# Have Synapse handle the email sending if account_threepid_delegates.email
# is not defined
@@ -162,6 +166,19 @@ class EmailConfig(Config):
email_config.get("validation_token_lifetime", "1h")
)
if (
self.email_enable_notifs
or account_validity_renewal_enabled
or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
):
# make sure we can import the required deps
import bleach
import jinja2
# prevent unused warnings
jinja2
bleach
if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
missing = []
if not self.email_notif_from:
@@ -179,49 +196,49 @@ class EmailConfig(Config):
# These email templates have placeholders in them, and thus must be
# parsed using a templating engine during a request
password_reset_template_html = email_config.get(
self.email_password_reset_template_html = email_config.get(
"password_reset_template_html", "password_reset.html"
)
password_reset_template_text = email_config.get(
self.email_password_reset_template_text = email_config.get(
"password_reset_template_text", "password_reset.txt"
)
registration_template_html = email_config.get(
self.email_registration_template_html = email_config.get(
"registration_template_html", "registration.html"
)
registration_template_text = email_config.get(
self.email_registration_template_text = email_config.get(
"registration_template_text", "registration.txt"
)
add_threepid_template_html = email_config.get(
self.email_add_threepid_template_html = email_config.get(
"add_threepid_template_html", "add_threepid.html"
)
add_threepid_template_text = email_config.get(
self.email_add_threepid_template_text = email_config.get(
"add_threepid_template_text", "add_threepid.txt"
)
password_reset_template_failure_html = email_config.get(
self.email_password_reset_template_failure_html = email_config.get(
"password_reset_template_failure_html", "password_reset_failure.html"
)
registration_template_failure_html = email_config.get(
self.email_registration_template_failure_html = email_config.get(
"registration_template_failure_html", "registration_failure.html"
)
add_threepid_template_failure_html = email_config.get(
self.email_add_threepid_template_failure_html = email_config.get(
"add_threepid_template_failure_html", "add_threepid_failure.html"
)
# These templates do not support any placeholder variables, so we
# will read them from disk once during setup
password_reset_template_success_html = email_config.get(
email_password_reset_template_success_html = email_config.get(
"password_reset_template_success_html", "password_reset_success.html"
)
registration_template_success_html = email_config.get(
email_registration_template_success_html = email_config.get(
"registration_template_success_html", "registration_success.html"
)
add_threepid_template_success_html = email_config.get(
email_add_threepid_template_success_html = email_config.get(
"add_threepid_template_success_html", "add_threepid_success.html"
)
# Read all templates from disk
(
# Check templates exist
for f in [
self.email_password_reset_template_html,
self.email_password_reset_template_text,
self.email_registration_template_html,
@@ -231,36 +248,32 @@ class EmailConfig(Config):
self.email_password_reset_template_failure_html,
self.email_registration_template_failure_html,
self.email_add_threepid_template_failure_html,
password_reset_template_success_html_template,
registration_template_success_html_template,
add_threepid_template_success_html_template,
) = self.read_templates(
[
password_reset_template_html,
password_reset_template_text,
registration_template_html,
registration_template_text,
add_threepid_template_html,
add_threepid_template_text,
password_reset_template_failure_html,
registration_template_failure_html,
add_threepid_template_failure_html,
password_reset_template_success_html,
registration_template_success_html,
add_threepid_template_success_html,
],
template_dir,
)
email_password_reset_template_success_html,
email_registration_template_success_html,
email_add_threepid_template_success_html,
]:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find template file %s" % (p,))
# Render templates that do not contain any placeholders
self.email_password_reset_template_success_html_content = (
password_reset_template_success_html_template.render()
# Retrieve content of web templates
filepath = os.path.join(
self.email_template_dir, email_password_reset_template_success_html
)
self.email_registration_template_success_html_content = (
registration_template_success_html_template.render()
self.email_password_reset_template_success_html = self.read_file(
filepath, "email.password_reset_template_success_html"
)
self.email_add_threepid_template_success_html_content = (
add_threepid_template_success_html_template.render()
filepath = os.path.join(
self.email_template_dir, email_registration_template_success_html
)
self.email_registration_template_success_html_content = self.read_file(
filepath, "email.registration_template_success_html"
)
filepath = os.path.join(
self.email_template_dir, email_add_threepid_template_success_html
)
self.email_add_threepid_template_success_html_content = self.read_file(
filepath, "email.add_threepid_template_success_html"
)
if self.email_enable_notifs:
@@ -277,19 +290,17 @@ class EmailConfig(Config):
% (", ".join(missing),)
)
notif_template_html = email_config.get(
self.email_notif_template_html = email_config.get(
"notif_template_html", "notif_mail.html"
)
notif_template_text = email_config.get(
self.email_notif_template_text = email_config.get(
"notif_template_text", "notif_mail.txt"
)
(
self.email_notif_template_html,
self.email_notif_template_text,
) = self.read_templates(
[notif_template_html, notif_template_text], template_dir,
)
for f in self.email_notif_template_text, self.email_notif_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p,))
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
@@ -298,20 +309,18 @@ class EmailConfig(Config):
"client_base_url", email_config.get("riot_base_url", None)
)
if self.account_validity.renew_by_email_enabled:
expiry_template_html = email_config.get(
if account_validity_renewal_enabled:
self.email_expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html"
)
expiry_template_text = email_config.get(
self.email_expiry_template_text = email_config.get(
"expiry_template_text", "notice_expiry.txt"
)
(
self.account_validity_template_html,
self.account_validity_template_text,
) = self.read_templates(
[expiry_template_html, expiry_template_text], template_dir,
)
for f in self.email_expiry_template_text, self.email_expiry_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p,))
subjects_config = email_config.get("subjects", {})
subjects = {}
@@ -391,7 +400,9 @@ class EmailConfig(Config):
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# Do not uncomment this setting unless you want to customise the templates.
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
+11 -3
View File
@@ -18,6 +18,8 @@ import logging
from typing import Any, List
import attr
import jinja2
import pkg_resources
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module
@@ -169,9 +171,15 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "15m")
)
self.saml2_error_html_template = self.read_templates(
["saml_error.html"], saml2_config.get("template_dir")
)
template_dir = saml2_config.get("template_dir")
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
loader = jinja2.FileSystemLoader(template_dir)
# enable auto-escape here, to having to remember to escape manually in the
# template
env = jinja2.Environment(loader=loader, autoescape=True)
self.saml2_error_html_template = env.get_template("saml_error.html")
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
+20
View File
@@ -26,6 +26,7 @@ import yaml
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@@ -507,6 +508,8 @@ class ServerConfig(Config):
)
)
_check_resource_config(self.listeners)
self.cleanup_extremities_with_dummy_events = config.get(
"cleanup_extremities_with_dummy_events", True
)
@@ -1130,3 +1133,20 @@ def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
if name == "webclient":
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
resource_names = {
res_name
for listener in listeners
if listener.http_options
for res in listener.http_options.resources
for res_name in res.names
}
for resource in resource_names:
if resource == "consent":
try:
check_requirements("resources.consent")
except DependencyException as e:
raise ConfigError(e.message)
+15 -22
View File
@@ -12,8 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict
import pkg_resources
from ._base import Config
@@ -26,32 +29,22 @@ class SSOConfig(Config):
def read_config(self, config, **kwargs):
sso_config = config.get("sso") or {} # type: Dict[str, Any]
# The sso-specific template_dir
# Pick a template directory in order of:
# * The sso-specific template_dir
# * /path/to/synapse/install/res/templates
template_dir = sso_config.get("template_dir")
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
# Read templates from disk
(
self.sso_redirect_confirm_template,
self.sso_auth_confirm_template,
self.sso_error_template,
sso_account_deactivated_template,
sso_auth_success_template,
) = self.read_templates(
[
"sso_redirect_confirm.html",
"sso_auth_confirm.html",
"sso_error.html",
"sso_account_deactivated.html",
"sso_auth_success.html",
],
template_dir,
self.sso_template_dir = template_dir
self.sso_account_deactivated_template = self.read_file(
os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
"sso_account_deactivated_template",
)
# These templates have no placeholders, so render them here
self.sso_account_deactivated_template = (
sso_account_deactivated_template.render()
self.sso_auth_success_template = self.read_file(
os.path.join(self.sso_template_dir, "sso_auth_success.html"),
"sso_auth_success_template",
)
self.sso_auth_success_template = sso_auth_success_template.render()
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
+11 -8
View File
@@ -54,7 +54,7 @@ from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function
from synapse.types import JsonDict
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -217,13 +217,15 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"]
]
# FIXME: We should handle signature failures more gracefully.
pdus[:] = await make_deferred_yieldable(
defer.gatherResults(
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
).addErrback(unwrapFirstError)
pdus[:] = await self._check_sigs_and_hash_and_fetch(
dest,
list(pdus),
outlier=True,
room_version=room_version,
)
logger.info("DDD pdus ended up as: %s", pdus)
return pdus
async def get_pdu(
@@ -386,10 +388,11 @@ class FederationClient(FederationBase):
pdu.event_id, allow_rejected=True, allow_none=True
)
if not res and pdu.origin != origin:
pdu_origin = get_domain_from_id(pdu.sender)
if not res and pdu_origin != origin:
try:
res = await self.get_pdu(
destinations=[pdu.origin],
destinations=[pdu_origin],
event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,
+1 -1
View File
@@ -37,8 +37,8 @@ from sortedcontainers import SortedDict
from twisted.internet import defer
from synapse.api.presence import UserPresenceState
from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
from .units import Edu
+1 -1
View File
@@ -22,7 +22,6 @@ from twisted.internet import defer
import synapse
import synapse.metrics
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
@@ -40,6 +39,7 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.metrics import Measure, measure_func
@@ -24,12 +24,12 @@ from synapse.api.errors import (
HttpResponseException,
RequestSendFailed,
)
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.units import Edu
from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@@ -337,28 +337,6 @@ class PerDestinationQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0
),
)
if e.retry_interval > 60 * 60 * 1000:
# we won't retry for another hour!
# (this suggests a significant outage)
# We drop pending PDUs and EDUs because otherwise they will
# rack up indefinitely.
# Note that:
# - the EDUs that are being dropped here are those that we can
# afford to drop (specifically, only typing notifications,
# read receipts and presence updates are being dropped here)
# - Other EDUs such as to_device messages are queued with a
# different mechanism
# - this is all volatile state that would be lost if the
# federation sender restarted anyway
# dropping read receipts is a bit sad but should be solved
# through another mechanism, because this is all volatile!
self._pending_pdus = []
self._pending_edus = []
self._pending_edus_keyed = {}
self._pending_presence = {}
self._pending_rrs = {}
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:
-21
View File
@@ -719,27 +719,6 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise NotImplementedError()
async def change_user_admin_in_group(
self, group_id, user_id, want_admin, requester_user_id, content
):
"""Promotes or demotes a user in a group.
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
if requester_user_id == user_id:
raise SynapseError(400, "User cannot target themselves")
is_admin = await self.store.is_user_admin_in_group(
group_id, requester_user_id
)
if not is_admin:
raise SynapseError(403, "User is not admin in group")
await self.store.change_user_admin_in_group(group_id, user_id, want_admin)
return {}
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
+17 -3
View File
@@ -26,6 +26,11 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
try:
from synapse.push.mailer import load_jinja2_templates
except ImportError:
load_jinja2_templates = None
logger = logging.getLogger(__name__)
@@ -42,11 +47,9 @@ class AccountValidityHandler(object):
if (
self._account_validity.enabled
and self._account_validity.renew_by_email_enabled
and load_jinja2_templates
):
# Don't do email-specific configuration if renewal by email is disabled.
self._template_html = self.config.account_validity_template_html
self._template_text = self.config.account_validity_template_text
try:
app_name = self.hs.config.email_app_name
@@ -62,6 +65,17 @@ class AccountValidityHandler(object):
self._raw_from = email.utils.parseaddr(self._from_string)[1]
self._template_html, self._template_text = load_jinja2_templates(
self.config.email_template_dir,
[
self.config.email_expiry_template_html,
self.config.email_expiry_template_text,
],
apply_format_ts_filter=True,
apply_mxc_to_http_filter=True,
public_baseurl=self.config.public_baseurl,
)
# Check the renewal emails to send and send them every 30min.
def send_emails():
# run as a background process to make sure that the database transactions
+7 -5
View File
@@ -42,6 +42,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.threepids import canonicalise_email
@@ -131,17 +132,18 @@ class AuthHandler(BaseHandler):
# after the SSO completes and before redirecting them back to their client.
# It notifies the user they are about to give access to their matrix account
# to the client.
self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
self._sso_redirect_confirm_template = load_jinja2_templates(
hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
)[0]
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
self._sso_auth_confirm_template = load_jinja2_templates(
hs.config.sso_template_dir, ["sso_auth_confirm.html"],
)[0]
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
self._sso_auth_success_template = hs.config.sso_auth_success_template
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
-19
View File
@@ -461,25 +461,6 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
async def change_user_admin_in_group(
self, group_id, user_id, want_admin, requester_user_id, content
):
"""Promotes or demotes a user in a group.
"""
if not self.is_mine_id(user_id):
raise SynapseError(400, "User not on this server")
# TODO: We should probably support federation, but this is fine for now
if not self.is_mine_id(group_id):
raise SynapseError(400, "Group not on this server")
res = await self.groups_server_handler.change_user_admin_in_group(
group_id, user_id, want_admin, requester_user_id, content
)
return res
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
+13 -16
View File
@@ -667,14 +667,14 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_event = await self.deduplicate_state_event(event, context)
if prev_event is not None:
prev_state = await self.deduplicate_state_event(event, context)
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
event.event_id,
prev_event.event_id,
prev_state.event_id,
)
return await self.store.get_stream_id_for_event(prev_event.event_id)
return prev_state
return await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
@@ -682,32 +682,27 @@ class EventCreationHandler(object):
async def deduplicate_state_event(
self, event: EventBase, context: EventContext
) -> Optional[EventBase]:
) -> None:
"""
Checks whether event is in the latest resolved state in context.
Args:
event: The event to check for duplication.
context: The event context.
Returns:
The previous verion of the event is returned, if it is found in the
event context. Otherwise, None is returned.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return None
return
prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return None
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
return prev_event
return None
return
async def create_and_send_nonmember_event(
self,
@@ -896,7 +891,9 @@ class EventCreationHandler(object):
except Exception:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
await self.store.remove_push_actions_from_staging(event.event_id)
run_in_background(
self.store.remove_push_actions_from_staging, event.event_id
)
raise
async def _validate_canonical_alias(
+4 -1
View File
@@ -38,6 +38,7 @@ from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.push.mailer import load_jinja2_templates
from synapse.types import UserID, map_username_to_mxid_localpart
if TYPE_CHECKING:
@@ -122,7 +123,9 @@ class OidcHandler:
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = hs.config.sso_error_template
self._error_template = load_jinja2_templates(
hs.config.sso_template_dir, ["sso_error.html"]
)[0]
# identifier for the external_ids table
self._auth_provider_id = "oidc"
+1 -1
View File
@@ -33,13 +33,13 @@ from typing_extensions import ContextManager
import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.storage.presence import UserPresenceState
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
-8
View File
@@ -142,7 +142,6 @@ class RegistrationHandler(BaseHandler):
address=None,
bind_emails=[],
by_admin=False,
shadow_banned=False,
):
"""Registers a new client on the server.
@@ -160,7 +159,6 @@ class RegistrationHandler(BaseHandler):
bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the
admin api, otherwise False.
shadow_banned (bool): Shadow-ban the created user.
Returns:
str: user_id
Raises:
@@ -196,7 +194,6 @@ class RegistrationHandler(BaseHandler):
admin=admin,
user_type=user_type,
address=address,
shadow_banned=shadow_banned,
)
if self.hs.config.user_directory_search_all_users:
@@ -227,7 +224,6 @@ class RegistrationHandler(BaseHandler):
make_guest=make_guest,
create_profile_with_displayname=default_display_name,
address=address,
shadow_banned=shadow_banned,
)
# Successfully registered
@@ -533,7 +529,6 @@ class RegistrationHandler(BaseHandler):
admin=False,
user_type=None,
address=None,
shadow_banned=False,
):
"""Register user in the datastore.
@@ -551,7 +546,6 @@ class RegistrationHandler(BaseHandler):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the registration.
shadow_banned (bool): Whether to shadow-ban the user
Returns:
Awaitable
@@ -567,7 +561,6 @@ class RegistrationHandler(BaseHandler):
admin=admin,
user_type=user_type,
address=address,
shadow_banned=shadow_banned,
)
else:
return self.store.register_user(
@@ -579,7 +572,6 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
shadow_banned=shadow_banned,
)
async def register_device(
+38 -66
View File
@@ -22,7 +22,7 @@ import logging
import math
import string
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from typing import Awaitable, Optional, Tuple
from synapse.api.constants import (
EventTypes,
@@ -32,14 +32,11 @@ from synapse.api.constants import (
RoomEncryptionAlgorithms,
)
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
Requester,
RoomAlias,
RoomID,
@@ -56,9 +53,6 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
@@ -67,7 +61,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
class RoomCreationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
def __init__(self, hs):
super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker()
@@ -98,7 +92,7 @@ class RoomCreationHandler(BaseHandler):
"guest_can_join": False,
"power_level_content_override": {},
},
} # type: Dict[str, Dict[str, Any]]
}
# Modify presets to selectively enable encryption by default per homeserver config
for preset_name, preset_config in self._presets_dict.items():
@@ -221,9 +215,6 @@ class RoomCreationHandler(BaseHandler):
old_room_state = await tombstone_context.get_current_state_ids()
# We know the tombstone event isn't an outlier so it has current state.
assert old_room_state is not None
# update any aliases
await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
@@ -537,21 +528,17 @@ class RoomCreationHandler(BaseHandler):
logger.error("Unable to send updated alias events in new room: %s", e)
async def create_room(
self,
requester: Requester,
config: JsonDict,
ratelimit: bool = True,
creator_join_profile: Optional[JsonDict] = None,
self, requester, config, ratelimit=True, creator_join_profile=None
) -> Tuple[dict, int]:
""" Creates a new room.
Args:
requester:
requester (synapse.types.Requester):
The user who requested the room creation.
config : A dict of configuration options.
ratelimit: set to False to disable the rate limiter
config (dict) : A dict of configuration options.
ratelimit (bool): set to False to disable the rate limiter
creator_join_profile:
creator_join_profile (dict|None):
Set to override the displayname and avatar for the creating
user in this room. If unset, displayname and avatar will be
derived from the user's profile. If set, should contain the
@@ -614,7 +601,6 @@ class RoomCreationHandler(BaseHandler):
Codes.UNSUPPORTED_ROOM_VERSION,
)
room_alias = None
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
@@ -625,6 +611,8 @@ class RoomCreationHandler(BaseHandler):
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
else:
room_alias = None
invite_list = config.get("invite", [])
for i in invite_list:
@@ -783,30 +771,23 @@ class RoomCreationHandler(BaseHandler):
async def _send_events_for_new_room(
self,
creator: Requester,
room_id: str,
preset_config: str,
invite_list: List[str],
initial_state: StateMap,
creation_content: JsonDict,
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
creator, # A Requester object.
room_id,
preset_config,
invite_list,
initial_state,
creation_content,
room_alias=None,
power_level_content_override=None, # Doesn't apply when initial state has power level state event content
creator_join_profile=None,
) -> int:
"""Sends the initial events into a new room.
`power_level_content_override` doesn't apply when initial state has
power level state event content.
Returns:
The stream_id of the last event persisted.
"""
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
def create(etype, content, **kwargs):
e = {"type": etype, "content": content}
e.update(event_keys)
@@ -814,7 +795,7 @@ class RoomCreationHandler(BaseHandler):
return e
async def send(etype: str, content: JsonDict, **kwargs) -> int:
async def send(etype, content, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
(
@@ -827,6 +808,10 @@ class RoomCreationHandler(BaseHandler):
config = self._presets_dict[preset_config]
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
await send(etype=EventTypes.Create, content=creation_content)
@@ -867,7 +852,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 50,
} # type: JsonDict
}
if config["original_invitees_have_ops"]:
for invitee in invite_list:
@@ -921,7 +906,7 @@ class RoomCreationHandler(BaseHandler):
return last_sent_stream_id
async def _generate_room_id(
self, creator_id: str, is_public: bool, room_version: RoomVersion,
self, creator_id: str, is_public: str, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -945,30 +930,23 @@ class RoomCreationHandler(BaseHandler):
class RoomContextHandler(object):
def __init__(self, hs: "HomeServer"):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
async def get_event_context(
self,
user: UserID,
room_id: str,
event_id: str,
limit: int,
event_filter: Optional[Filter],
) -> Optional[JsonDict]:
async def get_event_context(self, user, room_id, event_id, limit, event_filter):
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
user
room_id
event_id
limit: The maximum number of events to return in total
user (UserID)
room_id (str)
event_id (str)
limit (int): The maximum number of events to return in total
(excluding state).
event_filter: the filter to apply to the events returned
event_filter (Filter|None): the filter to apply to the events returned
(excluding the target event_id)
Returns:
@@ -1055,18 +1033,12 @@ class RoomContextHandler(object):
class RoomEventSource(object):
def __init__(self, hs: "HomeServer"):
def __init__(self, hs):
self.store = hs.get_datastore()
async def get_new_events(
self,
user: UserID,
from_key: str,
limit: int,
room_ids: List[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], str]:
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
to_key = self.get_current_key()
@@ -1124,7 +1096,7 @@ class RoomShutdownHandler(object):
)
DEFAULT_ROOM_NAME = "Content Violation Notification"
def __init__(self, hs: "HomeServer"):
def __init__(self, hs):
self.hs = hs
self.room_member_handler = hs.get_room_member_handler()
self._room_creation_handler = hs.get_room_creation_handler()
+29 -23
View File
@@ -210,24 +210,40 @@ class RoomMemberHandler(object):
_, stream_id = await self.store.get_event_ordering(duplicate.event_id)
return duplicate.event_id, stream_id
stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
newly_joined = False
if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
if newly_joined:
await self._user_joined_room(target, room_id)
time_now_s = self.clock.time()
(
allowed,
time_allowed,
) = self._join_rate_limiter_local.can_requester_do_action(requester)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
if event.membership == Membership.JOIN and newly_joined:
# Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -457,22 +473,12 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if is_host_in_room:
if not is_host_in_room:
time_now_s = self.clock.time()
allowed, time_allowed = self._join_rate_limiter_local.can_do_action(
requester.user.to_string(),
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
else:
time_now_s = self.clock.time()
allowed, time_allowed = self._join_rate_limiter_remote.can_do_action(
requester.user.to_string(),
)
(
allowed,
time_allowed,
) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
if not allowed:
raise LimitExceededError(
+2
View File
@@ -441,6 +441,7 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
"""
logger.debug(
"{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id,
@@ -449,6 +450,7 @@ class MatrixFederationHttpClient(object):
url_str,
_sec_timeout,
)
"""
outgoing_requests_counter.labels(request.method).inc()
+8 -89
View File
@@ -22,13 +22,12 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
from typing import Any, Callable, Dict, Tuple, Union
import jinja2
from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
from zope.interface import implementer
from canonicaljson import encode_canonical_json, encode_pretty_printed_json
from twisted.internet import defer, interfaces
from twisted.internet import defer
from twisted.python import failure
from twisted.web import resource
from twisted.web.server import NOT_DONE_YET, Request
@@ -500,78 +499,6 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
pass
@implementer(interfaces.IPullProducer)
class _ByteProducer:
"""
Iteratively write bytes to the request.
"""
# The minimum number of bytes for each chunk. Note that the last chunk will
# usually be smaller than this.
min_chunk_size = 1024
def __init__(
self, request: Request, iterator: Iterator[bytes],
):
self._request = request
self._iterator = iterator
def start(self) -> None:
self._request.registerProducer(self, False)
def _send_data(self, data: List[bytes]) -> None:
"""
Send a list of strings as a response to the request.
"""
if not data:
return
self._request.write(b"".join(data))
def resumeProducing(self) -> None:
# We've stopped producing in the meantime (note that this might be
# re-entrant after calling write).
if not self._request:
return
# Get the next chunk and write it to the request.
#
# The output of the JSON encoder is coalesced until min_chunk_size is
# reached. (This is because JSON encoders produce a very small output
# per iteration.)
#
# Note that buffer stores a list of bytes (instead of appending to
# bytes) to hopefully avoid many allocations.
buffer = []
buffered_bytes = 0
while buffered_bytes < self.min_chunk_size:
try:
data = next(self._iterator)
buffer.append(data)
buffered_bytes += len(data)
except StopIteration:
# The entire JSON object has been serialized, write any
# remaining data, finalize the producer and the request, and
# clean-up any references.
self._send_data(buffer)
self._request.unregisterProducer()
self._request.finish()
self.stopProducing()
return
self._send_data(buffer)
def stopProducing(self) -> None:
self._request = None
def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
"""
Encode an object into JSON. Returns an iterator of bytes.
"""
for chunk in json_encoder.iterencode(json_object):
yield chunk.encode("utf-8")
def respond_with_json(
request: Request,
code: int,
@@ -606,23 +533,15 @@ def respond_with_json(
return None
if pretty_print:
encoder = iterencode_pretty_printed_json
json_bytes = encode_pretty_printed_json(json_object) + b"\n"
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
encoder = iterencode_canonical_json
# canonicaljson already encodes to bytes
json_bytes = encode_canonical_json(json_object)
else:
encoder = _encode_json_bytes
json_bytes = json_encoder.encode(json_object).encode("utf-8")
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
if send_cors:
set_cors_headers(request)
producer = _ByteProducer(request, encoder(json_object))
producer.start()
return NOT_DONE_YET
return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
def respond_with_json_bytes(
+1
View File
@@ -22,6 +22,7 @@ _TIME_FUNC_ID = 0
def _log_debug_as_f(f, msg, msg_args):
return
name = f.__module__
logger = logging.getLogger(name)
+71 -1
View File
@@ -16,7 +16,8 @@
import email.mime.multipart
import email.utils
import logging
import urllib.parse
import time
import urllib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Iterable, List, TypeVar
@@ -639,3 +640,72 @@ def string_ordinal_total(s):
for c in s:
tot += ord(c)
return tot
def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000))
def load_jinja2_templates(
template_dir,
template_filenames,
apply_format_ts_filter=False,
apply_mxc_to_http_filter=False,
public_baseurl=None,
):
"""Loads and returns one or more jinja2 templates and applies optional filters
Args:
template_dir (str): The directory where templates are stored
template_filenames (list[str]): A list of template filenames
apply_format_ts_filter (bool): Whether to apply a template filter that formats
timestamps
apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts
mxc urls to http urls
public_baseurl (str|None): The public baseurl of the server. Required for
apply_mxc_to_http_filter to be enabled
Returns:
A list of jinja2 templates corresponding to the given list of filenames,
with order preserved
"""
logger.info(
"loading email templates %s from '%s'", template_filenames, template_dir
)
loader = jinja2.FileSystemLoader(template_dir)
env = jinja2.Environment(loader=loader)
if apply_format_ts_filter:
env.filters["format_ts"] = format_ts_filter
if apply_mxc_to_http_filter and public_baseurl:
env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl)
templates = []
for template_filename in template_filenames:
template = env.get_template(template_filename)
templates.append(template)
return templates
def _create_mxc_to_http_filter(public_baseurl):
def mxc_to_http_filter(value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
return ""
serverAndMediaId = value[6:]
fragment = None
if "#" in serverAndMediaId:
(serverAndMediaId, fragment) = serverAndMediaId.split("#", 1)
fragment = "#" + fragment
params = {"width": width, "height": height, "method": resize_method}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
public_baseurl,
serverAndMediaId,
urllib.parse.urlencode(params),
fragment or "",
)
return mxc_to_http_filter
+24 -7
View File
@@ -15,13 +15,22 @@
import logging
from synapse.push.emailpusher import EmailPusher
from synapse.push.mailer import Mailer
from .httppusher import HttpPusher
logger = logging.getLogger(__name__)
# We try importing this if we can (it will fail if we don't
# have the optional email dependencies installed). We don't
# yet have the config to know if we need the email pusher,
# but importing this after daemonizing seems to fail
# (even though a simple test of importing from a daemonized
# process works fine)
try:
from synapse.push.emailpusher import EmailPusher
from synapse.push.mailer import Mailer, load_jinja2_templates
except Exception:
pass
class PusherFactory(object):
def __init__(self, hs):
@@ -34,8 +43,16 @@ class PusherFactory(object):
if hs.config.email_enable_notifs:
self.mailers = {} # app_name -> Mailer
self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text
self.notif_template_html, self.notif_template_text = load_jinja2_templates(
self.config.email_template_dir,
[
self.config.email_notif_template_html,
self.config.email_notif_template_text,
],
apply_format_ts_filter=True,
apply_mxc_to_http_filter=True,
public_baseurl=self.config.public_baseurl,
)
self.pusher_types["email"] = self._create_email_pusher
@@ -56,8 +73,8 @@ class PusherFactory(object):
mailer = Mailer(
hs=self.hs,
app_name=app_name,
template_html=self._notif_template_html,
template_text=self._notif_template_text,
template_html=self.notif_template_html,
template_text=self.notif_template_text,
)
self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer)
+3 -1
View File
@@ -43,7 +43,7 @@ REQUIREMENTS = [
"jsonschema>=2.5.1",
"frozendict>=1",
"unpaddedbase64>=1.1.0",
"canonicaljson>=1.3.0",
"canonicaljson>=1.2.0",
# we use the type definitions added in signedjson 1.1.
"signedjson>=1.1.0",
"pynacl>=1.2.1",
@@ -78,6 +78,8 @@ CONDITIONAL_REQUIREMENTS = {
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
# we use execute_batch, which arrived in psycopg 2.7.
"postgres": ["psycopg2>=2.7"],
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
"resources.consent": ["Jinja2>=2.9"],
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
"acme": [
-4
View File
@@ -44,7 +44,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin,
user_type,
address,
shadow_banned,
):
"""
Args:
@@ -61,7 +60,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
shadow_banned (bool): Whether to shadow-ban the user
"""
return {
"password_hash": password_hash,
@@ -72,7 +70,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"admin": admin,
"user_type": user_type,
"address": address,
"shadow_banned": shadow_banned,
}
async def _handle_request(self, request, user_id):
@@ -90,7 +87,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin=content["admin"],
user_type=content["user_type"],
address=content["address"],
shadow_banned=content["shadow_banned"],
)
return 200, {}
+4 -5
View File
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import (
NotFoundError,
StoreError,
@@ -162,7 +163,7 @@ class PushRuleRestServlet(RestServlet):
stream_id, _ = self.store.get_push_rules_stream_token()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(self, user_id, spec, val):
def set_rule_attr(self, user_id, spec, val):
if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -172,9 +173,7 @@ class PushRuleRestServlet(RestServlet):
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return await self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val
)
return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
elif spec["attr"] == "actions":
actions = val.get("actions")
_check_actions(actions)
@@ -189,7 +188,7 @@ class PushRuleRestServlet(RestServlet):
if namespaced_rule_id not in rule_ids:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
return await self.store.set_push_rule_actions(
return self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
+32 -12
View File
@@ -32,7 +32,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.push.mailer import Mailer
from synapse.push.mailer import Mailer, load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
@@ -53,11 +53,21 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
template_html, template_text = load_jinja2_templates(
self.config.email_template_dir,
[
self.config.email_password_reset_template_html,
self.config.email_password_reset_template_text,
],
apply_format_ts_filter=True,
apply_mxc_to_http_filter=True,
public_baseurl=self.config.public_baseurl,
)
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
template_html=self.config.email_password_reset_template_html,
template_text=self.config.email_password_reset_template_text,
template_html=template_html,
template_text=template_text,
)
async def on_POST(self, request):
@@ -159,8 +169,9 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._failure_email_template = (
self.config.email_password_reset_template_failure_html
(self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_password_reset_template_failure_html],
)
async def on_GET(self, request, medium):
@@ -203,14 +214,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
return None
# Otherwise show the success template
html = self.config.email_password_reset_template_success_html_content
html = self.config.email_password_reset_template_success_html
status_code = 200
except ThreepidValidationError as e:
status_code = e.code
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
html = self._failure_email_template.render(**template_vars)
html = self.failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
@@ -400,11 +411,19 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
template_html, template_text = load_jinja2_templates(
self.config.email_template_dir,
[
self.config.email_add_threepid_template_html,
self.config.email_add_threepid_template_text,
],
public_baseurl=self.config.public_baseurl,
)
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
template_html=self.config.email_add_threepid_template_html,
template_text=self.config.email_add_threepid_template_text,
template_html=template_html,
template_text=template_text,
)
async def on_POST(self, request):
@@ -559,8 +578,9 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._failure_email_template = (
self.config.email_add_threepid_template_failure_html
(self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_add_threepid_template_failure_html],
)
async def on_GET(self, request):
@@ -611,7 +631,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
html = self._failure_email_template.render(**template_vars)
html = self.failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
-26
View File
@@ -548,31 +548,6 @@ class GroupAdminUsersKickServlet(RestServlet):
return 200, result
class GroupAdminChangeAdminServlet(RestServlet):
"""Promote or demote a user in the group
"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/admins/(?P<user_id>[^/]*)$"
)
def __init__(self, hs):
super(GroupAdminChangeAdminServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
async def on_POST(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
want_admin = content["is_admin"]
result = await self.groups_handler.change_user_admin_in_group(
group_id, user_id, want_admin, requester_user_id, content
)
return 200, result
class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group
@@ -747,7 +722,6 @@ def register_servlets(hs, http_server):
GroupAdminRoomsConfigServlet(hs).register(http_server)
GroupAdminUsersInviteServlet(hs).register(http_server)
GroupAdminUsersKickServlet(hs).register(http_server)
GroupAdminChangeAdminServlet(hs).register(http_server)
GroupSelfLeaveServlet(hs).register(http_server)
GroupSelfJoinServlet(hs).register(http_server)
GroupSelfAcceptInviteServlet(hs).register(http_server)
+25 -6
View File
@@ -44,7 +44,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.push.mailer import Mailer
from synapse.push.mailer import load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -81,11 +81,23 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
self.config = hs.config
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
from synapse.push.mailer import Mailer, load_jinja2_templates
template_html, template_text = load_jinja2_templates(
self.config.email_template_dir,
[
self.config.email_registration_template_html,
self.config.email_registration_template_text,
],
apply_format_ts_filter=True,
apply_mxc_to_http_filter=True,
public_baseurl=self.config.public_baseurl,
)
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
template_html=self.config.email_registration_template_html,
template_text=self.config.email_registration_template_text,
template_html=template_html,
template_text=template_text,
)
async def on_POST(self, request):
@@ -250,8 +262,15 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._failure_email_template = (
self.config.email_registration_template_failure_html
(self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_registration_template_failure_html],
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
(self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_registration_template_failure_html],
)
async def on_GET(self, request, medium):
@@ -299,7 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
html = self._failure_email_template.render(**template_vars)
html = self.failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
+3 -3
View File
@@ -15,12 +15,12 @@
import logging
from typing import Dict, Set
from canonicaljson import json
from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request
logger = logging.getLogger(__name__)
@@ -223,4 +223,4 @@ class RemoteKey(DirectServeJsonResource):
results = {"server_keys": signed_keys}
respond_with_json(request, 200, results, canonical_json=True)
respond_with_json_bytes(request, 200, encode_canonical_json(results))
+9 -5
View File
@@ -18,6 +18,8 @@ from typing import Optional
from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines
@@ -306,8 +308,9 @@ class BackgroundUpdater(object):
update_name (str): Name of update
"""
async def noop_update(progress, batch_size):
await self._end_background_update(update_name)
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, noop_update)
@@ -406,11 +409,12 @@ class BackgroundUpdater(object):
else:
runner = create_index_sqlite
async def updater(progress, batch_size):
@defer.inlineCallbacks
def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
await self.db_pool.runWithConnection(runner)
await self._end_background_update(update_name)
yield self.db_pool.runWithConnection(runner)
yield self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, updater)
+14 -9
View File
@@ -332,7 +332,8 @@ class DatabasePool(object):
"""
return self._db_pool.running
async def _check_safe_to_upsert(self):
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
Is it safe to use native UPSERT?
@@ -341,7 +342,7 @@ class DatabasePool(object):
If the background updates have not completed, wait 15 sec and check again.
"""
updates = await self.simple_select_list(
updates = yield self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
@@ -613,7 +614,8 @@ class DatabasePool(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
@defer.inlineCallbacks
def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@@ -629,7 +631,7 @@ class DatabasePool(object):
`or_ignore` is True
"""
try:
await self.runInteraction(desc, self.simple_insert_txn, table, values)
yield self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@@ -682,7 +684,8 @@ class DatabasePool(object):
txn.executemany(sql, vals)
async def simple_upsert(
@defer.inlineCallbacks
def simple_upsert(
self,
table,
keyvalues,
@@ -711,14 +714,14 @@ class DatabasePool(object):
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
None or bool: Native upserts always return None. Emulated
Deferred(None or bool): Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
"""
attempts = 0
while True:
try:
return await self.runInteraction(
result = yield self.runInteraction(
desc,
self.simple_upsert_txn,
table,
@@ -727,6 +730,7 @@ class DatabasePool(object):
insertion_values,
lock=lock,
)
return result
except self.engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
@@ -1117,7 +1121,8 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn)
async def simple_select_many_batch(
@defer.inlineCallbacks
def simple_select_many_batch(
self,
table,
column,
@@ -1151,7 +1156,7 @@ class DatabasePool(object):
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = await self.runInteraction(
rows = yield self.runInteraction(
desc,
self.simple_select_many_txn,
table,
+1 -1
View File
@@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore(
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
An Awaitable which resolves when the state was set successfully.
A Deferred which resolves when the state was set successfully.
"""
return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
+3 -2
View File
@@ -671,9 +671,10 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
inlineCallbacks=True,
)
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = await self.db_pool.simple_select_many_batch(
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = yield self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@@ -257,6 +257,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_in_room(self, room_id):
return self.db_pool.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
return self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
@@ -298,6 +303,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self.db_pool.simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
def get_prev_events_for_room(self, room_id: str):
"""
Gets a subset of the current forward extremities in the given room.
@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -86,17 +86,18 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3
self._rotate_count = 10000
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
return await self.db_pool.runInteraction(
ret = yield self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
return ret
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
+19 -14
View File
@@ -17,11 +17,13 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
import attr
from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
@@ -111,14 +113,15 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates(
@defer.inlineCallbacks
def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
) -> None:
):
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -133,7 +136,7 @@ class PersistEventsStore:
backfilled
Returns:
Resolves when the events have been persisted
Deferred: resolves when the events have been persisted
"""
# We want to calculate the stream orderings as late as possible, as
@@ -165,7 +168,7 @@ class PersistEventsStore:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
@@ -203,15 +206,16 @@ class PersistEventsStore:
(room_id,), list(latest_event_ids)
)
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
@defer.inlineCallbacks
def _get_events_which_are_prevs(self, event_ids):
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
Args:
event_ids: event ids to filter
event_ids (Iterable[str]): event ids to filter
Returns:
Filtered event ids
Deferred[List[str]]: filtered event ids
"""
results = []
@@ -236,13 +240,14 @@ class PersistEventsStore:
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
return results
async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
@defer.inlineCallbacks
def _get_prevs_before_rejected(self, event_ids):
"""Get soft-failed ancestors to remove from the extremities.
Given a set of events, find all those that have been soft-failed or
@@ -254,11 +259,11 @@ class PersistEventsStore:
are separated by soft failed events.
Args:
event_ids: Events to find prev events for. Note that these must have
already been persisted.
event_ids (Iterable[str]): Events to find prev events for. Note
that these must have already been persisted.
Returns:
The previous events.
Deferred[set[str]]
"""
# The set of event_ids to return. This includes all soft-failed events
@@ -299,7 +304,7 @@ class PersistEventsStore:
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
@@ -15,6 +15,8 @@
import logging
from twisted.internet import defer
from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
@@ -92,7 +94,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="NOT have_censored",
)
async def _background_reindex_fields_sender(self, progress, batch_size):
@defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -152,18 +155,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
result = await self.db_pool.runInteraction(
result = yield self.db_pool.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
await self.db_pool.updates._end_background_update(
yield self.db_pool.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
async def _background_reindex_origin_server_ts(self, progress, batch_size):
@defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -230,18 +234,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows_to_update)
result = await self.db_pool.runInteraction(
result = yield self.db_pool.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
await self.db_pool.updates._end_background_update(
yield self.db_pool.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
async def _cleanup_extremities_bg_update(self, progress, batch_size):
@defer.inlineCallbacks
def _cleanup_extremities_bg_update(self, progress, batch_size):
"""Background update to clean out extremities that should have been
deleted previously.
@@ -409,25 +414,26 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(original_set)
num_handled = await self.db_pool.runInteraction(
num_handled = yield self.db_pool.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
await self.db_pool.updates._end_background_update(
yield self.db_pool.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
return num_handled
async def _redactions_received_ts(self, progress, batch_size):
@defer.inlineCallbacks
def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions.
"""
last_event_id = progress.get("last_event_id", "")
@@ -474,16 +480,17 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
count = await self.db_pool.runInteraction(
count = yield self.db_pool.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
await self.db_pool.updates._end_background_update("redactions_received_ts")
yield self.db_pool.updates._end_background_update("redactions_received_ts")
return count
async def _event_fix_redactions_bytes(self, progress, batch_size):
@defer.inlineCallbacks
def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
@@ -504,15 +511,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute("DROP INDEX redactions_censored_redacts")
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
return 1
async def _event_store_labels(self, progress, batch_size):
@defer.inlineCallbacks
def _event_store_labels(self, progress, batch_size):
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
@@ -567,11 +575,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return nbrows
num_rows = await self.db_pool.runInteraction(
num_rows = yield self.db_pool.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
if not num_rows:
await self.db_pool.updates._end_background_update("event_store_labels")
yield self.db_pool.updates._end_background_update("event_store_labels")
return num_rows
+176 -10
View File
@@ -43,7 +43,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -137,6 +137,42 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts",
)
def get_received_ts_by_stream_pos(self, stream_ordering):
"""Given a stream ordering get an approximate timestamp of when it
happened.
This is done by simply taking the received ts of the first event that
has a stream ordering greater than or equal to the given stream pos.
If none exists returns the current time, on the assumption that it must
have happened recently.
Args:
stream_ordering (int)
Returns:
Deferred[int]
"""
def _get_approximate_received_ts_txn(txn):
sql = """
SELECT received_ts FROM events
WHERE stream_ordering >= ?
LIMIT 1
"""
txn.execute(sql, (stream_ordering,))
row = txn.fetchone()
if row and row[0]:
ts = row[0]
else:
ts = self.clock.time_msec()
return ts
return self.db_pool.runInteraction(
"get_approximate_received_ts", _get_approximate_received_ts_txn
)
@defer.inlineCallbacks
def get_event(
self,
@@ -847,15 +883,13 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield defer.ensureDeferred(
self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
rows = yield self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
return {r["event_id"] for r in rows}
@@ -889,6 +923,36 @@ class EventsWorkerStore(SQLBaseStore):
)
return results
def _get_total_state_event_counts_txn(self, txn, room_id):
"""
See get_total_state_event_counts.
"""
# We join against the events table as that has an index on room_id
sql = """
SELECT COUNT(*) FROM state_events
INNER JOIN events USING (room_id, event_id)
WHERE room_id=?
"""
txn.execute(sql, (room_id,))
row = txn.fetchone()
return row[0] if row else 0
def get_total_state_event_counts(self, room_id):
"""
Gets the total number of state events in a room.
Args:
room_id (str)
Returns:
Deferred[int]
"""
return self.db_pool.runInteraction(
"get_total_state_event_counts",
self._get_total_state_event_counts_txn,
room_id,
)
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
@@ -1158,6 +1222,97 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
@cached(num_args=5, max_entries=10)
def get_all_new_events(
self,
last_backfill_id,
last_forward_id,
current_backfill_id,
current_forward_id,
limit,
):
"""Get all the new events that have arrived at the server either as
new events or as backfilled events"""
have_backfill_events = last_backfill_id != current_backfill_id
have_forward_events = last_forward_id != current_forward_id
if not have_backfill_events and not have_forward_events:
return defer.succeed(AllNewEventsResult([], [], [], [], []))
def get_all_new_events_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
if have_forward_events:
txn.execute(sql, (last_forward_id, current_forward_id, limit))
new_forward_events = txn.fetchall()
if len(new_forward_events) == limit:
upper_bound = new_forward_events[-1][0]
else:
upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_forward_id, upper_bound))
forward_ex_outliers = txn.fetchall()
else:
new_forward_events = []
forward_ex_outliers = []
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
if have_backfill_events:
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
new_backfill_events = txn.fetchall()
if len(new_backfill_events) == limit:
upper_bound = new_backfill_events[-1][0]
else:
upper_bound = current_backfill_id
sql = (
"SELECT -event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_backfill_id, -upper_bound))
backward_ex_outliers = txn.fetchall()
else:
new_backfill_events = []
backward_ex_outliers = []
return AllNewEventsResult(
new_forward_events,
new_backfill_events,
forward_ex_outliers,
backward_ex_outliers,
)
return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
@@ -1202,3 +1357,14 @@ class EventsWorkerStore(SQLBaseStore):
return self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
[
"new_forward_events",
"new_backfill_events",
"forward_ex_outliers",
"backward_ex_outliers",
],
)
@@ -1038,14 +1038,6 @@ class GroupServerStore(GroupServerWorkerStore):
"remove_user_from_group", _remove_user_from_group_txn
)
def change_user_admin_in_group(self, group_id, user_id, is_admin):
return self.db_pool.simple_update(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_admin": is_admin},
desc="change_user_admin_in_group"
)
def add_room_to_group(self, group_id, room_id, is_public):
return self.db_pool.simple_insert(
table="group_rooms",
+28 -4
View File
@@ -15,8 +15,8 @@
from typing import List, Tuple
from synapse.api.presence import UserPresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -130,10 +130,13 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
async def get_presence_for_users(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
def get_presence_for_users(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -157,3 +160,24 @@ class PresenceStore(SQLBaseStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self.db_pool.simple_insert(
table="presence_allow_inbound",
values={
"observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="allow_presence_visible",
or_ignore=True,
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
return self.db_pool.simple_delete_one(
table="presence_allow_inbound",
keyvalues={
"observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="disallow_presence_visible",
)
+52 -44
View File
@@ -32,7 +32,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ChainedIdGenerator
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -115,9 +115,9 @@ class PushRulesWorkerStore(
"""
raise NotImplementedError()
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id):
rows = await self.db_pool.simple_select_list(
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@@ -133,15 +133,17 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
use_new_defaults = user_id in self._users_new_default_push_rules
return _load_rules(rows, enabled_map, use_new_defaults)
rules = _load_rules(rows, enabled_map, use_new_defaults)
@cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id):
results = await self.db_pool.simple_select_list(
return rules
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
results = yield self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@@ -168,15 +170,18 @@ class PushRulesWorkerStore(
)
@cachedList(
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
cached_method_name="get_push_rules_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
async def bulk_get_push_rules(self, user_ids):
def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}
results = {user_id: [] for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
rows = yield self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -189,7 +194,7 @@ class PushRulesWorkerStore(
for row in rows:
results.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
use_new_defaults = user_id in self._users_new_default_push_rules
@@ -200,15 +205,14 @@ class PushRulesWorkerStore(
return results
async def copy_push_rule_from_room_to_room(
self, new_room_id: str, user_id: str, rule: dict
) -> None:
@defer.inlineCallbacks
def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
"""Copy a single push rule from one room to another for a specific user.
Args:
new_room_id: ID of the new room.
user_id : ID of user the push rule belongs to.
rule: A push rule.
new_room_id (str): ID of the new room.
user_id (str): ID of user the push rule belongs to.
rule (Dict): A push rule.
"""
# Create new rule id
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
@@ -220,7 +224,7 @@ class PushRulesWorkerStore(
condition["pattern"] = new_room_id
# Add the rule for the new room
await self.add_push_rule(
yield self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
priority_class=rule["priority_class"],
@@ -228,19 +232,20 @@ class PushRulesWorkerStore(
actions=rule["actions"],
)
async def copy_push_rules_from_room_to_room_for_user(
self, old_room_id: str, new_room_id: str, user_id: str
) -> None:
@defer.inlineCallbacks
def copy_push_rules_from_room_to_room_for_user(
self, old_room_id, new_room_id, user_id
):
"""Copy all of the push rules from one room to another for a specific
user.
Args:
old_room_id: ID of the old room.
new_room_id: ID of the new room.
user_id: ID of user to copy push rules for.
old_room_id (str): ID of the old room.
new_room_id (str): ID of the new room.
user_id (str): ID of user to copy push rules for.
"""
# Retrieve push rules for this user
user_push_rules = await self.get_push_rules_for_user(user_id)
user_push_rules = yield self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
@@ -249,20 +254,21 @@ class PushRulesWorkerStore(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
async def bulk_get_push_rules_enabled(self, user_ids):
def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}
results = {user_id: {} for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
rows = yield self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@@ -326,7 +332,8 @@ class PushRulesWorkerStore(
class PushRuleStore(PushRulesWorkerStore):
async def add_push_rule(
@defer.inlineCallbacks
def add_push_rule(
self,
user_id,
rule_id,
@@ -335,13 +342,13 @@ class PushRuleStore(PushRulesWorkerStore):
actions,
before=None,
after=None,
) -> None:
):
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
if before or after:
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
@@ -355,7 +362,7 @@ class PushRuleStore(PushRulesWorkerStore):
after,
)
else:
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
@@ -539,15 +546,16 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
@defer.inlineCallbacks
def delete_push_rule(self, user_id, rule_id):
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
user_id: The matrix ID of the push rule owner
rule_id: The rule_id of the rule to be deleted
user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
@@ -561,17 +569,18 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
event_stream_ordering,
)
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_id, rule_id, enabled):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
@@ -602,9 +611,8 @@ class PushRuleStore(PushRulesWorkerStore):
op="ENABLE" if enabled else "DISABLE",
)
async def set_push_rule_actions(
self, user_id, rule_id, actions, is_default_rule
) -> None:
@defer.inlineCallbacks
def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -645,7 +653,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
+55 -49
View File
@@ -19,8 +19,10 @@ from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
logger = logging.getLogger(__name__)
@@ -32,22 +34,23 @@ class PusherWorkerStore(SQLBaseStore):
Drops any rows whose data cannot be decoded
"""
for r in rows:
data_json = r["data"]
dataJson = r["data"]
try:
r["data"] = db_to_json(data_json)
r["data"] = db_to_json(dataJson)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
data_json,
dataJson,
e.args[0],
)
continue
yield r
async def user_has_pusher(self, user_id):
ret = await self.db_pool.simple_select_one_onecol(
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
ret = yield self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@@ -58,8 +61,9 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_by_user_id(self, user_id):
return self.get_pushers_by({"user_name": user_id})
async def get_pushers_by(self, keyvalues):
ret = await self.db_pool.simple_select_list(
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
ret = yield self.db_pool.simple_select_list(
"pushers",
keyvalues,
[
@@ -83,14 +87,16 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
async def get_all_pushers(self):
@defer.inlineCallbacks
def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
return rows
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -158,16 +164,19 @@ class PusherWorkerStore(SQLBaseStore):
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
@cached(num_args=1, max_entries=15000)
async def get_if_user_has_pusher(self, user_id):
@cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
cached_method_name="get_if_user_has_pusher",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
async def get_if_users_have_pushers(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
def get_if_users_have_pushers(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@@ -180,38 +189,34 @@ class PusherWorkerStore(SQLBaseStore):
return result
async def update_pusher_last_stream_ordering(
@defer.inlineCallbacks
def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
) -> None:
await self.db_pool.simple_update_one(
):
yield self.db_pool.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
async def update_pusher_last_stream_ordering_and_success(
self,
app_id: str,
pushkey: str,
user_id: str,
last_stream_ordering: int,
last_success: int,
) -> bool:
@defer.inlineCallbacks
def update_pusher_last_stream_ordering_and_success(
self, app_id, pushkey, user_id, last_stream_ordering, last_success
):
"""Update the last stream ordering position we've processed up to for
the given pusher.
Args:
app_id
pushkey
user_id
last_stream_ordering
last_success
app_id (str)
pushkey (str)
last_stream_ordering (int)
last_success (int)
Returns:
True if the pusher still exists; False if it has been deleted.
Deferred[bool]: True if the pusher still exists; False if it has been deleted.
"""
updated = await self.db_pool.simple_update(
updated = yield self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@@ -223,18 +228,18 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
async def update_pusher_failing_since(
self, app_id, pushkey, user_id, failing_since
) -> None:
await self.db_pool.simple_update(
@defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
yield self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
desc="update_pusher_failing_since",
)
async def get_throttle_params_by_room(self, pusher_id):
res = await self.db_pool.simple_select_list(
@defer.inlineCallbacks
def get_throttle_params_by_room(self, pusher_id):
res = yield self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@@ -250,10 +255,11 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room
async def set_throttle_params(self, pusher_id, room_id, params) -> None:
@defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params):
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
await self.db_pool.simple_upsert(
yield self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
@@ -266,7 +272,8 @@ class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
async def add_pusher(
@defer.inlineCallbacks
def add_pusher(
self,
user_id,
access_token,
@@ -280,11 +287,11 @@ class PusherStore(PusherWorkerStore):
data,
last_stream_ordering,
profile_tag="",
) -> None:
):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
yield self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -309,16 +316,15 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
(user_id,),
)
async def delete_pusher_by_app_id_pushkey_user_id(
self, app_id, pushkey, user_id
) -> None:
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
def delete_pusher_txn(txn, stream_id):
self._invalidate_cache_and_stream(
txn, self.get_if_user_has_pusher, (user_id,)
@@ -345,6 +351,6 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
+39 -48
View File
@@ -16,7 +16,7 @@
import abc
import logging
from typing import List, Optional, Tuple
from typing import List, Tuple
from twisted.internet import defer
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
raise NotImplementedError()
@cached()
async def get_users_with_read_receipts_in_room(self, room_id):
receipts = await self.get_receipts_for_room(room_id, "m.read")
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
return {r["user_id"] for r in receipts}
@cached(num_args=2)
@@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
@cached(num_args=2)
async def get_receipts_for_user(self, user_id, receipt_type):
rows = await self.db_pool.simple_select_list(
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
rows = yield self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -95,7 +95,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
@defer.inlineCallbacks
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
@@ -109,7 +110,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
rows = await self.db_pool.runInteraction(
rows = yield self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
)
return {
@@ -121,61 +122,56 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows
}
async def get_linearized_receipts_for_rooms(
self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]:
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients.
Args:
room_id: List of room_ids.
to_key: Max stream id to fetch receipts upto.
from_key: Min stream id to fetch receipts from. None fetches
room_ids (list): List of room_ids.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
from the start.
Returns:
A list of receipts.
list: A list of receipts.
"""
room_ids = set(room_ids)
if from_key is not None:
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids = yield self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
results = await self._get_linearized_receipts_for_rooms(
results = yield self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key
)
return [ev for res in results.values() for ev in res]
async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
Args:
room_ids: The room id.
to_key: Max stream id to fetch receipts upto.
from_key: Min stream id to fetch receipts from. None fetches
room_ids (str): The room id.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
from the start.
Returns:
A list of receipts.
Deferred[list]: A list of receipts.
"""
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
return []
defer.succeed([])
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cached(num_args=3, tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
@cachedInlineCallbacks(num_args=3, tree=True)
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""See get_linearized_receipts_for_room
"""
@@ -199,7 +195,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rows
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@@ -216,8 +212,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
inlineCallbacks=True,
)
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}
@@ -246,7 +243,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
txn_results = await self.db_pool.runInteraction(
txn_results = yield self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
@@ -349,7 +346,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _invalidate_get_users_with_receipts_in_room(
self, room_id: str, receipt_type: str, user_id: str
self, room_id, receipt_type, user_id
):
if receipt_type != "m.read":
return
@@ -475,21 +472,15 @@ class ReceiptsStore(ReceiptsWorkerStore):
return rx_ts
async def insert_receipt(
self,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: dict,
) -> Optional[Tuple[int, int]]:
@defer.inlineCallbacks
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
return None
return
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
@@ -516,13 +507,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = await self.db_pool.runInteraction(
linearized_event_id = yield self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
event_ts = await self.db_pool.runInteraction(
event_ts = yield self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -544,7 +535,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
now - event_ts,
)
await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
+43 -11
View File
@@ -17,7 +17,9 @@
import logging
import re
from typing import Awaitable, Dict, List, Optional
from typing import Dict, List, Optional
from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -302,7 +304,7 @@ class RegistrationWorkerStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
@@ -561,7 +563,7 @@ class RegistrationWorkerStore(SQLBaseStore):
id_server (str)
Returns:
Awaitable
Deferred
"""
# We need to use an upsert, in case they user had already bound the
# threepid
@@ -950,7 +952,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname=None,
admin=False,
user_type=None,
shadow_banned=False,
):
"""Attempts to register an account.
@@ -967,8 +968,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
shadow_banned (bool): Whether the user is shadow-banned,
i.e. they may be told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
@@ -987,7 +986,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
shadow_banned,
)
def _register_user(
@@ -1001,7 +999,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
shadow_banned,
):
user_id_obj = UserID.from_string(user_id)
@@ -1031,7 +1028,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
"shadow_banned": shadow_banned,
},
)
else:
@@ -1046,7 +1042,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
"shadow_banned": shadow_banned,
},
)
@@ -1082,7 +1077,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> Awaitable:
) -> Deferred:
"""Record a mapping from an external user id to a mxid
Args:
@@ -1350,6 +1345,43 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"validate_threepid_session_txn", validate_threepid_session_txn
)
def upsert_threepid_validation_session(
self,
medium,
address,
client_secret,
send_attempt,
session_id,
validated_at=None,
):
"""Upsert a threepid validation session
Args:
medium (str): The medium of the 3PID
address (str): The address of the 3PID
client_secret (str): A unique string provided by the client to
help identify this validation attempt
send_attempt (int): The latest send_attempt on this session
session_id (str): The id of this validation session
validated_at (int|None): The unix timestamp in milliseconds of
when the session was marked as valid
"""
insertion_values = {
"medium": medium,
"address": address,
"client_secret": client_secret,
}
if validated_at:
insertion_values["validated_at"] = validated_at
return self.db_pool.simple_upsert(
table="threepid_validation_session",
keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt},
insertion_values=insertion_values,
desc="upsert_threepid_validation_session",
)
def start_or_continue_validation_session(
self,
medium,
+4
View File
@@ -35,6 +35,10 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple(
"OpsLevel", ("ban_level", "kick_level", "redact_level")
)
RatelimitOverride = collections.namedtuple(
"RatelimitOverride", ("messages_per_second", "burst_count")
)
+13 -8
View File
@@ -17,6 +17,8 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@@ -90,7 +92,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
async def _count_known_servers(self):
@defer.inlineCallbacks
def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -118,7 +121,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
count = await self.db_pool.runInteraction("get_known_servers", _transact)
count = yield self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -586,9 +589,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
inlineCallbacks=True,
)
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -596,11 +601,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
rows = await self.db_pool.simple_select_many_batch(
rows = yield self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -767,13 +772,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
async def get_membership_from_event_ids(
def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
return await self.db_pool.simple_select_many_batch(
return self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@@ -1,18 +0,0 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- A shadow-banned user may be told that their requests succeeded when they were
-- actually ignored.
ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;
@@ -1,17 +0,0 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- This table is no longer used.
DROP TABLE IF EXISTS presence_allow_inbound;
+3 -2
View File
@@ -273,11 +273,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
inlineCallbacks=True,
)
async def _get_state_group_for_events(self, event_ids):
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
rows = await self.db_pool.simple_select_many_batch(
rows = yield self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
+188 -212
View File
@@ -39,17 +39,15 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Optional
from twisted.internet import defer
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.engines import PostgresEngine
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -70,12 +68,8 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
from_token: Optional[Tuple[int, int]],
to_token: Optional[Tuple[int, int]],
engine: BaseDatabaseEngine,
) -> str:
direction, column_names, from_token, to_token, engine
):
"""Creates an SQL expression to bound the columns by the pagination
tokens.
@@ -96,19 +90,21 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
direction: Whether we're paginating backwards("b") or forwards ("f").
column_names: The column names to bound. Must *not* be user defined as
these get inserted directly into the SQL statement without escapes.
from_token: The start point for the pagination. This is an exclusive
minimum bound if direction is "f", and an inclusive maximum bound if
direction is "b".
to_token: The endpoint point for the pagination. This is an inclusive
maximum bound if direction is "f", and an exclusive minimum bound if
direction is "b".
direction (str): Whether we're paginating backwards("b") or
forwards ("f").
column_names (tuple[str, str]): The column names to bound. Must *not*
be user defined as these get inserted directly into the SQL
statement without escapes.
from_token (tuple[int, int]|None): The start point for the pagination.
This is an exclusive minimum bound if direction is "f", and an
inclusive maximum bound if direction is "b".
to_token (tuple[int, int]|None): The endpoint point for the pagination.
This is an inclusive maximum bound if direction is "f", and an
exclusive minimum bound if direction is "b".
engine: The database engine to generate the clauses for
Returns:
The sql expression
str: The sql expression
"""
assert direction in ("b", "f")
@@ -136,12 +132,7 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause)
def _make_generic_sql_bound(
bound: str,
column_names: Tuple[str, str],
values: Tuple[Optional[int], int],
engine: BaseDatabaseEngine,
) -> str:
def _make_generic_sql_bound(bound, column_names, values, engine):
"""Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
@@ -151,18 +142,18 @@ def _make_generic_sql_bound(
out manually.
Args:
bound: The comparison operator to use. One of ">", "<", ">=",
bound (str): The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right.
names: The column names. Must *not* be user defined
names (tuple[str, str]): The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without
escapes.
values: The values to bound the columns by. If
values (tuple[int|None, int]): The values to bound the columns by. If
the first value is None then only creates a bound on the second
column.
engine: The database engine to generate the SQL for
Returns:
The SQL statement
str
"""
assert bound in (">", "<", ">=", "<=")
@@ -202,7 +193,7 @@ def _make_generic_sql_bound(
)
def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
def filter_to_clause(event_filter):
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -300,35 +291,34 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_room_min_stream_ordering(self):
raise NotImplementedError()
async def get_room_events_stream_for_rooms(
self,
room_ids: Iterable[str],
from_key: str,
to_key: str,
limit: int = 0,
order: str = "DESC",
) -> Dict[str, Tuple[List[EventBase], str]]:
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(
self, room_ids, from_key, to_key, limit=0, order="DESC"
):
"""Get new room events in stream ordering since `from_key`.
Args:
room_ids
from_key: Token from which no events are returned before
to_key: Token from which no events are returned after. (This
room_id (str)
from_key (str): Token from which no events are returned before
to_key (str): Token from which no events are returned after. (This
is typically the current stream token)
limit: Maximum number of events to return
order: Either "DESC" or "ASC". Determines which events are
limit (int): Maximum number of events to return
order (str): Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
A map from room id to a tuple containing:
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
Deferred[dict[str,tuple[list[FrozenEvent], str]]]
A map from room id to a tuple containing:
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
room_ids = yield self._events_stream_cache.get_entities_changed(
room_ids, from_id
)
if not room_ids:
return {}
@@ -336,7 +326,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
res = await make_deferred_yieldable(
res = yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -371,31 +361,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if self._events_stream_cache.has_entity_changed(room_id, from_key)
}
async def get_room_events_stream_for_room(
self,
room_id: str,
from_key: str,
to_key: str,
limit: int = 0,
order: str = "DESC",
) -> Tuple[List[EventBase], str]:
@defer.inlineCallbacks
def get_room_events_stream_for_room(
self, room_id, from_key, to_key, limit=0, order="DESC"
):
"""Get new room events in stream ordering since `from_key`.
Args:
room_id
from_key: Token from which no events are returned before
to_key: Token from which no events are returned after. (This
room_id (str)
from_key (str): Token from which no events are returned before
to_key (str): Token from which no events are returned after. (This
is typically the current stream token)
limit: Maximum number of events to return
order: Either "DESC" or "ASC". Determines which events are
limit (int): Maximum number of events to return
order (str): Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
The list of events (in ascending order) and the token from the start
of the chunk of events returned.
Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
events (in ascending order) and the token from the start of
the chunk of events returned.
"""
if from_key == to_key:
return [], from_key
@@ -403,7 +390,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
has_changed = yield self._events_stream_cache.has_entity_changed(
room_id, from_id
)
if not has_changed:
return [], from_key
@@ -421,9 +410,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
ret = await self.get_events_as_list(
ret = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -441,7 +430,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(self, user_id, from_key, to_key):
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -470,9 +460,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
ret = await self.get_events_as_list(
ret = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -480,26 +470,27 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
async def get_recent_events_for_room(
self, room_id: str, limit: int, end_token: str
) -> Tuple[List[EventBase], str]:
@defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token):
"""Get the most recent events in the room in topological ordering.
Args:
room_id
limit
end_token: The stream token representing now.
room_id (str)
limit (int)
end_token (str): The stream token representing now.
Returns:
A list of events and a token pointing to the start of the returned
events. The events returned are in ascending order.
Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
events and a token pointing to the start of the returned
events.
The events returned are in ascending order.
"""
rows, token = await self.get_recent_event_ids_for_room(
rows, token = yield self.get_recent_event_ids_for_room(
room_id, limit, end_token
)
events = await self.get_events_as_list(
events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -507,19 +498,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
async def get_recent_event_ids_for_room(
self, room_id: str, limit: int, end_token: str
) -> Tuple[List[_EventDictReturn], str]:
@defer.inlineCallbacks
def get_recent_event_ids_for_room(self, room_id, limit, end_token):
"""Get the most recent events in the room in topological ordering.
Args:
room_id
limit
end_token: The stream token representing now.
room_id (str)
limit (int)
end_token (str): The stream token representing now.
Returns:
A list of _EventDictReturn and a token pointing to the start of the
returned events. The events returned are in ascending order.
Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
_EventDictReturn and a token pointing to the start of the returned
events.
The events returned are in ascending order.
"""
# Allow a zero limit here, and no-op.
if limit == 0:
@@ -527,7 +519,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
rows, token = await self.db_pool.runInteraction(
rows, token = yield self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -540,12 +532,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token
def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
"""Gets details of the first event in a room at or before a stream ordering
Args:
room_id:
stream_ordering:
room_id (str):
stream_ordering (int):
Returns:
Deferred[(int, int, str)]:
@@ -582,67 +574,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return "t%d-%d" % (topo, token)
async def get_stream_id_for_event(self, event_id: str) -> int:
"""The stream ID for an event
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
event_id(str): The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A stream ID.
A deferred "s%d" stream token.
"""
return await self.db_pool.simple_select_one_onecol(
return self.db_pool.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
)
).addCallback(lambda row: "s%d" % (row,))
async def get_stream_token_for_event(self, event_id: str) -> str:
def get_topological_token_for_event(self, event_id):
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
event_id(str): The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A "s%d" stream token.
A deferred "t%d-%d" topological token.
"""
stream_id = await self.get_stream_id_for_event(event_id)
return "s%d" % (stream_id,)
async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A "t%d-%d" topological token.
"""
row = await self.db_pool.simple_select_one(
return self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
).addCallback(
lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
)
return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
def get_max_topological_token(self, room_id, stream_key):
"""Get the max topological token in a room before the given stream
ordering.
Args:
room_id
stream_key
room_id (str)
stream_key (int)
Returns:
The maximum topological token.
Deferred[int]
"""
sql = (
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
row = await self.db_pool.execute(
return self.db_pool.execute(
"get_max_topological_token", None, sql, room_id, stream_key
)
return row[0][0] if row else 0
).addCallback(lambda r: r[0][0] if r else 0)
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
@@ -654,18 +634,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows[0][0] if rows else 0
@staticmethod
def _set_before_and_after(
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
):
def _set_before_and_after(events, rows, topo_order=True):
"""Inserts ordering information to events' internal metadata from
the DB rows.
Args:
events
rows
topo_order: Whether the events were ordered topologically or by stream
ordering. If true then all rows should have a non null
topological_ordering.
events (list[FrozenEvent])
rows (list[_EventDictReturn])
topo_order (bool): Whether the events were ordered topologically
or by stream ordering. If true then all rows should have a non
null topological_ordering.
"""
for event, row in zip(events, rows):
stream = row.stream_ordering
@@ -678,19 +656,25 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (int(topo) if topo else 0, int(stream))
async def get_events_around(
self,
room_id: str,
event_id: str,
before_limit: int,
after_limit: int,
event_filter: Optional[Filter] = None,
) -> dict:
@defer.inlineCallbacks
def get_events_around(
self, room_id, event_id, before_limit, after_limit, event_filter=None
):
"""Retrieve events and pagination tokens around a given event in a
room.
Args:
room_id (str)
event_id (str)
before_limit (int)
after_limit (int)
event_filter (Filter|None)
Returns:
dict
"""
results = await self.db_pool.runInteraction(
results = yield self.db_pool.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -700,11 +684,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
events_before = await self.get_events_as_list(
events_before = yield self.get_events_as_list(
list(results["before"]["event_ids"]), get_prev_content=True
)
events_after = await self.get_events_as_list(
events_after = yield self.get_events_as_list(
list(results["after"]["event_ids"]), get_prev_content=True
)
@@ -716,23 +700,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
def _get_events_around_txn(
self,
txn,
room_id: str,
event_id: str,
before_limit: int,
after_limit: int,
event_filter: Optional[Filter],
) -> dict:
self, txn, room_id, event_id, before_limit, after_limit, event_filter
):
"""Retrieves event_ids and pagination tokens around a given event in a
room.
Args:
room_id
event_id
before_limit
after_limit
event_filter
room_id (str)
event_id (str)
before_limit (int)
after_limit (int)
event_filter (Filter|None)
Returns:
dict
@@ -780,23 +758,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
async def get_all_new_events_stream(
self, from_id: int, current_id: int, limit: int
) -> Tuple[int, List[EventBase]]:
@defer.inlineCallbacks
def get_all_new_events_stream(self, from_id, current_id, limit):
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
Args:
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
from_id (int): the stream_ordering of the last event we processed
current_id (int): the stream_ordering of the most recently processed event
limit (int): the maximum number of events to return
Returns:
A tuple of (next_id, events), where `next_id` is the next value to
pass as `from_id` (it will either be the stream_ordering of the
last returned event, or, if fewer than `limit` events were found,
the `current_id`).
Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
`next_id` is the next value to pass as `from_id` (it will either be the
stream_ordering of the last returned event, or, if fewer than `limit` events
were found, `current_id`.
"""
def get_all_new_events_stream_txn(txn):
@@ -818,11 +795,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = await self.db_pool.runInteraction(
upper_bound, event_ids = yield self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
events = await self.get_events_as_list(event_ids)
events = yield self.get_events_as_list(event_ids)
return upper_bound, events
@@ -840,21 +817,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_federation_out_pos",
)
async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
async def update_federation_out_pos(self, typ, stream_id):
if self._need_to_reset_federation_stream_positions:
await self.db_pool.runInteraction(
"_reset_federation_positions_txn", self._reset_federation_positions_txn
)
self._need_to_reset_federation_stream_positions = False
await self.db_pool.simple_update_one(
return await self.db_pool.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
def _reset_federation_positions_txn(self, txn) -> None:
def _reset_federation_positions_txn(self, txn):
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@@ -915,37 +892,39 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
values={"stream_id": stream_id},
)
def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(
self,
txn,
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], str]:
room_id,
from_token,
to_token=None,
direction="b",
limit=-1,
event_filter=None,
):
"""Returns list of events before or after a given token.
Args:
txn
room_id
from_token: The token used to stream from
to_token: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating
forwards or backwards from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to
room_id (str)
from_token (RoomStreamToken): The token used to stream from
to_token (RoomStreamToken|None): A token which if given limits the
results to only those before
direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`.
limit (int): The maximum number of events to return.
event_filter (Filter|None): If provided filters the events to
those that match the filter.
Returns:
A list of _EventDictReturn and a token that points to the end of the
result set. If no events are returned then the end of the stream has
been reached (i.e. there are no events between `from_token` and
`to_token`), or `limit` is zero.
Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
as a list of _EventDictReturn and a token that points to the end
of the result set. If no events are returned then the end of the
stream has been reached (i.e. there are no events between
`from_token` and `to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -1029,38 +1008,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, str(next_token)
async def paginate_room_events(
self,
room_id: str,
from_key: str,
to_key: Optional[str] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], str]:
@defer.inlineCallbacks
def paginate_room_events(
self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
):
"""Returns list of events before or after a given token.
Args:
room_id
from_key: The token used to stream from
to_key: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating
forwards or backwards from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to those that match the filter.
room_id (str)
from_key (str): The token used to stream from
to_key (str|None): A token which if given limits the results to
only those before
direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`.
limit (int): The maximum number of events to return.
event_filter (Filter|None): If provided filters the events to
those that match the filter.
Returns:
The results as a list of events and a token that points to the end
of the result set. If no events are returned then the end of the
stream has been reached (i.e. there are no events between `from_key`
and `to_key`).
tuple[list[FrozenEvent], str]: Returns the results as a list of
events and a token that points to the end of the result set. If no
events are returned then the end of the stream has been reached
(i.e. there are no events between `from_key` and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
rows, token = await self.db_pool.runInteraction(
rows, token = yield self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
@@ -1071,7 +1047,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
events = await self.get_events_as_list(
events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -1081,8 +1057,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
class StreamStore(StreamWorkerStore):
def get_room_max_stream_ordering(self) -> int:
def get_room_max_stream_ordering(self):
return self._stream_id_gen.get_current_token()
def get_room_min_stream_ordering(self) -> int:
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
@@ -38,8 +38,10 @@ class UserErasureWorkerStore(SQLBaseStore):
desc="is_user_erased",
).addCallback(operator.truth)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids):
@cachedList(
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
)
def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure
@@ -47,14 +49,14 @@ class UserErasureWorkerStore(SQLBaseStore):
user_ids (iterable[str]): full user id to check
Returns:
dict[str, bool]:
Deferred[dict[str, bool]]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
rows = await self.db_pool.simple_select_many_batch(
rows = yield self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@@ -63,7 +65,8 @@ class UserErasureWorkerStore(SQLBaseStore):
)
erased_users = {row["user_id"] for row in rows}
return {u: u in erased_users for u in user_ids}
res = {u: u in erased_users for u in user_ids}
return res
class UserErasureStore(UserErasureWorkerStore):
+3 -22
View File
@@ -51,15 +51,7 @@ JsonDict = Dict[str, Any]
class Requester(
namedtuple(
"Requester",
[
"user",
"access_token_id",
"is_guest",
"shadow_banned",
"device_id",
"app_service",
],
"Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
)
):
"""
@@ -70,7 +62,6 @@ class Requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
shadow_banned (bool): True if the user making this request has been shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
"""
@@ -86,7 +77,6 @@ class Requester(
"user_id": self.user.to_string(),
"access_token_id": self.access_token_id,
"is_guest": self.is_guest,
"shadow_banned": self.shadow_banned,
"device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None,
}
@@ -111,19 +101,13 @@ class Requester(
user=UserID.from_string(input["user_id"]),
access_token_id=input["access_token_id"],
is_guest=input["is_guest"],
shadow_banned=input["shadow_banned"],
device_id=input["device_id"],
app_service=appservice,
)
def create_requester(
user_id,
access_token_id=None,
is_guest=False,
shadow_banned=False,
device_id=None,
app_service=None,
user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
):
"""
Create a new ``Requester`` object
@@ -133,7 +117,6 @@ def create_requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
@@ -142,9 +125,7 @@ def create_requester(
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
return Requester(
user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
)
return Requester(user_id, access_token_id, is_guest, device_id, app_service)
def get_domain_from_id(string):
+1
View File
@@ -32,6 +32,7 @@ json_encoder = json.JSONEncoder(separators=(",", ":"))
def unwrapFirstError(failure):
# defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError)
logger.info("DDD failure.value.subFailure: %s", failure.value.subFailure)
return failure.value.subFailure
+3 -1
View File
@@ -24,7 +24,9 @@ from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
# Note: The : character is allowed here for older clients, but will be removed in a
# future release. Context: https://github.com/matrix-org/synapse/issues/6766
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$")
# random_string and random_string_with_symbols are used for a range of things,
# some cryptographically important, some less so. We use SystemRandom to make sure
+73
View File
@@ -1,4 +1,6 @@
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from synapse.appservice import ApplicationService
from synapse.types import create_requester
from tests import unittest
@@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase):
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
def test_allowed_user_via_can_requester_do_action(self):
user_requester = create_requester("@user:example.com")
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_requester_do_action(
user_requester, _time_now_s=0
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
user_requester, _time_now_s=5
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
user_requester, _time_now_s=10
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=True,
)
as_requester = create_requester("@user:example.com", app_service=appservice)
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=0
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=5
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=10
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=False,
)
as_requester = create_requester("@user:example.com", app_service=appservice)
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=0
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=5
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=10
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
def test_allowed_via_ratelimit(self):
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-82
View File
@@ -1,82 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import tempfile
from synapse.config import ConfigError
from synapse.util.stringutils import random_string
from tests import unittest
class BaseConfigTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.hs = hs
def test_loading_missing_templates(self):
# Use a temporary directory that exists on the system, but that isn't likely to
# contain template files
with tempfile.TemporaryDirectory() as tmp_dir:
# Attempt to load an HTML template from our custom template directory
template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
# If no errors, we should've gotten the default template instead
# Render the template
a_random_string = random_string(5)
html_content = template.render({"error_description": a_random_string})
# Check that our string exists in the template
self.assertIn(
a_random_string,
html_content,
"Template file did not contain our test string",
)
def test_loading_custom_templates(self):
# Use a temporary directory that exists on the system
with tempfile.TemporaryDirectory() as tmp_dir:
# Create a temporary bogus template file
with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template:
# Get temporary file's filename
template_filename = os.path.basename(tmp_template.name)
# Write a custom HTML template
contents = b"{{ test_variable }}"
tmp_template.write(contents)
tmp_template.flush()
# Attempt to load the template from our custom template directory
template = (
self.hs.config.read_templates([template_filename], tmp_dir)
)[0]
# Render the template
a_random_string = random_string(5)
html_content = template.render({"test_variable": a_random_string})
# Check that our string exists in the template
self.assertIn(
a_random_string,
html_content,
"Template file did not contain our test string",
)
def test_loading_template_from_nonexistent_custom_directory(self):
with self.assertRaises(ConfigError):
self.hs.config.read_templates(
["some_filename.html"], "a_nonexistent_directory"
)
+10 -20
View File
@@ -79,11 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -112,11 +110,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -152,11 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(None)
)
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
handler.federation_handler.do_invite_join = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
return_value=make_awaitable(("", 1))
)
# Artificially raise the complexity
@@ -210,11 +204,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -242,11 +234,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
+1 -1
View File
@@ -19,7 +19,6 @@ from mock import Mock, call
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
@@ -33,6 +32,7 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest.client.v1 import room
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from tests import unittest
+2 -2
View File
@@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
yield self.store.create_profile(self.frank.localpart)
self.handler = hs.get_profile_handler()
self.hs = hs
@@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.create_profile("caroline")
yield self.store.set_profile_displayname("caroline", "Caroline")
response = yield defer.ensureDeferred(
+1 -1
View File
@@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
None
)
+86 -1
View File
@@ -28,7 +28,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
from synapse.types import JsonDict, RoomAlias
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
from tests import unittest
@@ -675,6 +675,91 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
servlets = [
profile.register_servlets,
room.register_servlets,
]
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
)
def test_join_local_ratelimit(self):
"""Tests that local joins are actually rate-limited."""
for i in range(5):
self.helper.create_room_as(self.user_id)
self.helper.create_room_as(self.user_id, expect_code=429)
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
)
def test_join_local_ratelimit_profile_change(self):
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
# Create and join more rooms than the rate-limiting config allows in a second.
room_ids = [
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
]
self.reactor.advance(1)
room_ids = room_ids + [
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
]
# Create a profile for the user, since it hasn't been done on registration.
store = self.hs.get_datastore()
store.create_profile(UserID.from_string(self.user_id).localpart)
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
self.render(request)
self.assertEquals(channel.code, 200, channel.json_body)
# Check that all the rooms have been sent a profile update into.
for room_id in room_ids:
path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
room_id,
self.user_id,
)
request, channel = self.make_request("GET", path)
self.render(request)
self.assertEquals(channel.code, 200)
self.assertIn("displayname", channel.json_body)
self.assertEquals(channel.json_body["displayname"], "John Doe")
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
)
def test_join_local_ratelimit_idempotent(self):
"""Tests that the room join endpoints remain idempotent despite rate-limiting
on room joins."""
room_id = self.helper.create_room_as(self.user_id)
# Let's test both paths to be sure.
paths_to_test = [
"/_matrix/client/r0/rooms/%s/join",
"/_matrix/client/r0/join/%s",
]
for path in paths_to_test:
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
for i in range(6):
request, channel = self.make_request("POST", path % room_id, {})
self.render(request)
self.assertEquals(channel.code, 200)
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
+7 -3
View File
@@ -39,7 +39,9 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
def create_room_as(self, room_creator=None, is_public=True, tok=None):
def create_room_as(
self, room_creator=None, is_public=True, tok=None, expect_code=200,
):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
@@ -54,9 +56,11 @@ class RestHelper(object):
)
render(request, self.resource, self.hs.get_reactor())
assert channel.result["code"] == b"200", channel.result
assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
return channel.json_body["room_id"]
if expect_code == 200:
return channel.json_body["room_id"]
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
+4 -12
View File
@@ -207,9 +207,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -221,15 +219,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
+6 -10
View File
@@ -66,10 +66,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"}
)
yield self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"}
)
self.mock_txn.execute.assert_called_with(
@@ -80,12 +78,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
)
yield self.datastore.db_pool.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
)
self.mock_txn.execute.assert_called_with(
+2 -2
View File
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
self.requester = Requester(self.user, None, False, False, None, None)
self.requester = Requester(self.user, None, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
self.requester = Requester(self.user, None, False, False, None, None)
self.requester = Requester(self.user, None, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
+1 -1
View File
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
requester = Requester(user, None, False, False, None, None)
requester = Requester(user, None, False, None, None)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]
+14 -16
View File
@@ -142,22 +142,20 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
return defer.ensureDeferred(
self.store.db_pool.simple_insert(
"events",
{
"stream_ordering": so,
"received_ts": ts,
"event_id": "event%i" % so,
"type": "",
"room_id": "",
"content": "",
"processed": True,
"outlier": False,
"topological_ordering": 0,
"depth": 0,
},
)
return self.store.db_pool.simple_insert(
"events",
{
"stream_ordering": so,
"received_ts": ts,
"event_id": "event%i" % so,
"type": "",
"room_id": "",
"content": "",
"processed": True,
"outlier": False,
"topological_ordering": 0,
"depth": 0,
},
)
# start with the base case where there are no events in the table
+1 -1
View File
@@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield self.store.create_profile(self.user.localpart)
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
users, total = yield self.store.get_users_paginate(
+2 -2
View File
@@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.create_profile(self.u_frank.localpart)
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
@@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.create_profile(self.u_frank.localpart)
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
+33 -16
View File
@@ -15,7 +15,6 @@
from twisted.internet import defer
from synapse.api.errors import NotFoundError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -47,19 +46,30 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
event = self.get_success(
store.get_topological_token_for_event(last["event_id"])
)
event = store.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
# Purge everything before this topological token
self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
purge = defer.ensureDeferred(
storage.purge_events.purge_history(self.room_id, event, True)
)
self.pump()
self.assertEqual(self.successResultOf(purge), None)
# Try and get the events
get_first = store.get_event(first["event_id"])
get_second = store.get_event(second["event_id"])
get_third = store.get_event(third["event_id"])
get_last = store.get_event(last["event_id"])
self.pump()
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
self.get_failure(store.get_event(first["event_id"]), NotFoundError)
self.get_failure(store.get_event(second["event_id"]), NotFoundError)
self.get_failure(store.get_event(third["event_id"]), NotFoundError)
self.get_success(store.get_event(last["event_id"]))
self.failureResultOf(get_first)
self.failureResultOf(get_second)
self.failureResultOf(get_third)
self.successResultOf(get_last)
def test_purge_wont_delete_extrems(self):
"""
@@ -74,9 +84,9 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
event = self.get_success(
storage.get_topological_token_for_event(last["event_id"])
)
event = storage.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
event = "t{}-{}".format(
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
)
@@ -88,7 +98,14 @@ class PurgeTests(HomeserverTestCase):
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
self.get_success(storage.get_event(first["event_id"]))
self.get_success(storage.get_event(second["event_id"]))
self.get_success(storage.get_event(third["event_id"]))
self.get_success(storage.get_event(last["event_id"]))
get_first = storage.get_event(first["event_id"])
get_second = storage.get_event(second["event_id"])
get_third = storage.get_event(third["event_id"])
get_last = storage.get_event(last["event_id"])
self.pump()
# Nothing is deleted.
self.successResultOf(get_first)
self.successResultOf(get_second)
self.successResultOf(get_third)
self.successResultOf(get_last)
+1 -1
View File
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
requester = Requester(user, None, False, False, None, None)
requester = Requester(user, None, False, None, None)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.

Some files were not shown because too many files have changed in this diff Show More