Compare commits
12 Commits
dmr/warn-m
...
shay/mx_ma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85b30abfde | ||
|
|
d809a4c8fb | ||
|
|
df621cbaa5 | ||
|
|
301b9cdfcc | ||
|
|
21002db229 | ||
|
|
78b99de7c2 | ||
|
|
5ef673de4f | ||
|
|
d743b25c8f | ||
|
|
30c8e7e408 | ||
|
|
6463244375 | ||
|
|
8a23bde823 | ||
|
|
e8d1ec0e92 |
8
.github/workflows/tests.yml
vendored
8
.github/workflows/tests.yml
vendored
@@ -20,13 +20,9 @@ jobs:
|
||||
- run: scripts-dev/config-lint.sh
|
||||
|
||||
lint:
|
||||
# This does a vanilla `poetry install` - no extras. I'm slightly anxious
|
||||
# that we might skip some typechecks on code that uses extras. However,
|
||||
# I think the right way to fix this is to mark any extras needed for
|
||||
# typechecking as development dependencies. To detect this, we ought to
|
||||
# turn up mypy's strictness: disallow unknown imports and be accept fewer
|
||||
# uses of `Any`.
|
||||
uses: "matrix-org/backend-meta/.github/workflows/python-poetry-ci.yml@v1"
|
||||
with:
|
||||
typechecking-extras: "all"
|
||||
|
||||
lint-crlf:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
1
changelog.d/12356.misc
Normal file
1
changelog.d/12356.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix scripts-dev to pass typechecking.
|
||||
1
changelog.d/12406.feature
Normal file
1
changelog.d/12406.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add a module API to allow modules to change actions for existing push rules of local users.
|
||||
1
changelog.d/12480.misc
Normal file
1
changelog.d/12480.misc
Normal file
@@ -0,0 +1 @@
|
||||
Use supervisord to supervise Postgres and Caddy in the Complement image to reduce restart time.
|
||||
1
changelog.d/12505.misc
Normal file
1
changelog.d/12505.misc
Normal file
@@ -0,0 +1 @@
|
||||
Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests.
|
||||
1
changelog.d/12526.feature
Normal file
1
changelog.d/12526.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add new `enable_registration_token_3pid_bypass` configuration option to allow registrations via token as an alternative to verifying a 3pid.
|
||||
1
changelog.d/12531.misc
Normal file
1
changelog.d/12531.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove unused `# type: ignore`s.
|
||||
1
changelog.d/12564.misc
Normal file
1
changelog.d/12564.misc
Normal file
@@ -0,0 +1 @@
|
||||
Consistently check if an object is a `frozendict`.
|
||||
1
changelog.d/12572.feature
Normal file
1
changelog.d/12572.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add the Synapse function `types.map_username_to_mxid_localpart` to the Module API.
|
||||
@@ -20,6 +20,9 @@ RUN rm /etc/nginx/sites-enabled/default
|
||||
# Copy Synapse worker, nginx and supervisord configuration template files
|
||||
COPY ./docker/conf-workers/* /conf/
|
||||
|
||||
# Copy a script to prefix log lines with the supervisor program name
|
||||
COPY ./docker/prefix-log /usr/local/bin/
|
||||
|
||||
# Expose nginx listener port
|
||||
EXPOSE 8080/tcp
|
||||
|
||||
|
||||
@@ -34,13 +34,16 @@ WORKDIR /data
|
||||
# Copy the caddy config
|
||||
COPY conf-workers/caddy.complement.json /root/caddy.json
|
||||
|
||||
COPY conf-workers/postgres.supervisord.conf /etc/supervisor/conf.d/postgres.conf
|
||||
COPY conf-workers/caddy.supervisord.conf /etc/supervisor/conf.d/caddy.conf
|
||||
|
||||
# Copy the entrypoint
|
||||
COPY conf-workers/start-complement-synapse-workers.sh /
|
||||
|
||||
# Expose caddy's listener ports
|
||||
EXPOSE 8008 8448
|
||||
|
||||
ENTRYPOINT /start-complement-synapse-workers.sh
|
||||
ENTRYPOINT ["/start-complement-synapse-workers.sh"]
|
||||
|
||||
# Update the healthcheck to have a shorter check interval
|
||||
HEALTHCHECK --start-period=5s --interval=1s --timeout=1s \
|
||||
|
||||
7
docker/complement/conf-workers/caddy.supervisord.conf
Normal file
7
docker/complement/conf-workers/caddy.supervisord.conf
Normal file
@@ -0,0 +1,7 @@
|
||||
[program:caddy]
|
||||
command=/usr/local/bin/prefix-log /root/caddy run --config /root/caddy.json
|
||||
autorestart=unexpected
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
16
docker/complement/conf-workers/postgres.supervisord.conf
Normal file
16
docker/complement/conf-workers/postgres.supervisord.conf
Normal file
@@ -0,0 +1,16 @@
|
||||
[program:postgres]
|
||||
command=/usr/local/bin/prefix-log /usr/bin/pg_ctlcluster 13 main start --foreground
|
||||
|
||||
# Lower priority number = starts first
|
||||
priority=1
|
||||
|
||||
autorestart=unexpected
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
|
||||
# Use 'Fast Shutdown' mode which aborts current transactions and closes connections quickly.
|
||||
# (Default (TERM) is 'Smart Shutdown' which stops accepting new connections but
|
||||
# lets existing connections close gracefully.)
|
||||
stopsignal=INT
|
||||
@@ -12,12 +12,6 @@ function log {
|
||||
# Replace the server name in the caddy config
|
||||
sed -i "s/{{ server_name }}/${SERVER_NAME}/g" /root/caddy.json
|
||||
|
||||
log "starting postgres"
|
||||
pg_ctlcluster 13 main start
|
||||
|
||||
log "starting caddy"
|
||||
/root/caddy start --config /root/caddy.json
|
||||
|
||||
# Set the server name of the homeserver
|
||||
export SYNAPSE_SERVER_NAME=${SERVER_NAME}
|
||||
|
||||
|
||||
@@ -2,11 +2,7 @@ version: 1
|
||||
|
||||
formatters:
|
||||
precise:
|
||||
{% if worker_name %}
|
||||
format: '%(asctime)s - worker:{{ worker_name }} - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
|
||||
{% else %}
|
||||
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
|
||||
{% endif %}
|
||||
|
||||
handlers:
|
||||
{% if LOG_FILE_PATH %}
|
||||
|
||||
@@ -171,7 +171,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
|
||||
# Templates for sections that may be inserted multiple times in config files
|
||||
SUPERVISORD_PROCESS_CONFIG_BLOCK = """
|
||||
[program:synapse_{name}]
|
||||
command=/usr/local/bin/python -m {app} \
|
||||
command=/usr/local/bin/prefix-log /usr/local/bin/python -m {app} \
|
||||
--config-path="{config_path}" \
|
||||
--config-path=/conf/workers/shared.yaml \
|
||||
--config-path=/conf/workers/{name}.yaml
|
||||
|
||||
12
docker/prefix-log
Executable file
12
docker/prefix-log
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Prefixes all lines on stdout and stderr with the process name (as determined by
|
||||
# the SUPERVISOR_PROCESS_NAME env var, which is automatically set by Supervisor).
|
||||
#
|
||||
# Usage:
|
||||
# prefix-log command [args...]
|
||||
#
|
||||
|
||||
exec 1> >(awk '{print "'"${SUPERVISOR_PROCESS_NAME}"' | "$0}' >&1)
|
||||
exec 2> >(awk '{print "'"${SUPERVISOR_PROCESS_NAME}"' | "$0}' >&2)
|
||||
exec "$@"
|
||||
@@ -1323,6 +1323,12 @@ oembed:
|
||||
#
|
||||
#registration_requires_token: true
|
||||
|
||||
# Allow users to submit a token during registration to bypass any required 3pid
|
||||
# steps configured in `registrations_require_3pid`.
|
||||
# Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow.
|
||||
#
|
||||
#enable_registration_token_3pid_bypass: false
|
||||
|
||||
# If set, allows registration of standard or admin accounts by anyone who
|
||||
# has the shared secret, even if registration is otherwise disabled.
|
||||
#
|
||||
|
||||
16
mypy.ini
16
mypy.ini
@@ -7,6 +7,7 @@ show_error_codes = True
|
||||
show_traceback = True
|
||||
mypy_path = stubs
|
||||
warn_unreachable = True
|
||||
warn_unused_ignores = True
|
||||
local_partial_types = True
|
||||
no_implicit_optional = True
|
||||
|
||||
@@ -23,10 +24,6 @@ files =
|
||||
# https://docs.python.org/3/library/re.html#re.X
|
||||
exclude = (?x)
|
||||
^(
|
||||
|scripts-dev/build_debian_packages.py
|
||||
|scripts-dev/federation_client.py
|
||||
|scripts-dev/release.py
|
||||
|
||||
|synapse/storage/databases/__init__.py
|
||||
|synapse/storage/databases/main/cache.py
|
||||
|synapse/storage/databases/main/devices.py
|
||||
@@ -134,6 +131,11 @@ disallow_untyped_defs = True
|
||||
[mypy-synapse.metrics.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.metrics._reactor_metrics]
|
||||
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
|
||||
# See https://github.com/matrix-org/synapse/pull/11771.
|
||||
warn_unused_ignores = False
|
||||
|
||||
[mypy-synapse.module_api.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
@@ -302,6 +304,9 @@ ignore_missing_imports = True
|
||||
[mypy-pympler.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-redbaron.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-rust_python_jaeger_reporter.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@@ -317,6 +322,9 @@ ignore_missing_imports = True
|
||||
[mypy-signedjson.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-srvlookup.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-treq.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
||||
25
poetry.lock
generated
25
poetry.lock
generated
@@ -309,14 +309,15 @@ smmap = ">=3.0.1,<6"
|
||||
|
||||
[[package]]
|
||||
name = "gitpython"
|
||||
version = "3.1.14"
|
||||
description = "Python Git Library"
|
||||
version = "3.1.27"
|
||||
description = "GitPython is a python library used to interact with Git repositories"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.4"
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[package.dependencies]
|
||||
gitdb = ">=4.0.1,<5"
|
||||
typing-extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.8\""}
|
||||
|
||||
[[package]]
|
||||
name = "hiredis"
|
||||
@@ -1315,6 +1316,14 @@ category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-commonmark"
|
||||
version = "0.9.2"
|
||||
description = "Typing stubs for commonmark"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-cryptography"
|
||||
version = "3.3.15"
|
||||
@@ -1553,7 +1562,7 @@ url_preview = ["lxml"]
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.7"
|
||||
content-hash = "f482a4f594a165dfe01ce253a22510d5faf38647ab0dcebc35789350cafd9bf0"
|
||||
content-hash = "3825cef058b8c9f520ef4b7acb92519be95db9a663a61c2e89a5fe431ed55655"
|
||||
|
||||
[metadata.files]
|
||||
attrs = [
|
||||
@@ -1766,8 +1775,8 @@ gitdb = [
|
||||
{file = "gitdb-4.0.9.tar.gz", hash = "sha256:bac2fd45c0a1c9cf619e63a90d62bdc63892ef92387424b855792a6cabe789aa"},
|
||||
]
|
||||
gitpython = [
|
||||
{file = "GitPython-3.1.14-py3-none-any.whl", hash = "sha256:3283ae2fba31c913d857e12e5ba5f9a7772bbc064ae2bb09efafa71b0dd4939b"},
|
||||
{file = "GitPython-3.1.14.tar.gz", hash = "sha256:be27633e7509e58391f10207cd32b2a6cf5b908f92d9cd30da2e514e1137af61"},
|
||||
{file = "GitPython-3.1.27-py3-none-any.whl", hash = "sha256:5b68b000463593e05ff2b261acff0ff0972df8ab1b70d3cdbd41b546c8b8fc3d"},
|
||||
{file = "GitPython-3.1.27.tar.gz", hash = "sha256:1c885ce809e8ba2d88a29befeb385fcea06338d3640712b59ca623c220bb5704"},
|
||||
]
|
||||
hiredis = [
|
||||
{file = "hiredis-2.0.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b4c8b0bc5841e578d5fb32a16e0c305359b987b850a06964bd5a62739d688048"},
|
||||
@@ -2588,6 +2597,10 @@ types-bleach = [
|
||||
{file = "types-bleach-4.1.4.tar.gz", hash = "sha256:2d30c2c4fb6854088ac636471352c9a51bf6c089289800d2a8060820a01cd43a"},
|
||||
{file = "types_bleach-4.1.4-py3-none-any.whl", hash = "sha256:edffe173ed6d7b6f3543036a96204a9319c3bf6c3645917b14274e43f000cc9b"},
|
||||
]
|
||||
types-commonmark = [
|
||||
{file = "types-commonmark-0.9.2.tar.gz", hash = "sha256:b894b67750c52fd5abc9a40a9ceb9da4652a391d75c1b480bba9cef90f19fc86"},
|
||||
{file = "types_commonmark-0.9.2-py3-none-any.whl", hash = "sha256:56f20199a1f9a2924443211a0ef97f8b15a8a956a7f4e9186be6950bf38d6d02"},
|
||||
]
|
||||
types-cryptography = [
|
||||
{file = "types-cryptography-3.3.15.tar.gz", hash = "sha256:a7983a75a7b88a18f88832008f0ef140b8d1097888ec1a0824ec8fb7e105273b"},
|
||||
{file = "types_cryptography-3.3.15-py3-none-any.whl", hash = "sha256:d9b0dd5465d7898d400850e7f35e5518aa93a7e23d3e11757cd81b4777089046"},
|
||||
|
||||
@@ -251,6 +251,7 @@ flake8 = "*"
|
||||
mypy = "==0.931"
|
||||
mypy-zope = "==0.3.5"
|
||||
types-bleach = ">=4.1.0"
|
||||
types-commonmark = ">=0.9.2"
|
||||
types-jsonschema = ">=3.2.0"
|
||||
types-opentracing = ">=2.4.2"
|
||||
types-Pillow = ">=8.3.4"
|
||||
@@ -270,7 +271,8 @@ idna = ">=2.5"
|
||||
|
||||
# The following are used by the release script
|
||||
click = "==8.1.0"
|
||||
GitPython = "==3.1.14"
|
||||
# GitPython was == 3.1.14; bumped to 3.1.20, the first release with type hints.
|
||||
GitPython = ">=3.1.20"
|
||||
commonmark = "==0.9.1"
|
||||
pygithub = "==1.55"
|
||||
# The following are executed as commands by the release script.
|
||||
|
||||
@@ -17,7 +17,8 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, Sequence
|
||||
from types import FrameType
|
||||
from typing import Collection, Optional, Sequence, Set
|
||||
|
||||
DISTS = (
|
||||
"debian:buster", # oldstable: EOL 2022-08
|
||||
@@ -41,15 +42,17 @@ projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
class Builder(object):
|
||||
def __init__(
|
||||
self, redirect_stdout=False, docker_build_args: Optional[Sequence[str]] = None
|
||||
self,
|
||||
redirect_stdout: bool = False,
|
||||
docker_build_args: Optional[Sequence[str]] = None,
|
||||
):
|
||||
self.redirect_stdout = redirect_stdout
|
||||
self._docker_build_args = tuple(docker_build_args or ())
|
||||
self.active_containers = set()
|
||||
self.active_containers: Set[str] = set()
|
||||
self._lock = threading.Lock()
|
||||
self._failed = False
|
||||
|
||||
def run_build(self, dist, skip_tests=False):
|
||||
def run_build(self, dist: str, skip_tests: bool = False) -> None:
|
||||
"""Build deb for a single distribution"""
|
||||
|
||||
if self._failed:
|
||||
@@ -63,7 +66,7 @@ class Builder(object):
|
||||
self._failed = True
|
||||
raise
|
||||
|
||||
def _inner_build(self, dist, skip_tests=False):
|
||||
def _inner_build(self, dist: str, skip_tests: bool = False) -> None:
|
||||
tag = dist.split(":", 1)[1]
|
||||
|
||||
# Make the dir where the debs will live.
|
||||
@@ -138,7 +141,7 @@ class Builder(object):
|
||||
stdout.close()
|
||||
print("Completed build of %s" % (dist,))
|
||||
|
||||
def kill_containers(self):
|
||||
def kill_containers(self) -> None:
|
||||
with self._lock:
|
||||
active = list(self.active_containers)
|
||||
|
||||
@@ -156,8 +159,10 @@ class Builder(object):
|
||||
self.active_containers.remove(c)
|
||||
|
||||
|
||||
def run_builds(builder, dists, jobs=1, skip_tests=False):
|
||||
def sig(signum, _frame):
|
||||
def run_builds(
|
||||
builder: Builder, dists: Collection[str], jobs: int = 1, skip_tests: bool = False
|
||||
) -> None:
|
||||
def sig(signum: int, _frame: Optional[FrameType]) -> None:
|
||||
print("Caught SIGINT")
|
||||
builder.kill_containers()
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ import argparse
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import requests
|
||||
@@ -47,13 +47,14 @@ import signedjson.types
|
||||
import srvlookup
|
||||
import yaml
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3 import HTTPConnectionPool
|
||||
|
||||
# uncomment the following to enable debug logging of http requests
|
||||
# from httplib import HTTPConnection
|
||||
# HTTPConnection.debuglevel = 1
|
||||
|
||||
|
||||
def encode_base64(input_bytes):
|
||||
def encode_base64(input_bytes: bytes) -> str:
|
||||
"""Encode bytes as a base64 string without any padding."""
|
||||
|
||||
input_len = len(input_bytes)
|
||||
@@ -63,7 +64,7 @@ def encode_base64(input_bytes):
|
||||
return output_string
|
||||
|
||||
|
||||
def encode_canonical_json(value):
|
||||
def encode_canonical_json(value: object) -> bytes:
|
||||
return json.dumps(
|
||||
value,
|
||||
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes
|
||||
@@ -130,7 +131,7 @@ def request(
|
||||
sig,
|
||||
destination,
|
||||
)
|
||||
authorization_headers.append(header.encode("ascii"))
|
||||
authorization_headers.append(header)
|
||||
print("Authorization: %s" % header, file=sys.stderr)
|
||||
|
||||
dest = "matrix://%s%s" % (destination, path)
|
||||
@@ -139,7 +140,10 @@ def request(
|
||||
s = requests.Session()
|
||||
s.mount("matrix://", MatrixConnectionAdapter())
|
||||
|
||||
headers = {"Host": destination, "Authorization": authorization_headers[0]}
|
||||
headers: Dict[str, str] = {
|
||||
"Host": destination,
|
||||
"Authorization": authorization_headers[0],
|
||||
}
|
||||
|
||||
if method == "POST":
|
||||
headers["Content-Type"] = "application/json"
|
||||
@@ -154,7 +158,7 @@ def request(
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Signs and sends a federation request to a matrix homeserver"
|
||||
)
|
||||
@@ -212,6 +216,7 @@ def main():
|
||||
if not args.server_name or not args.signing_key:
|
||||
read_args_from_config(args)
|
||||
|
||||
assert isinstance(args.signing_key, str)
|
||||
algorithm, version, key_base64 = args.signing_key.split()
|
||||
key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)
|
||||
|
||||
@@ -233,7 +238,7 @@ def main():
|
||||
print("")
|
||||
|
||||
|
||||
def read_args_from_config(args):
|
||||
def read_args_from_config(args: argparse.Namespace) -> None:
|
||||
with open(args.config, "r") as fh:
|
||||
config = yaml.safe_load(fh)
|
||||
|
||||
@@ -250,7 +255,7 @@ def read_args_from_config(args):
|
||||
|
||||
class MatrixConnectionAdapter(HTTPAdapter):
|
||||
@staticmethod
|
||||
def lookup(s, skip_well_known=False):
|
||||
def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]:
|
||||
if s[-1] == "]":
|
||||
# ipv6 literal (with no port)
|
||||
return s, 8448
|
||||
@@ -276,7 +281,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
|
||||
return s, 8448
|
||||
|
||||
@staticmethod
|
||||
def get_well_known(server_name):
|
||||
def get_well_known(server_name: str) -> Optional[str]:
|
||||
uri = "https://%s/.well-known/matrix/server" % (server_name,)
|
||||
print("fetching %s" % (uri,), file=sys.stderr)
|
||||
|
||||
@@ -299,7 +304,9 @@ class MatrixConnectionAdapter(HTTPAdapter):
|
||||
print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
|
||||
return None
|
||||
|
||||
def get_connection(self, url, proxies=None):
|
||||
def get_connection(
|
||||
self, url: str, proxies: Optional[Dict[str, str]] = None
|
||||
) -> HTTPConnectionPool:
|
||||
parsed = urlparse.urlparse(url)
|
||||
|
||||
(host, port) = self.lookup(parsed.netloc)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
can crop up, e.g the cache descriptors.
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Type
|
||||
|
||||
from mypy.nodes import ARG_NAMED_OPT
|
||||
from mypy.plugin import MethodSigContext, Plugin
|
||||
@@ -94,7 +94,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||
return signature
|
||||
|
||||
|
||||
def plugin(version: str):
|
||||
def plugin(version: str) -> Type[SynapsePlugin]:
|
||||
# This is the entry point of the plugin, and let's us deal with the fact
|
||||
# that the mypy plugin interface is *not* stable by looking at the version
|
||||
# string.
|
||||
|
||||
@@ -25,7 +25,7 @@ import sys
|
||||
import urllib.request
|
||||
from os import path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
import attr
|
||||
import click
|
||||
@@ -36,7 +36,9 @@ from github import Github
|
||||
from packaging import version
|
||||
|
||||
|
||||
def run_until_successful(command, *args, **kwargs):
|
||||
def run_until_successful(
|
||||
command: str, *args: Any, **kwargs: Any
|
||||
) -> subprocess.CompletedProcess:
|
||||
while True:
|
||||
completed_process = subprocess.run(command, *args, **kwargs)
|
||||
exit_code = completed_process.returncode
|
||||
@@ -50,7 +52,7 @@ def run_until_successful(command, *args, **kwargs):
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""An interactive script to walk through the parts of creating a release.
|
||||
|
||||
Requires the dev dependencies be installed, which can be done via:
|
||||
@@ -81,7 +83,7 @@ def cli():
|
||||
|
||||
|
||||
@cli.command()
|
||||
def prepare():
|
||||
def prepare() -> None:
|
||||
"""Do the initial stages of creating a release, including creating release
|
||||
branch, updating changelog and pushing to GitHub.
|
||||
"""
|
||||
@@ -161,7 +163,9 @@ def prepare():
|
||||
click.get_current_context().abort()
|
||||
|
||||
# Switch to the release branch.
|
||||
parsed_new_version: version.Version = version.parse(new_version)
|
||||
# Cast safety: parse() won't return a version.LegacyVersion from our
|
||||
# version string format.
|
||||
parsed_new_version = cast(version.Version, version.parse(new_version))
|
||||
|
||||
# We assume for debian changelogs that we only do RCs or full releases.
|
||||
assert not parsed_new_version.is_devrelease
|
||||
@@ -176,7 +180,6 @@ def prepare():
|
||||
# If the release branch only exists on the remote we check it out
|
||||
# locally.
|
||||
repo.git.checkout(release_branch_name)
|
||||
release_branch = repo.active_branch
|
||||
else:
|
||||
# If a branch doesn't exist we create one. We ask which one branch it
|
||||
# should be based off, defaulting to sensible values depending on the
|
||||
@@ -198,13 +201,15 @@ def prepare():
|
||||
click.get_current_context().abort()
|
||||
|
||||
# Check out the base branch and ensure it's up to date
|
||||
repo.head.reference = base_branch
|
||||
repo.head.set_reference(base_branch, "check out the base branch")
|
||||
repo.head.reset(index=True, working_tree=True)
|
||||
if not base_branch.is_remote():
|
||||
update_branch(repo)
|
||||
|
||||
# Create the new release branch
|
||||
release_branch = repo.create_head(release_branch_name, commit=base_branch)
|
||||
# Type ignore will no longer be needed after GitPython 3.1.28.
|
||||
# See https://github.com/gitpython-developers/GitPython/pull/1419
|
||||
repo.create_head(release_branch_name, commit=base_branch) # type: ignore[arg-type]
|
||||
|
||||
# Switch to the release branch and ensure it's up to date.
|
||||
repo.git.checkout(release_branch_name)
|
||||
@@ -265,7 +270,7 @@ def prepare():
|
||||
|
||||
@cli.command()
|
||||
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"])
|
||||
def tag(gh_token: Optional[str]):
|
||||
def tag(gh_token: Optional[str]) -> None:
|
||||
"""Tags the release and generates a draft GitHub release"""
|
||||
|
||||
# Make sure we're in a git repo.
|
||||
@@ -293,7 +298,12 @@ def tag(gh_token: Optional[str]):
|
||||
|
||||
click.echo_via_pager(changes)
|
||||
if click.confirm("Edit text?", default=False):
|
||||
changes = click.edit(changes, require_save=False)
|
||||
edited_changes = click.edit(changes, require_save=False)
|
||||
# This assert is for mypy's benefit. click's docs are a little unclear, but
|
||||
# when `require_save=False`, not saving the temp file in the editor returns
|
||||
# the original string.
|
||||
assert edited_changes is not None
|
||||
changes = edited_changes
|
||||
|
||||
repo.create_tag(tag_name, message=changes, sign=True)
|
||||
|
||||
@@ -347,7 +357,7 @@ def tag(gh_token: Optional[str]):
|
||||
|
||||
@cli.command()
|
||||
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=True)
|
||||
def publish(gh_token: str):
|
||||
def publish(gh_token: str) -> None:
|
||||
"""Publish release."""
|
||||
|
||||
# Make sure we're in a git repo.
|
||||
@@ -390,7 +400,7 @@ def publish(gh_token: str):
|
||||
|
||||
|
||||
@cli.command()
|
||||
def upload():
|
||||
def upload() -> None:
|
||||
"""Upload release to pypi."""
|
||||
|
||||
current_version = get_package_version()
|
||||
@@ -418,7 +428,7 @@ def upload():
|
||||
|
||||
|
||||
@cli.command()
|
||||
def announce():
|
||||
def announce() -> None:
|
||||
"""Generate markdown to announce the release."""
|
||||
|
||||
current_version = get_package_version()
|
||||
@@ -461,18 +471,19 @@ def get_package_version() -> version.Version:
|
||||
|
||||
def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]:
|
||||
"""Find the branch/ref, looking first locally then in the remote."""
|
||||
if ref_name in repo.refs:
|
||||
return repo.refs[ref_name]
|
||||
if ref_name in repo.references:
|
||||
return repo.references[ref_name]
|
||||
elif ref_name in repo.remote().refs:
|
||||
return repo.remote().refs[ref_name]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def update_branch(repo: git.Repo):
|
||||
def update_branch(repo: git.Repo) -> None:
|
||||
"""Ensure branch is up to date if it has a remote"""
|
||||
if repo.active_branch.tracking_branch():
|
||||
repo.git.merge(repo.active_branch.tracking_branch().name)
|
||||
tracking_branch = repo.active_branch.tracking_branch()
|
||||
if tracking_branch:
|
||||
repo.git.merge(tracking_branch.name)
|
||||
|
||||
|
||||
def get_changes_for_version(wanted_version: version.Version) -> str:
|
||||
@@ -536,7 +547,9 @@ def get_changes_for_version(wanted_version: version.Version) -> str:
|
||||
return "\n".join(version_changelog)
|
||||
|
||||
|
||||
def generate_and_write_changelog(current_version: version.Version, new_version: str):
|
||||
def generate_and_write_changelog(
|
||||
current_version: version.Version, new_version: str
|
||||
) -> None:
|
||||
# We do this by getting a draft so that we can edit it before writing to the
|
||||
# changelog.
|
||||
result = run_until_successful(
|
||||
@@ -558,8 +571,8 @@ def generate_and_write_changelog(current_version: version.Version, new_version:
|
||||
f.write(existing_content)
|
||||
|
||||
# Remove all the news fragments
|
||||
for f in glob.iglob("changelog.d/*.*"):
|
||||
os.remove(f)
|
||||
for filename in glob.iglob("changelog.d/*.*"):
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -27,7 +27,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.util import json_encoder
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Adds a signature to a JSON object.
|
||||
|
||||
|
||||
@@ -115,9 +115,7 @@ class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]):
|
||||
def __getitem__(self, index: slice) -> List[_KT_co]: ...
|
||||
def __delitem__(self, index: Union[int, slice]) -> None: ...
|
||||
|
||||
class SortedItemsView( # type: ignore
|
||||
ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]
|
||||
):
|
||||
class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]):
|
||||
def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ...
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ...
|
||||
|
||||
@@ -48,7 +48,6 @@ from twisted.logger import LoggingFile, LogLevel
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
|
||||
import synapse
|
||||
from synapse.api.constants import MAX_PDU_SIZE
|
||||
from synapse.app import check_bind_error
|
||||
from synapse.app.phone_stats_home import start_phone_stats_home
|
||||
@@ -60,6 +59,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers
|
||||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
||||
from synapse.handlers.auth import load_legacy_password_auth_providers
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.logging.opentracing import init_tracer
|
||||
from synapse.metrics import install_gc_manager, register_threadpool
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.metrics.jemalloc import setup_jemalloc_stats
|
||||
@@ -431,7 +431,7 @@ async def start(hs: "HomeServer") -> None:
|
||||
refresh_certificate(hs)
|
||||
|
||||
# Start the tracer
|
||||
synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
|
||||
init_tracer(hs) # noqa
|
||||
|
||||
# Instantiate the modules so they can register their web resources to the module API
|
||||
# before we start the listeners.
|
||||
|
||||
@@ -43,6 +43,9 @@ class RegistrationConfig(Config):
|
||||
self.registration_requires_token = config.get(
|
||||
"registration_requires_token", False
|
||||
)
|
||||
self.enable_registration_token_3pid_bypasss = config.get(
|
||||
"enable_registration_token_3pid_bypasss", False
|
||||
)
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
@@ -309,6 +312,12 @@ class RegistrationConfig(Config):
|
||||
#
|
||||
#registration_requires_token: true
|
||||
|
||||
# Allow users to submit a token during registration to bypass any required 3pid
|
||||
# steps configured in `registrations_require_3pid`.
|
||||
# Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow.
|
||||
#
|
||||
#enable_registration_token_3pid_bypass: false
|
||||
|
||||
# If set, allows registration of standard or admin accounts by anyone who
|
||||
# has the shared secret, even if registration is otherwise disabled.
|
||||
#
|
||||
|
||||
@@ -186,7 +186,7 @@ KNOWN_RESOURCES = {
|
||||
class HttpResourceConfig:
|
||||
names: List[str] = attr.ib(
|
||||
factory=list,
|
||||
validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore
|
||||
validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)),
|
||||
)
|
||||
compress: bool = attr.ib(
|
||||
default=False,
|
||||
@@ -231,9 +231,7 @@ class ManholeConfig:
|
||||
class LimitRemoteRoomsConfig:
|
||||
enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False)
|
||||
complexity: Union[float, int] = attr.ib(
|
||||
validator=attr.validators.instance_of(
|
||||
(float, int) # type: ignore[arg-type] # noqa
|
||||
),
|
||||
validator=attr.validators.instance_of((float, int)), # noqa
|
||||
default=1.0,
|
||||
)
|
||||
complexity_error: str = attr.ib(
|
||||
|
||||
@@ -27,7 +27,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
@@ -204,7 +203,9 @@ def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
|
||||
key_to_move = field.pop(-1)
|
||||
sub_dict = src
|
||||
for sub_field in field: # e.g. sub_field => "content"
|
||||
if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]:
|
||||
if sub_field in sub_dict and isinstance(
|
||||
sub_dict[sub_field], collections.abc.Mapping
|
||||
):
|
||||
sub_dict = sub_dict[sub_field]
|
||||
else:
|
||||
return
|
||||
@@ -622,7 +623,7 @@ def validate_canonicaljson(value: Any) -> None:
|
||||
# Note that Infinity, -Infinity, and NaN are also considered floats.
|
||||
raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON)
|
||||
|
||||
elif isinstance(value, (dict, frozendict)):
|
||||
elif isinstance(value, collections.abc.Mapping):
|
||||
for v in value.values():
|
||||
validate_canonicaljson(v)
|
||||
|
||||
|
||||
@@ -268,8 +268,8 @@ class FederationServer(FederationBase):
|
||||
transaction_id=transaction_id,
|
||||
destination=destination,
|
||||
origin=origin,
|
||||
origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore
|
||||
pdus=transaction_data.get("pdus"), # type: ignore
|
||||
origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore[arg-type]
|
||||
pdus=transaction_data.get("pdus"),
|
||||
edus=transaction_data.get("edus"),
|
||||
)
|
||||
|
||||
|
||||
@@ -229,21 +229,21 @@ class TransportLayerClient:
|
||||
"""
|
||||
logger.debug(
|
||||
"send_data dest=%s, txid=%s",
|
||||
transaction.destination, # type: ignore
|
||||
transaction.transaction_id, # type: ignore
|
||||
transaction.destination,
|
||||
transaction.transaction_id,
|
||||
)
|
||||
|
||||
if transaction.destination == self.server_name: # type: ignore
|
||||
if transaction.destination == self.server_name:
|
||||
raise RuntimeError("Transport layer cannot send to itself!")
|
||||
|
||||
# FIXME: This is only used by the tests. The actual json sent is
|
||||
# generated by the json_data_callback.
|
||||
json_data = transaction.get_dict()
|
||||
|
||||
path = _create_v1_path("/send/%s", transaction.transaction_id) # type: ignore
|
||||
path = _create_v1_path("/send/%s", transaction.transaction_id)
|
||||
|
||||
return await self.client.put_json(
|
||||
transaction.destination, # type: ignore
|
||||
transaction.destination,
|
||||
path=path,
|
||||
data=json_data,
|
||||
json_data_callback=json_data_callback,
|
||||
|
||||
@@ -481,7 +481,7 @@ class AuthHandler:
|
||||
sid = authdict["session"]
|
||||
|
||||
# Convert the URI and method to strings.
|
||||
uri = request.uri.decode("utf-8") # type: ignore
|
||||
uri = request.uri.decode("utf-8")
|
||||
method = request.method.decode("utf-8")
|
||||
|
||||
# If there's no session ID, create a new session.
|
||||
|
||||
@@ -966,7 +966,7 @@ class OidcProvider:
|
||||
"Mapping provider does not support de-duplicating Matrix IDs"
|
||||
)
|
||||
|
||||
attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
|
||||
attributes = await self._user_mapping_provider.map_user_attributes(
|
||||
userinfo, token
|
||||
)
|
||||
|
||||
|
||||
138
synapse/handlers/push_rules.py
Normal file
138
synapse/handlers/push_rules.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import SynapseError, UnrecognizedRequestError
|
||||
from synapse.push.baserules import BASE_RULE_IDS
|
||||
from synapse.storage.push_rule import RuleNotFoundException
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class RuleSpec:
|
||||
scope: str
|
||||
template: str
|
||||
rule_id: str
|
||||
attr: Optional[str]
|
||||
|
||||
|
||||
class PushRulesHandler:
|
||||
"""A class to handle changes in push rules for users."""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._notifier = hs.get_notifier()
|
||||
self._main_store = hs.get_datastores().main
|
||||
|
||||
async def set_rule_attr(
|
||||
self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
|
||||
) -> None:
|
||||
"""Set an attribute (enabled or actions) on an existing push rule.
|
||||
|
||||
Notifies listeners (e.g. sync handler) of the change.
|
||||
|
||||
Args:
|
||||
user_id: the user for which to modify the push rule.
|
||||
spec: the spec of the push rule to modify.
|
||||
val: the value to change the attribute to.
|
||||
|
||||
Raises:
|
||||
RuleNotFoundException if the rule being modified doesn't exist.
|
||||
SynapseError(400) if the value is malformed.
|
||||
UnrecognizedRequestError if the attribute to change is unknown.
|
||||
InvalidRuleException if we're trying to change the actions on a rule but
|
||||
the provided actions aren't compliant with the spec.
|
||||
"""
|
||||
if spec.attr not in ("enabled", "actions"):
|
||||
# for the sake of potential future expansion, shouldn't report
|
||||
# 404 in the case of an unknown request so check it corresponds to
|
||||
# a known attribute first.
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
|
||||
rule_id = spec.rule_id
|
||||
is_default_rule = rule_id.startswith(".")
|
||||
if is_default_rule:
|
||||
if namespaced_rule_id not in BASE_RULE_IDS:
|
||||
raise RuleNotFoundException("Unknown rule %r" % (namespaced_rule_id,))
|
||||
if spec.attr == "enabled":
|
||||
if isinstance(val, dict) and "enabled" in val:
|
||||
val = val["enabled"]
|
||||
if not isinstance(val, bool):
|
||||
# Legacy fallback
|
||||
# This should *actually* take a dict, but many clients pass
|
||||
# bools directly, so let's not break them.
|
||||
raise SynapseError(400, "Value for 'enabled' must be boolean")
|
||||
await self._main_store.set_push_rule_enabled(
|
||||
user_id, namespaced_rule_id, val, is_default_rule
|
||||
)
|
||||
elif spec.attr == "actions":
|
||||
if not isinstance(val, dict):
|
||||
raise SynapseError(400, "Value must be a dict")
|
||||
actions = val.get("actions")
|
||||
if not isinstance(actions, list):
|
||||
raise SynapseError(400, "Value for 'actions' must be dict")
|
||||
check_actions(actions)
|
||||
rule_id = spec.rule_id
|
||||
is_default_rule = rule_id.startswith(".")
|
||||
if is_default_rule:
|
||||
if namespaced_rule_id not in BASE_RULE_IDS:
|
||||
raise RuleNotFoundException(
|
||||
"Unknown rule %r" % (namespaced_rule_id,)
|
||||
)
|
||||
await self._main_store.set_push_rule_actions(
|
||||
user_id, namespaced_rule_id, actions, is_default_rule
|
||||
)
|
||||
else:
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
self.notify_user(user_id)
|
||||
|
||||
def notify_user(self, user_id: str) -> None:
|
||||
"""Notify listeners about a push rule change.
|
||||
|
||||
Args:
|
||||
user_id: the user ID the change is for.
|
||||
"""
|
||||
stream_id = self._main_store.get_max_push_rules_stream_id()
|
||||
self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
|
||||
|
||||
|
||||
def check_actions(actions: List[Union[str, JsonDict]]) -> None:
|
||||
"""Check if the given actions are spec compliant.
|
||||
|
||||
Args:
|
||||
actions: the actions to check.
|
||||
|
||||
Raises:
|
||||
InvalidRuleException if the rules aren't compliant with the spec.
|
||||
"""
|
||||
if not isinstance(actions, list):
|
||||
raise InvalidRuleException("No actions found")
|
||||
|
||||
for a in actions:
|
||||
if a in ["notify", "dont_notify", "coalesce"]:
|
||||
pass
|
||||
elif isinstance(a, dict) and "set_tweak" in a:
|
||||
pass
|
||||
else:
|
||||
raise InvalidRuleException("Unrecognised action %s" % a)
|
||||
|
||||
|
||||
class InvalidRuleException(Exception):
|
||||
pass
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections.abc
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -24,7 +25,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
@@ -380,7 +380,7 @@ class RelationsHandler:
|
||||
# Do not bundle aggregations for an event which represents an edit or an
|
||||
# annotation. It does not make sense for them to have related events.
|
||||
relates_to = event.content.get("m.relates_to")
|
||||
if isinstance(relates_to, (dict, frozendict)):
|
||||
if isinstance(relates_to, collections.abc.Mapping):
|
||||
relation_type = relates_to.get("rel_type")
|
||||
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
|
||||
continue
|
||||
|
||||
@@ -357,7 +357,7 @@ class SearchHandler:
|
||||
itertools.chain(
|
||||
# The events_before and events_after for each context.
|
||||
itertools.chain.from_iterable(
|
||||
itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
|
||||
itertools.chain(context["events_before"], context["events_after"])
|
||||
for context in contexts.values()
|
||||
),
|
||||
# The returned events.
|
||||
@@ -373,10 +373,10 @@ class SearchHandler:
|
||||
|
||||
for context in contexts.values():
|
||||
context["events_before"] = self._event_serializer.serialize_events(
|
||||
context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
|
||||
context["events_before"], time_now, bundle_aggregations=aggregations
|
||||
)
|
||||
context["events_after"] = self._event_serializer.serialize_events(
|
||||
context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
|
||||
context["events_after"], time_now, bundle_aggregations=aggregations
|
||||
)
|
||||
|
||||
results = [
|
||||
|
||||
@@ -256,7 +256,9 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self._enabled = bool(hs.config.registration.registration_requires_token)
|
||||
self._enabled = bool(
|
||||
hs.config.registration.registration_requires_token
|
||||
) or bool(hs.config.registration.enable_registration_token_3pid_bypasss)
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
|
||||
@@ -295,7 +295,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||
if isawaitable(raw_callback_return):
|
||||
callback_return = await raw_callback_return
|
||||
else:
|
||||
callback_return = raw_callback_return # type: ignore
|
||||
callback_return = raw_callback_return
|
||||
|
||||
return callback_return
|
||||
|
||||
@@ -469,7 +469,7 @@ class JsonResource(DirectServeJsonResource):
|
||||
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
|
||||
callback_return = await raw_callback_return
|
||||
else:
|
||||
callback_return = raw_callback_return # type: ignore
|
||||
callback_return = raw_callback_return
|
||||
|
||||
return callback_return
|
||||
|
||||
|
||||
@@ -722,6 +722,11 @@ P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R:
|
||||
"""Unwraps an arbitrary awaitable by awaiting it."""
|
||||
return await awaitable
|
||||
|
||||
|
||||
@overload
|
||||
def preserve_fn( # type: ignore[misc]
|
||||
f: Callable[P, Awaitable[R]],
|
||||
@@ -802,17 +807,20 @@ def run_in_background( # type: ignore[misc]
|
||||
# by synchronous exceptions, so let's turn them into Failures.
|
||||
return defer.fail()
|
||||
|
||||
# `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain
|
||||
# value. Convert it to a `Deferred`.
|
||||
if isinstance(res, typing.Coroutine):
|
||||
# Wrap the coroutine in a `Deferred`.
|
||||
res = defer.ensureDeferred(res)
|
||||
|
||||
# At this point we should have a Deferred, if not then f was a synchronous
|
||||
# function, wrap it in a Deferred for consistency.
|
||||
if not isinstance(res, defer.Deferred):
|
||||
# `res` is not a `Deferred` and not a `Coroutine`.
|
||||
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
|
||||
assert not isinstance(res, Awaitable)
|
||||
|
||||
return defer.succeed(res)
|
||||
elif isinstance(res, defer.Deferred):
|
||||
pass
|
||||
elif isinstance(res, Awaitable):
|
||||
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
|
||||
# or `Future` from `make_awaitable`.
|
||||
res = defer.ensureDeferred(_unwrap_awaitable(res))
|
||||
else:
|
||||
# `res` is a plain value. Wrap it in a `Deferred`.
|
||||
res = defer.succeed(res)
|
||||
|
||||
if res.called and not res.paused:
|
||||
# The function should have maintained the logcontext, so we can
|
||||
|
||||
@@ -82,6 +82,7 @@ from synapse.handlers.auth import (
|
||||
ON_LOGGED_OUT_CALLBACK,
|
||||
AuthHandler,
|
||||
)
|
||||
from synapse.handlers.push_rules import RuleSpec, check_actions
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.server import (
|
||||
DirectServeHtmlResource,
|
||||
@@ -109,12 +110,14 @@ from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
DomainSpecificString,
|
||||
JsonDict,
|
||||
JsonMapping,
|
||||
Requester,
|
||||
StateMap,
|
||||
UserID,
|
||||
UserInfo,
|
||||
UserProfile,
|
||||
create_requester,
|
||||
map_username_to_mxid_localpart,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
@@ -151,6 +154,7 @@ __all__ = [
|
||||
"PRESENCE_ALL_USERS",
|
||||
"LoginResponse",
|
||||
"JsonDict",
|
||||
"JsonMapping",
|
||||
"EventBase",
|
||||
"StateMap",
|
||||
"ProfileInfo",
|
||||
@@ -193,6 +197,7 @@ class ModuleApi:
|
||||
self._clock: Clock = hs.get_clock()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._send_email_handler = hs.get_send_email_handler()
|
||||
self._push_rules_handler = hs.get_push_rules_handler()
|
||||
self.custom_template_dir = hs.config.server.custom_template_directory
|
||||
|
||||
try:
|
||||
@@ -569,6 +574,26 @@ class ModuleApi:
|
||||
return username
|
||||
return UserID(username, self._hs.hostname).to_string()
|
||||
|
||||
def normalize_username(
|
||||
self, username: Union[str, bytes], case_sensitive: bool = False
|
||||
) -> str:
|
||||
"""Map a username onto a string suitable for a MXID
|
||||
|
||||
This follows the algorithm laid out at
|
||||
https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
|
||||
|
||||
Added in Synapse v1.58.0
|
||||
|
||||
Args:
|
||||
username: username to be mapped
|
||||
case_sensitive: true if TEST and test should be mapped
|
||||
onto different mxids
|
||||
|
||||
Returns:
|
||||
string suitable for a mxid localpart
|
||||
"""
|
||||
return map_username_to_mxid_localpart(username, case_sensitive)
|
||||
|
||||
async def get_profile_for_user(self, localpart: str) -> ProfileInfo:
|
||||
"""Look up the profile info for the user with the given localpart.
|
||||
|
||||
@@ -1350,6 +1375,68 @@ class ModuleApi:
|
||||
"""
|
||||
await self._store.add_user_bound_threepid(user_id, medium, address, id_server)
|
||||
|
||||
def check_push_rule_actions(
|
||||
self, actions: List[Union[str, Dict[str, str]]]
|
||||
) -> None:
|
||||
"""Checks if the given push rule actions are valid according to the Matrix
|
||||
specification.
|
||||
|
||||
See https://spec.matrix.org/v1.2/client-server-api/#actions for the list of valid
|
||||
actions.
|
||||
|
||||
Added in Synapse v1.58.0.
|
||||
|
||||
Args:
|
||||
actions: the actions to check.
|
||||
|
||||
Raises:
|
||||
synapse.module_api.errors.InvalidRuleException if the actions are invalid.
|
||||
"""
|
||||
check_actions(actions)
|
||||
|
||||
async def set_push_rule_action(
|
||||
self,
|
||||
user_id: str,
|
||||
scope: str,
|
||||
kind: str,
|
||||
rule_id: str,
|
||||
actions: List[Union[str, Dict[str, str]]],
|
||||
) -> None:
|
||||
"""Changes the actions of an existing push rule for the given user.
|
||||
|
||||
See https://spec.matrix.org/v1.2/client-server-api/#push-rules for more
|
||||
information about push rules and their syntax.
|
||||
|
||||
Can only be called on the main process.
|
||||
|
||||
Added in Synapse v1.58.0.
|
||||
|
||||
Args:
|
||||
user_id: the user for which to change the push rule's actions.
|
||||
scope: the push rule's scope, currently only "global" is allowed.
|
||||
kind: the push rule's kind.
|
||||
rule_id: the push rule's identifier.
|
||||
actions: the actions to run when the rule's conditions match.
|
||||
|
||||
Raises:
|
||||
RuntimeError if this method is called on a worker or `scope` is invalid.
|
||||
synapse.module_api.errors.RuleNotFoundException if the rule being modified
|
||||
can't be found.
|
||||
synapse.module_api.errors.InvalidRuleException if the actions are invalid.
|
||||
"""
|
||||
if self.worker_app is not None:
|
||||
raise RuntimeError("module tried to change push rule actions on a worker")
|
||||
|
||||
if scope != "global":
|
||||
raise RuntimeError(
|
||||
"invalid scope %s, only 'global' is currently allowed" % scope
|
||||
)
|
||||
|
||||
spec = RuleSpec(scope, kind, rule_id, "actions")
|
||||
await self._push_rules_handler.set_rule_attr(
|
||||
user_id, spec, {"actions": actions}
|
||||
)
|
||||
|
||||
|
||||
class PublicRoomListManager:
|
||||
"""Contains methods for adding to, removing from and querying whether a room
|
||||
@@ -1419,7 +1506,7 @@ class AccountDataManager:
|
||||
f"{user_id} is not local to this homeserver; can't access account data for remote users."
|
||||
)
|
||||
|
||||
async def get_global(self, user_id: str, data_type: str) -> Optional[JsonDict]:
|
||||
async def get_global(self, user_id: str, data_type: str) -> Optional[JsonMapping]:
|
||||
"""
|
||||
Gets some global account data, of a specified type, for the specified user.
|
||||
|
||||
|
||||
@@ -20,10 +20,14 @@ from synapse.api.errors import (
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.handlers.push_rules import InvalidRuleException
|
||||
from synapse.storage.push_rule import RuleNotFoundException
|
||||
|
||||
__all__ = [
|
||||
"InvalidClientCredentialsError",
|
||||
"RedirectException",
|
||||
"SynapseError",
|
||||
"ConfigError",
|
||||
"InvalidRuleException",
|
||||
"RuleNotFoundException",
|
||||
]
|
||||
|
||||
@@ -12,9 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import attr
|
||||
from typing import TYPE_CHECKING, List, Sequence, Tuple, Union
|
||||
|
||||
from synapse.api.errors import (
|
||||
NotFoundError,
|
||||
@@ -22,6 +20,7 @@ from synapse.api.errors import (
|
||||
SynapseError,
|
||||
UnrecognizedRequestError,
|
||||
)
|
||||
from synapse.handlers.push_rules import InvalidRuleException, RuleSpec, check_actions
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
@@ -29,7 +28,6 @@ from synapse.http.servlet import (
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.push.baserules import BASE_RULE_IDS
|
||||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
|
||||
from synapse.rest.client._base import client_patterns
|
||||
@@ -40,14 +38,6 @@ if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class RuleSpec:
|
||||
scope: str
|
||||
template: str
|
||||
rule_id: str
|
||||
attr: Optional[str]
|
||||
|
||||
|
||||
class PushRuleRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
|
||||
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
|
||||
@@ -60,6 +50,7 @@ class PushRuleRestServlet(RestServlet):
|
||||
self.store = hs.get_datastores().main
|
||||
self.notifier = hs.get_notifier()
|
||||
self._is_worker = hs.config.worker.worker_app is not None
|
||||
self._push_rules_handler = hs.get_push_rules_handler()
|
||||
|
||||
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
|
||||
if self._is_worker:
|
||||
@@ -81,8 +72,13 @@ class PushRuleRestServlet(RestServlet):
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
if spec.attr:
|
||||
await self.set_rule_attr(user_id, spec, content)
|
||||
self.notify_user(user_id)
|
||||
try:
|
||||
await self._push_rules_handler.set_rule_attr(user_id, spec, content)
|
||||
except InvalidRuleException as e:
|
||||
raise SynapseError(400, "Invalid actions: %s" % e)
|
||||
except RuleNotFoundException:
|
||||
raise NotFoundError("Unknown rule")
|
||||
|
||||
return 200, {}
|
||||
|
||||
if spec.rule_id.startswith("."):
|
||||
@@ -98,23 +94,23 @@ class PushRuleRestServlet(RestServlet):
|
||||
|
||||
before = parse_string(request, "before")
|
||||
if before:
|
||||
before = _namespaced_rule_id(spec, before)
|
||||
before = f"global/{spec.template}/{before}"
|
||||
|
||||
after = parse_string(request, "after")
|
||||
if after:
|
||||
after = _namespaced_rule_id(spec, after)
|
||||
after = f"global/{spec.template}/{after}"
|
||||
|
||||
try:
|
||||
await self.store.add_push_rule(
|
||||
user_id=user_id,
|
||||
rule_id=_namespaced_rule_id_from_spec(spec),
|
||||
rule_id=f"global/{spec.template}/{spec.rule_id}",
|
||||
priority_class=priority_class,
|
||||
conditions=conditions,
|
||||
actions=actions,
|
||||
before=before,
|
||||
after=after,
|
||||
)
|
||||
self.notify_user(user_id)
|
||||
self._push_rules_handler.notify_user(user_id)
|
||||
except InconsistentRuleException as e:
|
||||
raise SynapseError(400, str(e))
|
||||
except RuleNotFoundException as e:
|
||||
@@ -133,11 +129,11 @@ class PushRuleRestServlet(RestServlet):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
|
||||
|
||||
try:
|
||||
await self.store.delete_push_rule(user_id, namespaced_rule_id)
|
||||
self.notify_user(user_id)
|
||||
self._push_rules_handler.notify_user(user_id)
|
||||
return 200, {}
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
@@ -172,55 +168,6 @@ class PushRuleRestServlet(RestServlet):
|
||||
else:
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
def notify_user(self, user_id: str) -> None:
|
||||
stream_id = self.store.get_max_push_rules_stream_id()
|
||||
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
|
||||
|
||||
async def set_rule_attr(
|
||||
self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
|
||||
) -> None:
|
||||
if spec.attr not in ("enabled", "actions"):
|
||||
# for the sake of potential future expansion, shouldn't report
|
||||
# 404 in the case of an unknown request so check it corresponds to
|
||||
# a known attribute first.
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
rule_id = spec.rule_id
|
||||
is_default_rule = rule_id.startswith(".")
|
||||
if is_default_rule:
|
||||
if namespaced_rule_id not in BASE_RULE_IDS:
|
||||
raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
|
||||
if spec.attr == "enabled":
|
||||
if isinstance(val, dict) and "enabled" in val:
|
||||
val = val["enabled"]
|
||||
if not isinstance(val, bool):
|
||||
# Legacy fallback
|
||||
# This should *actually* take a dict, but many clients pass
|
||||
# bools directly, so let's not break them.
|
||||
raise SynapseError(400, "Value for 'enabled' must be boolean")
|
||||
await self.store.set_push_rule_enabled(
|
||||
user_id, namespaced_rule_id, val, is_default_rule
|
||||
)
|
||||
elif spec.attr == "actions":
|
||||
if not isinstance(val, dict):
|
||||
raise SynapseError(400, "Value must be a dict")
|
||||
actions = val.get("actions")
|
||||
if not isinstance(actions, list):
|
||||
raise SynapseError(400, "Value for 'actions' must be dict")
|
||||
_check_actions(actions)
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
rule_id = spec.rule_id
|
||||
is_default_rule = rule_id.startswith(".")
|
||||
if is_default_rule:
|
||||
if namespaced_rule_id not in BASE_RULE_IDS:
|
||||
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
|
||||
await self.store.set_push_rule_actions(
|
||||
user_id, namespaced_rule_id, actions, is_default_rule
|
||||
)
|
||||
else:
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
|
||||
def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
|
||||
"""Turn a sequence of path components into a rule spec
|
||||
@@ -291,24 +238,11 @@ def _rule_tuple_from_request_object(
|
||||
raise InvalidRuleException("No actions found")
|
||||
actions = req_obj["actions"]
|
||||
|
||||
_check_actions(actions)
|
||||
check_actions(actions)
|
||||
|
||||
return conditions, actions
|
||||
|
||||
|
||||
def _check_actions(actions: List[Union[str, JsonDict]]) -> None:
|
||||
if not isinstance(actions, list):
|
||||
raise InvalidRuleException("No actions found")
|
||||
|
||||
for a in actions:
|
||||
if a in ["notify", "dont_notify", "coalesce"]:
|
||||
pass
|
||||
elif isinstance(a, dict) and "set_tweak" in a:
|
||||
pass
|
||||
else:
|
||||
raise InvalidRuleException("Unrecognised action")
|
||||
|
||||
|
||||
def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict:
|
||||
if path == []:
|
||||
raise UnrecognizedRequestError(
|
||||
@@ -357,17 +291,5 @@ def _priority_class_from_spec(spec: RuleSpec) -> int:
|
||||
return pc
|
||||
|
||||
|
||||
def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str:
|
||||
return _namespaced_rule_id(spec, spec.rule_id)
|
||||
|
||||
|
||||
def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str:
|
||||
return "global/%s/%s" % (spec.template, rule_id)
|
||||
|
||||
|
||||
class InvalidRuleException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
PushRuleRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -929,6 +929,10 @@ def _calculate_registration_flows(
|
||||
# always let users provide both MSISDN & email
|
||||
flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
|
||||
|
||||
# Add a flow that doesn't require any 3pids, if the config requests it.
|
||||
if config.registration.enable_registration_token_3pid_bypasss:
|
||||
flows.append([LoginType.REGISTRATION_TOKEN])
|
||||
|
||||
# Prepend m.login.terms to all flows if we're requiring consent
|
||||
if config.consent.user_consent_at_registration:
|
||||
for flow in flows:
|
||||
@@ -942,7 +946,8 @@ def _calculate_registration_flows(
|
||||
# Prepend registration token to all flows if we're requiring a token
|
||||
if config.registration.registration_requires_token:
|
||||
for flow in flows:
|
||||
flow.insert(0, LoginType.REGISTRATION_TOKEN)
|
||||
if LoginType.REGISTRATION_TOKEN not in flow:
|
||||
flow.insert(0, LoginType.REGISTRATION_TOKEN)
|
||||
|
||||
return flows
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ from synapse.handlers.presence import (
|
||||
WorkerPresenceHandler,
|
||||
)
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.handlers.push_rules import PushRulesHandler
|
||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||
from synapse.handlers.receipts import ReceiptsHandler
|
||||
from synapse.handlers.register import RegistrationHandler
|
||||
@@ -810,6 +811,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
def get_account_handler(self) -> AccountHandler:
|
||||
return AccountHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_push_rules_handler(self) -> PushRulesHandler:
|
||||
return PushRulesHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_outbound_redis_connection(self) -> "ConnectionHandler":
|
||||
"""
|
||||
|
||||
@@ -232,10 +232,10 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
|
||||
# is racy.
|
||||
# Have resolved to invalidate the whole cache for now and do
|
||||
# something about it if and when the perf becomes significant
|
||||
self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
|
||||
self._invalidate_all_cache_and_stream(
|
||||
txn, self.user_last_seen_monthly_active
|
||||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
|
||||
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
|
||||
|
||||
reserved_users = await self.get_registered_reserved_users()
|
||||
await self.db_pool.runInteraction(
|
||||
@@ -363,7 +363,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
|
||||
|
||||
if self._limit_usage_by_mau or self._mau_stats_only:
|
||||
# Trial users and guests should not be included as part of MAU group
|
||||
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
|
||||
is_guest = await self.is_guest(user_id)
|
||||
if is_guest:
|
||||
return
|
||||
is_trial = await self.is_trial_user(user_id)
|
||||
|
||||
@@ -16,7 +16,7 @@ import abc
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
from synapse.api.errors import NotFoundError, StoreError
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.push.baserules import list_with_base_rules
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
@@ -618,7 +618,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
are always stored in the database `push_rules` table).
|
||||
|
||||
Raises:
|
||||
NotFoundError if the rule does not exist.
|
||||
RuleNotFoundException if the rule does not exist.
|
||||
"""
|
||||
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
@@ -668,8 +668,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
)
|
||||
txn.execute(sql, (user_id, rule_id))
|
||||
if txn.fetchone() is None:
|
||||
# needed to set NOT_FOUND code.
|
||||
raise NotFoundError("Push rule does not exist.")
|
||||
raise RuleNotFoundException("Push rule does not exist.")
|
||||
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
@@ -698,9 +697,6 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
"""
|
||||
Sets the `actions` state of a push rule.
|
||||
|
||||
Will throw NotFoundError if the rule does not exist; the Code for this
|
||||
is NOT_FOUND.
|
||||
|
||||
Args:
|
||||
user_id: the user ID of the user who wishes to enable/disable the rule
|
||||
e.g. '@tina:example.org'
|
||||
@@ -712,6 +708,9 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
is_default_rule: True if and only if this is a server-default rule.
|
||||
This skips the check for existence (as only user-created rules
|
||||
are always stored in the database `push_rules` table).
|
||||
|
||||
Raises:
|
||||
RuleNotFoundException if the rule does not exist.
|
||||
"""
|
||||
actions_json = json_encoder.encode(actions)
|
||||
|
||||
@@ -744,7 +743,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
except StoreError as serr:
|
||||
if serr.code == 404:
|
||||
# this sets the NOT_FOUND error Code
|
||||
raise NotFoundError("Push rule does not exist")
|
||||
raise RuleNotFoundException("Push rule does not exist")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
@@ -12,11 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections.abc
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
@@ -160,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
predecessor = create_event.content.get("predecessor", None)
|
||||
|
||||
# Ensure the key is a dictionary
|
||||
if not isinstance(predecessor, (dict, frozendict)):
|
||||
if not isinstance(predecessor, collections.abc.Mapping):
|
||||
return None
|
||||
|
||||
# The keys must be strings since the data is JSON.
|
||||
|
||||
@@ -501,11 +501,11 @@ def _upgrade_existing_database(
|
||||
|
||||
if hasattr(module, "run_create"):
|
||||
logger.info("Running %s:run_create", relative_path)
|
||||
module.run_create(cur, database_engine) # type: ignore
|
||||
module.run_create(cur, database_engine)
|
||||
|
||||
if not is_empty and hasattr(module, "run_upgrade"):
|
||||
logger.info("Running %s:run_upgrade", relative_path)
|
||||
module.run_upgrade(cur, database_engine, config=config) # type: ignore
|
||||
module.run_upgrade(cur, database_engine, config=config)
|
||||
elif ext == ".pyc" or file_name == "__pycache__":
|
||||
# Sometimes .pyc files turn up anyway even though we've
|
||||
# disabled their generation; e.g. from distribution package
|
||||
|
||||
@@ -107,7 +107,7 @@ class TTLCache(Generic[KT, VT]):
|
||||
self._metrics.inc_hits()
|
||||
return e.value, e.expiry_time, e.ttl
|
||||
|
||||
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
|
||||
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
|
||||
"""Remove a value from the cache
|
||||
|
||||
If key is in the cache, remove it and return its value, else return default.
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections.abc
|
||||
from typing import Any
|
||||
|
||||
from frozendict import frozendict
|
||||
@@ -35,7 +36,7 @@ def freeze(o: Any) -> Any:
|
||||
|
||||
|
||||
def unfreeze(o: Any) -> Any:
|
||||
if isinstance(o, (dict, frozendict)):
|
||||
if isinstance(o, collections.abc.Mapping):
|
||||
return {k: unfreeze(v) for k, v in o.items()}
|
||||
|
||||
if isinstance(o, (bytes, str)):
|
||||
|
||||
@@ -83,7 +83,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
|
||||
)
|
||||
|
||||
# mock up the response, and have the agent return it
|
||||
self._mock_agent.request.return_value = defer.succeed(
|
||||
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
|
||||
_mock_response(
|
||||
{
|
||||
"pdus": [
|
||||
|
||||
@@ -226,7 +226,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||
# Send the server a device list EDU for the other user, this will cause
|
||||
# it to try and resync the device lists.
|
||||
self.hs.get_federation_transport_client().query_user_devices.return_value = (
|
||||
defer.succeed(
|
||||
make_awaitable(
|
||||
{
|
||||
"stream_id": "1",
|
||||
"user_id": "@user2:host2",
|
||||
|
||||
@@ -19,7 +19,6 @@ from unittest import mock
|
||||
from parameterized import parameterized
|
||||
from signedjson import key as key, sign as sign
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import RoomEncryptionAlgorithms
|
||||
@@ -704,7 +703,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
||||
|
||||
self.hs.get_federation_client().query_client_keys = mock.Mock(
|
||||
return_value=defer.succeed(
|
||||
return_value=make_awaitable(
|
||||
{
|
||||
"device_keys": {remote_user_id: {}},
|
||||
"master_keys": {
|
||||
@@ -777,14 +776,14 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
# Pretend we're sharing a room with the user we're querying. If not,
|
||||
# `_query_devices_for_destination` will return early.
|
||||
self.store.get_rooms_for_user = mock.Mock(
|
||||
return_value=defer.succeed({"some_room_id"})
|
||||
return_value=make_awaitable({"some_room_id"})
|
||||
)
|
||||
|
||||
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
|
||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
||||
|
||||
self.hs.get_federation_client().query_user_devices = mock.Mock(
|
||||
return_value=defer.succeed(
|
||||
return_value=make_awaitable(
|
||||
{
|
||||
"user_id": remote_user_id,
|
||||
"stream_id": 1,
|
||||
|
||||
@@ -17,8 +17,6 @@
|
||||
from typing import Any, Type, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes
|
||||
@@ -190,7 +188,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||
|
||||
# check_password must return an awaitable
|
||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||
channel = self._send_password_login("u", "p")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@u:test", channel.json_body["user_id"])
|
||||
@@ -226,13 +224,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
self.get_success(module_api.register_user("u"))
|
||||
|
||||
# log in twice, to get two devices
|
||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||
tok1 = self.login("u", "p")
|
||||
self.login("u", "p", device_id="dev2")
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# have the auth provider deny the request to start with
|
||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||
|
||||
# make the initial request which returns a 401
|
||||
session = self._start_delete_device_session(tok1, "dev2")
|
||||
@@ -246,7 +244,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# Finally, check the request goes through when we allow it
|
||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
|
||||
self.assertEqual(channel.code, 200)
|
||||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||
@@ -260,7 +258,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
# check_password must return an awaitable
|
||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||
channel = self._send_password_login("u", "p")
|
||||
self.assertEqual(channel.code, 403, channel.result)
|
||||
|
||||
@@ -277,7 +275,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
# have the auth provider deny the request
|
||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||
|
||||
# log in twice, to get two devices
|
||||
tok1 = self.login("localuser", "localpass")
|
||||
@@ -320,7 +318,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
# check_password must return an awaitable
|
||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
@@ -342,7 +340,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
# allow login via the auth provider
|
||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||
|
||||
# log in twice, to get two devices
|
||||
tok1 = self.login("localuser", "p")
|
||||
@@ -359,7 +357,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
mock_password_provider.check_password.assert_not_called()
|
||||
|
||||
# now try deleting with the local password
|
||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||
channel = self._authed_delete_device(
|
||||
tok1, "dev2", session, "localuser", "localpass"
|
||||
)
|
||||
@@ -413,7 +411,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
|
||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@user:bz", None)
|
||||
)
|
||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||
@@ -427,7 +425,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
# try a weird username. Again, it's unclear what we *expect* to happen
|
||||
# in these cases, but at least we can guard against the API changing
|
||||
# unexpectedly
|
||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@ MALFORMED! :bz", None)
|
||||
)
|
||||
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
||||
@@ -477,7 +475,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# right params, but authing as the wrong user
|
||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@user:bz", None)
|
||||
)
|
||||
body["auth"]["test_field"] = "foo"
|
||||
@@ -490,7 +488,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# and finally, succeed
|
||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@localuser:test", None)
|
||||
)
|
||||
channel = self._delete_device(tok1, "dev2", body)
|
||||
@@ -508,9 +506,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
self.custom_auth_provider_callback_test_body()
|
||||
|
||||
def custom_auth_provider_callback_test_body(self):
|
||||
callback = Mock(return_value=defer.succeed(None))
|
||||
callback = Mock(return_value=make_awaitable(None))
|
||||
|
||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@user:bz", callback)
|
||||
)
|
||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||
@@ -646,7 +644,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
login is disabled"""
|
||||
# register the user and log in twice via the test login type to get two devices,
|
||||
self.register_user("localuser", "localpass")
|
||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@localuser:test", None)
|
||||
)
|
||||
channel = self._send_login("test.login_type", "localuser", test_field="")
|
||||
|
||||
@@ -193,8 +193,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
@override_config({"limit_usage_by_mau": True})
|
||||
def test_get_or_create_user_mau_not_blocked(self):
|
||||
# Type ignore: mypy doesn't like us assigning to methods.
|
||||
self.store.count_monthly_users = Mock( # type: ignore[assignment]
|
||||
self.store.count_monthly_users = Mock(
|
||||
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
|
||||
)
|
||||
# Ensure does not throw exception
|
||||
@@ -202,8 +201,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
@override_config({"limit_usage_by_mau": True})
|
||||
def test_get_or_create_user_mau_blocked(self):
|
||||
# Type ignore: mypy doesn't like us assigning to methods.
|
||||
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=make_awaitable(self.lots_of_users)
|
||||
)
|
||||
self.get_failure(
|
||||
@@ -211,8 +209,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||
ResourceLimitError,
|
||||
)
|
||||
|
||||
# Type ignore: mypy doesn't like us assigning to methods.
|
||||
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=make_awaitable(self.hs.config.server.max_mau_value)
|
||||
)
|
||||
self.get_failure(
|
||||
|
||||
@@ -65,11 +65,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
# we mock out the keyring so as to skip the authentication check on the
|
||||
# federation API call.
|
||||
mock_keyring = Mock(spec=["verify_json_for_server"])
|
||||
mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
|
||||
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
|
||||
|
||||
# we mock out the federation client too
|
||||
mock_federation_client = Mock(spec=["put_json"])
|
||||
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
|
||||
mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
|
||||
|
||||
# the tests assume that we are starting at unix time 1000
|
||||
reactor.pump((1000,))
|
||||
@@ -98,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.datastore = hs.get_datastores().main
|
||||
self.datastore.get_destination_retry_timings = Mock(
|
||||
return_value=defer.succeed(None)
|
||||
return_value=make_awaitable(None)
|
||||
)
|
||||
|
||||
self.datastore.get_device_updates_by_remote = Mock(
|
||||
|
||||
@@ -15,7 +15,6 @@ from typing import Tuple
|
||||
from unittest.mock import Mock, patch
|
||||
from urllib.parse import quote
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
@@ -30,6 +29,7 @@ from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.storage.test_user_directory import GetUserDirectoryTables
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.test_utils.event_injection import inject_member_event
|
||||
from tests.unittest import override_config
|
||||
|
||||
@@ -439,7 +439,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
|
||||
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
|
||||
with patch.object(
|
||||
self.store, "remove_from_user_dir", mock_remove_from_user_dir
|
||||
):
|
||||
@@ -454,7 +454,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
self.store.register_user(user_id=r_user_id, password_hash=None)
|
||||
)
|
||||
|
||||
mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
|
||||
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
|
||||
with patch.object(
|
||||
self.store, "remove_from_user_dir", mock_remove_from_user_dir
|
||||
):
|
||||
|
||||
@@ -11,8 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.rest import admin
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
@@ -22,7 +26,9 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver) -> None:
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
self._store = homeserver.get_datastores().main
|
||||
self._module_api = homeserver.get_module_api()
|
||||
self._account_data_mgr = self._module_api.account_data_manager
|
||||
@@ -91,7 +97,7 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||
)
|
||||
with self.assertRaises(TypeError):
|
||||
# This throws an exception because it's a frozen dict.
|
||||
the_data["wombat"] = False
|
||||
the_data["wombat"] = False # type: ignore[index]
|
||||
|
||||
def test_put_global(self) -> None:
|
||||
"""
|
||||
@@ -143,15 +149,14 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
# The account data type must be a string.
|
||||
self.get_success_or_raise(
|
||||
self._module_api.account_data_manager.put_global(
|
||||
self.user_id, 42, {} # type: ignore
|
||||
)
|
||||
self._module_api.account_data_manager.put_global(self.user_id, 42, {}) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
# The account data dict must be a dict.
|
||||
# noinspection PyTypeChecker
|
||||
self.get_success_or_raise(
|
||||
self._module_api.account_data_manager.put_global(
|
||||
self.user_id, "test.data", 42 # type: ignore
|
||||
self.user_id, "test.data", 42 # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -19,8 +19,9 @@ from synapse.api.constants import EduTypes, EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.units import Transaction
|
||||
from synapse.handlers.presence import UserPresenceState
|
||||
from synapse.handlers.push_rules import InvalidRuleException
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, presence, profile, room
|
||||
from synapse.rest.client import login, notifications, presence, profile, room
|
||||
from synapse.types import create_requester
|
||||
|
||||
from tests.events.test_presence_router import send_presence_update, sync_presence
|
||||
@@ -38,6 +39,7 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
presence.register_servlets,
|
||||
profile.register_servlets,
|
||||
notifications.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
@@ -553,6 +555,94 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||
self.assertEqual(state[("org.matrix.test", "")].state_key, "")
|
||||
self.assertEqual(state[("org.matrix.test", "")].content, {})
|
||||
|
||||
def test_set_push_rules_action(self) -> None:
|
||||
"""Test that a module can change the actions of an existing push rule for a user."""
|
||||
|
||||
# Create a room with 2 users in it. Push rules must not match if the user is the
|
||||
# event's sender, so we need one user to send messages and one user to receive
|
||||
# notifications.
|
||||
user_id = self.register_user("user", "password")
|
||||
tok = self.login("user", "password")
|
||||
|
||||
room_id = self.helper.create_room_as(user_id, is_public=True, tok=tok)
|
||||
|
||||
user_id2 = self.register_user("user2", "password")
|
||||
tok2 = self.login("user2", "password")
|
||||
self.helper.join(room_id, user_id2, tok=tok2)
|
||||
|
||||
# Register a 3rd user and join them to the room, so that we don't accidentally
|
||||
# trigger 1:1 push rules.
|
||||
user_id3 = self.register_user("user3", "password")
|
||||
tok3 = self.login("user3", "password")
|
||||
self.helper.join(room_id, user_id3, tok=tok3)
|
||||
|
||||
# Send a message as the second user and check that it notifies.
|
||||
res = self.helper.send(room_id=room_id, body="here's a message", tok=tok2)
|
||||
event_id = res["event_id"]
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/notifications",
|
||||
access_token=tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
|
||||
self.assertEqual(
|
||||
channel.json_body["notifications"][0]["event"]["event_id"],
|
||||
event_id,
|
||||
channel.json_body,
|
||||
)
|
||||
|
||||
# Change the .m.rule.message actions to not notify on new messages.
|
||||
self.get_success(
|
||||
defer.ensureDeferred(
|
||||
self.module_api.set_push_rule_action(
|
||||
user_id=user_id,
|
||||
scope="global",
|
||||
kind="underride",
|
||||
rule_id=".m.rule.message",
|
||||
actions=["dont_notify"],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Send another message as the second user and check that the number of
|
||||
# notifications didn't change.
|
||||
self.helper.send(room_id=room_id, body="here's another message", tok=tok2)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/notifications?from=",
|
||||
access_token=tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
|
||||
|
||||
def test_check_push_rules_actions(self) -> None:
|
||||
"""Test that modules can check whether a list of push rules actions are spec
|
||||
compliant.
|
||||
"""
|
||||
with self.assertRaises(InvalidRuleException):
|
||||
self.module_api.check_push_rule_actions(["foo"])
|
||||
|
||||
with self.assertRaises(InvalidRuleException):
|
||||
self.module_api.check_push_rule_actions({"foo": "bar"})
|
||||
|
||||
self.module_api.check_push_rule_actions(["notify"])
|
||||
|
||||
self.module_api.check_push_rule_actions(
|
||||
[{"set_tweak": "sound", "value": "default"}]
|
||||
)
|
||||
|
||||
def test_normalize_username(self) -> None:
|
||||
username = "Haxxor"
|
||||
username2 = "_leet"
|
||||
username3 = "aNoThErTeSt"
|
||||
self.assertEqual(self.module_api.normalize_username(username), "haxxor")
|
||||
self.assertEqual(self.module_api.normalize_username(username2), "=5fleet")
|
||||
self.assertEqual(self.module_api.normalize_username(username3), "anothertest")
|
||||
|
||||
|
||||
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
|
||||
"""For testing ModuleApi functionality in a multi-worker setup"""
|
||||
|
||||
@@ -102,8 +102,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||
for i in range(20):
|
||||
server_name = "other_server_%d" % (i,)
|
||||
room = self.create_room_with_remote_server(user, token, server_name)
|
||||
mock_client1.reset_mock() # type: ignore[attr-defined]
|
||||
mock_client2.reset_mock() # type: ignore[attr-defined]
|
||||
mock_client1.reset_mock()
|
||||
mock_client2.reset_mock()
|
||||
|
||||
self.create_and_send_event(room, UserID.from_string(user))
|
||||
self.replicate()
|
||||
@@ -167,8 +167,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||
for i in range(20):
|
||||
server_name = "other_server_%d" % (i,)
|
||||
room = self.create_room_with_remote_server(user, token, server_name)
|
||||
mock_client1.reset_mock() # type: ignore[attr-defined]
|
||||
mock_client2.reset_mock() # type: ignore[attr-defined]
|
||||
mock_client1.reset_mock()
|
||||
mock_client2.reset_mock()
|
||||
|
||||
self.get_success(
|
||||
typing_handler.started_typing(
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
@@ -24,6 +23,7 @@ from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
|
||||
|
||||
class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
@@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
|
||||
presence_handler = Mock(spec=PresenceHandler)
|
||||
presence_handler.set_state.return_value = defer.succeed(None)
|
||||
presence_handler.set_state.return_value = make_awaitable(None)
|
||||
|
||||
hs = self.setup_test_homeserver(
|
||||
"red",
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, List, Optional
|
||||
from unittest.mock import Mock, call
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
@@ -1426,9 +1425,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_simple(self) -> None:
|
||||
"Simple test for searching rooms over federation"
|
||||
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
|
||||
{}
|
||||
)
|
||||
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
|
||||
|
||||
search_filter = {"generic_search_term": "foobar"}
|
||||
|
||||
@@ -1456,7 +1453,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
||||
# with a 404, when using search filters.
|
||||
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
|
||||
HttpResponseException(404, "Not Found", b""),
|
||||
defer.succeed({}),
|
||||
make_awaitable({}),
|
||||
)
|
||||
|
||||
search_filter = {"generic_search_term": "foobar"}
|
||||
|
||||
@@ -22,6 +22,7 @@ from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionC
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.utils import MockClock
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_executes_given_function(self):
|
||||
cb = Mock(return_value=defer.succeed(self.mock_http_response))
|
||||
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||
res = yield self.cache.fetch_or_execute(
|
||||
self.mock_key, cb, "some_arg", keyword="arg"
|
||||
)
|
||||
@@ -47,7 +48,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_deduplicates_based_on_key(self):
|
||||
cb = Mock(return_value=defer.succeed(self.mock_http_response))
|
||||
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||
for i in range(3): # invoke multiple times
|
||||
res = yield self.cache.fetch_or_execute(
|
||||
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
|
||||
@@ -130,7 +131,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cleans_up(self):
|
||||
cb = Mock(return_value=defer.succeed(self.mock_http_response))
|
||||
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
|
||||
# should NOT have cleaned up yet
|
||||
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
|
||||
from synapse.api.errors import ResourceLimitError
|
||||
from synapse.rest import admin
|
||||
@@ -68,16 +66,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
return_value=make_awaitable(1000)
|
||||
)
|
||||
self._rlsn._server_notices_manager.send_notice = Mock(
|
||||
return_value=defer.succeed(Mock())
|
||||
return_value=make_awaitable(Mock())
|
||||
)
|
||||
self._send_notice = self._rlsn._server_notices_manager.send_notice
|
||||
|
||||
self.user_id = "@user_id:test"
|
||||
|
||||
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
|
||||
return_value=defer.succeed("!something:localhost")
|
||||
return_value=make_awaitable("!something:localhost")
|
||||
)
|
||||
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
|
||||
self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))
|
||||
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
|
||||
|
||||
@override_config({"hs_disabled": True})
|
||||
@@ -95,7 +93,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
|
||||
"""Test when user has blocked notice, but should have it removed"""
|
||||
|
||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
|
||||
mock_event = Mock(
|
||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||
)
|
||||
@@ -111,7 +109,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
Test when user has blocked notice, but notice ought to be there (NOOP)
|
||||
"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
|
||||
return_value=make_awaitable(None),
|
||||
side_effect=ResourceLimitError(403, "foo"),
|
||||
)
|
||||
|
||||
mock_event = Mock(
|
||||
@@ -130,7 +129,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
Test when user does not have blocked notice, but should have one
|
||||
"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
|
||||
return_value=make_awaitable(None),
|
||||
side_effect=ResourceLimitError(403, "foo"),
|
||||
)
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
|
||||
@@ -141,7 +141,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Test when user does not have blocked notice, nor should they (NOOP)
|
||||
"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
|
||||
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
|
||||
@@ -152,7 +152,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
Test when user is not part of the MAU cohort - this should not ever
|
||||
happen - but ...
|
||||
"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
|
||||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||
return_value=make_awaitable(None)
|
||||
)
|
||||
@@ -167,7 +167,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
an alert message is not sent into the room
|
||||
"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
return_value=defer.succeed(None),
|
||||
return_value=make_awaitable(None),
|
||||
side_effect=ResourceLimitError(
|
||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||
),
|
||||
@@ -182,7 +182,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
Test that when a server is disabled, that MAU limit alerting is ignored.
|
||||
"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
return_value=defer.succeed(None),
|
||||
return_value=make_awaitable(None),
|
||||
side_effect=ResourceLimitError(
|
||||
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
|
||||
),
|
||||
@@ -199,14 +199,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
is suppressed that the room is returned to an unblocked state.
|
||||
"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
return_value=defer.succeed(None),
|
||||
return_value=make_awaitable(None),
|
||||
side_effect=ResourceLimitError(
|
||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||
),
|
||||
)
|
||||
|
||||
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
|
||||
return_value=defer.succeed((True, []))
|
||||
return_value=make_awaitable((True, []))
|
||||
)
|
||||
|
||||
mock_event = Mock(
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
@@ -259,10 +258,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||
def test_populate_monthly_users_should_update(self):
|
||||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
||||
|
||||
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment]
|
||||
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
|
||||
|
||||
self.store.user_last_seen_monthly_active = Mock(
|
||||
return_value=defer.succeed(None)
|
||||
return_value=make_awaitable(None)
|
||||
)
|
||||
d = self.store.populate_monthly_active_users("user_id")
|
||||
self.get_success(d)
|
||||
@@ -272,9 +271,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||
def test_populate_monthly_users_should_not_update(self):
|
||||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
||||
|
||||
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment]
|
||||
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
|
||||
self.store.user_last_seen_monthly_active = Mock(
|
||||
return_value=defer.succeed(self.hs.get_clock().time_msec())
|
||||
return_value=make_awaitable(self.hs.get_clock().time_msec())
|
||||
)
|
||||
|
||||
d = self.store.populate_monthly_active_users("user_id")
|
||||
|
||||
@@ -233,7 +233,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||
# Register mock device list retrieval on the federation client.
|
||||
federation_client = self.homeserver.get_federation_client()
|
||||
federation_client.query_user_devices = Mock(
|
||||
return_value=succeed(
|
||||
return_value=make_awaitable(
|
||||
{
|
||||
"user_id": remote_user_id,
|
||||
"stream_id": 1,
|
||||
|
||||
@@ -52,7 +52,7 @@ def make_awaitable(result: TV) -> Awaitable[TV]:
|
||||
This uses Futures as they can be awaited multiple times so can be returned
|
||||
to multiple callers.
|
||||
"""
|
||||
future = Future() # type: ignore
|
||||
future: Future[TV] = Future()
|
||||
future.set_result(result)
|
||||
return future
|
||||
|
||||
@@ -69,7 +69,7 @@ def setup_awaitable_errors() -> Callable[[], None]:
|
||||
|
||||
# State shared between unraisablehook and check_for_unraisable_exceptions.
|
||||
unraisable_exceptions = []
|
||||
orig_unraisablehook = sys.unraisablehook # type: ignore
|
||||
orig_unraisablehook = sys.unraisablehook
|
||||
|
||||
def unraisablehook(unraisable):
|
||||
unraisable_exceptions.append(unraisable.exc_value)
|
||||
@@ -78,11 +78,11 @@ def setup_awaitable_errors() -> Callable[[], None]:
|
||||
"""
|
||||
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
|
||||
"""
|
||||
sys.unraisablehook = orig_unraisablehook # type: ignore
|
||||
sys.unraisablehook = orig_unraisablehook
|
||||
if unraisable_exceptions:
|
||||
raise unraisable_exceptions.pop()
|
||||
|
||||
sys.unraisablehook = unraisablehook # type: ignore
|
||||
sys.unraisablehook = unraisablehook
|
||||
|
||||
return cleanup
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
log_entry = self.format(record)
|
||||
log_level = record.levelname.lower().replace("warning", "warn")
|
||||
self.tx_log.emit( # type: ignore
|
||||
self.tx_log.emit(
|
||||
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user