1
0

Compare commits

...

12 Commits

Author SHA1 Message Date
H. Shay
85b30abfde update comment 2022-04-28 11:55:00 -07:00
H. Shay
d809a4c8fb update newsfragment number 2022-04-27 15:38:02 -07:00
H. Shay
df621cbaa5 lints 2022-04-27 15:34:11 -07:00
H. Shay
301b9cdfcc newsfragement 2022-04-27 15:17:45 -07:00
H. Shay
21002db229 add function to normalize username to module api + test 2022-04-27 15:17:33 -07:00
Sean Quah
78b99de7c2 Prefer make_awaitable over defer.succeed in tests (#12505)
When configuring the return values of mocks, prefer awaitables from
`make_awaitable` over `defer.succeed`. `Deferred`s are only awaitable
once, so it is inappropriate for a mock to return the same `Deferred`
multiple times.

Also update `run_in_background` to support functions that return
arbitrary awaitables.

Signed-off-by: Sean Quah <seanq@element.io>
2022-04-27 14:58:26 +01:00
Brendan Abolivier
5ef673de4f Add a module API to allow modules to edit push rule actions (#12406)
Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
2022-04-27 13:55:33 +00:00
reivilibre
d743b25c8f Use supervisord to supervise Postgres and Caddy in the Complement image. (#12480)
Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
2022-04-27 14:39:41 +01:00
David Robertson
30c8e7e408 Make scripts-dev pass mypy --disallow-untyped-defs (#12356)
Not enforced in config yet. One day.
2022-04-27 13:10:31 +00:00
David Robertson
6463244375 Remove unused # type: ignores (#12531)
Over time we've begun to use newer versions of mypy, typeshed, stub
packages---and of course we've improved our own annotations. This makes
some type ignore comments no longer necessary. I have removed them.

There was one exception: a module that imports `select.epoll`. The
ignore is redundant on Linux, but I've kept it ignored for those of us
who work on the source tree using not-Linux. (#11771)

I'm more interested in the config line which enforces this. I want
unused ignores to be reported, because I think it's useful feedback when
annotating to know when you've fixed a problem you had to previously
ignore.

* Installing extras before typechecking

Lacking an easy way to install all extras generically, let's bite the bullet and
make install the hand-maintained `all` extra before typechecking.

Now that https://github.com/matrix-org/backend-meta/pull/6 is merged to
the release/v1 branch.
2022-04-27 14:03:44 +01:00
Patrick Cloke
8a23bde823 Consistently use collections.abc.Mapping to check frozendict. (#12564) 2022-04-27 09:00:07 -04:00
Will Hunt
e8d1ec0e92 Add option to enable token registration without requiring 3pids (#12526) 2022-04-27 12:57:53 +00:00
70 changed files with 657 additions and 306 deletions

View File

@@ -20,13 +20,9 @@ jobs:
- run: scripts-dev/config-lint.sh
lint:
# This does a vanilla `poetry install` - no extras. I'm slightly anxious
# that we might skip some typechecks on code that uses extras. However,
# I think the right way to fix this is to mark any extras needed for
# typechecking as development dependencies. To detect this, we ought to
# turn up mypy's strictness: disallow unknown imports and be accept fewer
# uses of `Any`.
uses: "matrix-org/backend-meta/.github/workflows/python-poetry-ci.yml@v1"
with:
typechecking-extras: "all"
lint-crlf:
runs-on: ubuntu-latest

1
changelog.d/12356.misc Normal file
View File

@@ -0,0 +1 @@
Fix scripts-dev to pass typechecking.

View File

@@ -0,0 +1 @@
Add a module API to allow modules to change actions for existing push rules of local users.

1
changelog.d/12480.misc Normal file
View File

@@ -0,0 +1 @@
Use supervisord to supervise Postgres and Caddy in the Complement image to reduce restart time.

1
changelog.d/12505.misc Normal file
View File

@@ -0,0 +1 @@
Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests.

View File

@@ -0,0 +1 @@
Add new `enable_registration_token_3pid_bypass` configuration option to allow registrations via token as an alternative to verifying a 3pid.

1
changelog.d/12531.misc Normal file
View File

@@ -0,0 +1 @@
Remove unused `# type: ignore`s.

1
changelog.d/12564.misc Normal file
View File

@@ -0,0 +1 @@
Consistently check if an object is a `frozendict`.

View File

@@ -0,0 +1 @@
Add the Synapse function `types.map_username_to_mxid_localpart` to the Module API.

View File

@@ -20,6 +20,9 @@ RUN rm /etc/nginx/sites-enabled/default
# Copy Synapse worker, nginx and supervisord configuration template files
COPY ./docker/conf-workers/* /conf/
# Copy a script to prefix log lines with the supervisor program name
COPY ./docker/prefix-log /usr/local/bin/
# Expose nginx listener port
EXPOSE 8080/tcp

View File

@@ -34,13 +34,16 @@ WORKDIR /data
# Copy the caddy config
COPY conf-workers/caddy.complement.json /root/caddy.json
COPY conf-workers/postgres.supervisord.conf /etc/supervisor/conf.d/postgres.conf
COPY conf-workers/caddy.supervisord.conf /etc/supervisor/conf.d/caddy.conf
# Copy the entrypoint
COPY conf-workers/start-complement-synapse-workers.sh /
# Expose caddy's listener ports
EXPOSE 8008 8448
ENTRYPOINT /start-complement-synapse-workers.sh
ENTRYPOINT ["/start-complement-synapse-workers.sh"]
# Update the healthcheck to have a shorter check interval
HEALTHCHECK --start-period=5s --interval=1s --timeout=1s \

View File

@@ -0,0 +1,7 @@
[program:caddy]
command=/usr/local/bin/prefix-log /root/caddy run --config /root/caddy.json
autorestart=unexpected
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0

View File

@@ -0,0 +1,16 @@
[program:postgres]
command=/usr/local/bin/prefix-log /usr/bin/pg_ctlcluster 13 main start --foreground
# Lower priority number = starts first
priority=1
autorestart=unexpected
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
# Use 'Fast Shutdown' mode which aborts current transactions and closes connections quickly.
# (Default (TERM) is 'Smart Shutdown' which stops accepting new connections but
# lets existing connections close gracefully.)
stopsignal=INT

View File

@@ -12,12 +12,6 @@ function log {
# Replace the server name in the caddy config
sed -i "s/{{ server_name }}/${SERVER_NAME}/g" /root/caddy.json
log "starting postgres"
pg_ctlcluster 13 main start
log "starting caddy"
/root/caddy start --config /root/caddy.json
# Set the server name of the homeserver
export SYNAPSE_SERVER_NAME=${SERVER_NAME}

View File

@@ -2,11 +2,7 @@ version: 1
formatters:
precise:
{% if worker_name %}
format: '%(asctime)s - worker:{{ worker_name }} - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
{% else %}
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
{% endif %}
handlers:
{% if LOG_FILE_PATH %}

View File

@@ -171,7 +171,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
# Templates for sections that may be inserted multiple times in config files
SUPERVISORD_PROCESS_CONFIG_BLOCK = """
[program:synapse_{name}]
command=/usr/local/bin/python -m {app} \
command=/usr/local/bin/prefix-log /usr/local/bin/python -m {app} \
--config-path="{config_path}" \
--config-path=/conf/workers/shared.yaml \
--config-path=/conf/workers/{name}.yaml

12
docker/prefix-log Executable file
View File

@@ -0,0 +1,12 @@
#!/bin/bash
#
# Prefixes all lines on stdout and stderr with the process name (as determined by
# the SUPERVISOR_PROCESS_NAME env var, which is automatically set by Supervisor).
#
# Usage:
# prefix-log command [args...]
#
exec 1> >(awk '{print "'"${SUPERVISOR_PROCESS_NAME}"' | "$0}' >&1)
exec 2> >(awk '{print "'"${SUPERVISOR_PROCESS_NAME}"' | "$0}' >&2)
exec "$@"

View File

@@ -1323,6 +1323,12 @@ oembed:
#
#registration_requires_token: true
# Allow users to submit a token during registration to bypass any required 3pid
# steps configured in `registrations_require_3pid`.
# Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow.
#
#enable_registration_token_3pid_bypass: false
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#

View File

@@ -7,6 +7,7 @@ show_error_codes = True
show_traceback = True
mypy_path = stubs
warn_unreachable = True
warn_unused_ignores = True
local_partial_types = True
no_implicit_optional = True
@@ -23,10 +24,6 @@ files =
# https://docs.python.org/3/library/re.html#re.X
exclude = (?x)
^(
|scripts-dev/build_debian_packages.py
|scripts-dev/federation_client.py
|scripts-dev/release.py
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
@@ -134,6 +131,11 @@ disallow_untyped_defs = True
[mypy-synapse.metrics.*]
disallow_untyped_defs = True
[mypy-synapse.metrics._reactor_metrics]
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
# See https://github.com/matrix-org/synapse/pull/11771.
warn_unused_ignores = False
[mypy-synapse.module_api.*]
disallow_untyped_defs = True
@@ -302,6 +304,9 @@ ignore_missing_imports = True
[mypy-pympler.*]
ignore_missing_imports = True
[mypy-redbaron.*]
ignore_missing_imports = True
[mypy-rust_python_jaeger_reporter.*]
ignore_missing_imports = True
@@ -317,6 +322,9 @@ ignore_missing_imports = True
[mypy-signedjson.*]
ignore_missing_imports = True
[mypy-srvlookup.*]
ignore_missing_imports = True
[mypy-treq.*]
ignore_missing_imports = True

25
poetry.lock generated
View File

@@ -309,14 +309,15 @@ smmap = ">=3.0.1,<6"
[[package]]
name = "gitpython"
version = "3.1.14"
description = "Python Git Library"
version = "3.1.27"
description = "GitPython is a python library used to interact with Git repositories"
category = "dev"
optional = false
python-versions = ">=3.4"
python-versions = ">=3.7"
[package.dependencies]
gitdb = ">=4.0.1,<5"
typing-extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.8\""}
[[package]]
name = "hiredis"
@@ -1315,6 +1316,14 @@ category = "dev"
optional = false
python-versions = "*"
[[package]]
name = "types-commonmark"
version = "0.9.2"
description = "Typing stubs for commonmark"
category = "dev"
optional = false
python-versions = "*"
[[package]]
name = "types-cryptography"
version = "3.3.15"
@@ -1553,7 +1562,7 @@ url_preview = ["lxml"]
[metadata]
lock-version = "1.1"
python-versions = "^3.7"
content-hash = "f482a4f594a165dfe01ce253a22510d5faf38647ab0dcebc35789350cafd9bf0"
content-hash = "3825cef058b8c9f520ef4b7acb92519be95db9a663a61c2e89a5fe431ed55655"
[metadata.files]
attrs = [
@@ -1766,8 +1775,8 @@ gitdb = [
{file = "gitdb-4.0.9.tar.gz", hash = "sha256:bac2fd45c0a1c9cf619e63a90d62bdc63892ef92387424b855792a6cabe789aa"},
]
gitpython = [
{file = "GitPython-3.1.14-py3-none-any.whl", hash = "sha256:3283ae2fba31c913d857e12e5ba5f9a7772bbc064ae2bb09efafa71b0dd4939b"},
{file = "GitPython-3.1.14.tar.gz", hash = "sha256:be27633e7509e58391f10207cd32b2a6cf5b908f92d9cd30da2e514e1137af61"},
{file = "GitPython-3.1.27-py3-none-any.whl", hash = "sha256:5b68b000463593e05ff2b261acff0ff0972df8ab1b70d3cdbd41b546c8b8fc3d"},
{file = "GitPython-3.1.27.tar.gz", hash = "sha256:1c885ce809e8ba2d88a29befeb385fcea06338d3640712b59ca623c220bb5704"},
]
hiredis = [
{file = "hiredis-2.0.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b4c8b0bc5841e578d5fb32a16e0c305359b987b850a06964bd5a62739d688048"},
@@ -2588,6 +2597,10 @@ types-bleach = [
{file = "types-bleach-4.1.4.tar.gz", hash = "sha256:2d30c2c4fb6854088ac636471352c9a51bf6c089289800d2a8060820a01cd43a"},
{file = "types_bleach-4.1.4-py3-none-any.whl", hash = "sha256:edffe173ed6d7b6f3543036a96204a9319c3bf6c3645917b14274e43f000cc9b"},
]
types-commonmark = [
{file = "types-commonmark-0.9.2.tar.gz", hash = "sha256:b894b67750c52fd5abc9a40a9ceb9da4652a391d75c1b480bba9cef90f19fc86"},
{file = "types_commonmark-0.9.2-py3-none-any.whl", hash = "sha256:56f20199a1f9a2924443211a0ef97f8b15a8a956a7f4e9186be6950bf38d6d02"},
]
types-cryptography = [
{file = "types-cryptography-3.3.15.tar.gz", hash = "sha256:a7983a75a7b88a18f88832008f0ef140b8d1097888ec1a0824ec8fb7e105273b"},
{file = "types_cryptography-3.3.15-py3-none-any.whl", hash = "sha256:d9b0dd5465d7898d400850e7f35e5518aa93a7e23d3e11757cd81b4777089046"},

View File

@@ -251,6 +251,7 @@ flake8 = "*"
mypy = "==0.931"
mypy-zope = "==0.3.5"
types-bleach = ">=4.1.0"
types-commonmark = ">=0.9.2"
types-jsonschema = ">=3.2.0"
types-opentracing = ">=2.4.2"
types-Pillow = ">=8.3.4"
@@ -270,7 +271,8 @@ idna = ">=2.5"
# The following are used by the release script
click = "==8.1.0"
GitPython = "==3.1.14"
# GitPython was == 3.1.14; bumped to 3.1.20, the first release with type hints.
GitPython = ">=3.1.20"
commonmark = "==0.9.1"
pygithub = "==1.55"
# The following are executed as commands by the release script.

View File

@@ -17,7 +17,8 @@ import subprocess
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Sequence
from types import FrameType
from typing import Collection, Optional, Sequence, Set
DISTS = (
"debian:buster", # oldstable: EOL 2022-08
@@ -41,15 +42,17 @@ projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
class Builder(object):
def __init__(
self, redirect_stdout=False, docker_build_args: Optional[Sequence[str]] = None
self,
redirect_stdout: bool = False,
docker_build_args: Optional[Sequence[str]] = None,
):
self.redirect_stdout = redirect_stdout
self._docker_build_args = tuple(docker_build_args or ())
self.active_containers = set()
self.active_containers: Set[str] = set()
self._lock = threading.Lock()
self._failed = False
def run_build(self, dist, skip_tests=False):
def run_build(self, dist: str, skip_tests: bool = False) -> None:
"""Build deb for a single distribution"""
if self._failed:
@@ -63,7 +66,7 @@ class Builder(object):
self._failed = True
raise
def _inner_build(self, dist, skip_tests=False):
def _inner_build(self, dist: str, skip_tests: bool = False) -> None:
tag = dist.split(":", 1)[1]
# Make the dir where the debs will live.
@@ -138,7 +141,7 @@ class Builder(object):
stdout.close()
print("Completed build of %s" % (dist,))
def kill_containers(self):
def kill_containers(self) -> None:
with self._lock:
active = list(self.active_containers)
@@ -156,8 +159,10 @@ class Builder(object):
self.active_containers.remove(c)
def run_builds(builder, dists, jobs=1, skip_tests=False):
def sig(signum, _frame):
def run_builds(
builder: Builder, dists: Collection[str], jobs: int = 1, skip_tests: bool = False
) -> None:
def sig(signum: int, _frame: Optional[FrameType]) -> None:
print("Caught SIGINT")
builder.kill_containers()

View File

@@ -38,7 +38,7 @@ import argparse
import base64
import json
import sys
from typing import Any, Optional
from typing import Any, Dict, Optional, Tuple
from urllib import parse as urlparse
import requests
@@ -47,13 +47,14 @@ import signedjson.types
import srvlookup
import yaml
from requests.adapters import HTTPAdapter
from urllib3 import HTTPConnectionPool
# uncomment the following to enable debug logging of http requests
# from httplib import HTTPConnection
# HTTPConnection.debuglevel = 1
def encode_base64(input_bytes):
def encode_base64(input_bytes: bytes) -> str:
"""Encode bytes as a base64 string without any padding."""
input_len = len(input_bytes)
@@ -63,7 +64,7 @@ def encode_base64(input_bytes):
return output_string
def encode_canonical_json(value):
def encode_canonical_json(value: object) -> bytes:
return json.dumps(
value,
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes
@@ -130,7 +131,7 @@ def request(
sig,
destination,
)
authorization_headers.append(header.encode("ascii"))
authorization_headers.append(header)
print("Authorization: %s" % header, file=sys.stderr)
dest = "matrix://%s%s" % (destination, path)
@@ -139,7 +140,10 @@ def request(
s = requests.Session()
s.mount("matrix://", MatrixConnectionAdapter())
headers = {"Host": destination, "Authorization": authorization_headers[0]}
headers: Dict[str, str] = {
"Host": destination,
"Authorization": authorization_headers[0],
}
if method == "POST":
headers["Content-Type"] = "application/json"
@@ -154,7 +158,7 @@ def request(
)
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Signs and sends a federation request to a matrix homeserver"
)
@@ -212,6 +216,7 @@ def main():
if not args.server_name or not args.signing_key:
read_args_from_config(args)
assert isinstance(args.signing_key, str)
algorithm, version, key_base64 = args.signing_key.split()
key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)
@@ -233,7 +238,7 @@ def main():
print("")
def read_args_from_config(args):
def read_args_from_config(args: argparse.Namespace) -> None:
with open(args.config, "r") as fh:
config = yaml.safe_load(fh)
@@ -250,7 +255,7 @@ def read_args_from_config(args):
class MatrixConnectionAdapter(HTTPAdapter):
@staticmethod
def lookup(s, skip_well_known=False):
def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]:
if s[-1] == "]":
# ipv6 literal (with no port)
return s, 8448
@@ -276,7 +281,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
return s, 8448
@staticmethod
def get_well_known(server_name):
def get_well_known(server_name: str) -> Optional[str]:
uri = "https://%s/.well-known/matrix/server" % (server_name,)
print("fetching %s" % (uri,), file=sys.stderr)
@@ -299,7 +304,9 @@ class MatrixConnectionAdapter(HTTPAdapter):
print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
return None
def get_connection(self, url, proxies=None):
def get_connection(
self, url: str, proxies: Optional[Dict[str, str]] = None
) -> HTTPConnectionPool:
parsed = urlparse.urlparse(url)
(host, port) = self.lookup(parsed.netloc)

View File

@@ -16,7 +16,7 @@
can crop up, e.g the cache descriptors.
"""
from typing import Callable, Optional
from typing import Callable, Optional, Type
from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
@@ -94,7 +94,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
return signature
def plugin(version: str):
def plugin(version: str) -> Type[SynapsePlugin]:
# This is the entry point of the plugin, and let's us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version
# string.

View File

@@ -25,7 +25,7 @@ import sys
import urllib.request
from os import path
from tempfile import TemporaryDirectory
from typing import List, Optional
from typing import Any, List, Optional, cast
import attr
import click
@@ -36,7 +36,9 @@ from github import Github
from packaging import version
def run_until_successful(command, *args, **kwargs):
def run_until_successful(
command: str, *args: Any, **kwargs: Any
) -> subprocess.CompletedProcess:
while True:
completed_process = subprocess.run(command, *args, **kwargs)
exit_code = completed_process.returncode
@@ -50,7 +52,7 @@ def run_until_successful(command, *args, **kwargs):
@click.group()
def cli():
def cli() -> None:
"""An interactive script to walk through the parts of creating a release.
Requires the dev dependencies be installed, which can be done via:
@@ -81,7 +83,7 @@ def cli():
@cli.command()
def prepare():
def prepare() -> None:
"""Do the initial stages of creating a release, including creating release
branch, updating changelog and pushing to GitHub.
"""
@@ -161,7 +163,9 @@ def prepare():
click.get_current_context().abort()
# Switch to the release branch.
parsed_new_version: version.Version = version.parse(new_version)
# Cast safety: parse() won't return a version.LegacyVersion from our
# version string format.
parsed_new_version = cast(version.Version, version.parse(new_version))
# We assume for debian changelogs that we only do RCs or full releases.
assert not parsed_new_version.is_devrelease
@@ -176,7 +180,6 @@ def prepare():
# If the release branch only exists on the remote we check it out
# locally.
repo.git.checkout(release_branch_name)
release_branch = repo.active_branch
else:
# If a branch doesn't exist we create one. We ask which one branch it
# should be based off, defaulting to sensible values depending on the
@@ -198,13 +201,15 @@ def prepare():
click.get_current_context().abort()
# Check out the base branch and ensure it's up to date
repo.head.reference = base_branch
repo.head.set_reference(base_branch, "check out the base branch")
repo.head.reset(index=True, working_tree=True)
if not base_branch.is_remote():
update_branch(repo)
# Create the new release branch
release_branch = repo.create_head(release_branch_name, commit=base_branch)
# Type ignore will no longer be needed after GitPython 3.1.28.
# See https://github.com/gitpython-developers/GitPython/pull/1419
repo.create_head(release_branch_name, commit=base_branch) # type: ignore[arg-type]
# Switch to the release branch and ensure it's up to date.
repo.git.checkout(release_branch_name)
@@ -265,7 +270,7 @@ def prepare():
@cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"])
def tag(gh_token: Optional[str]):
def tag(gh_token: Optional[str]) -> None:
"""Tags the release and generates a draft GitHub release"""
# Make sure we're in a git repo.
@@ -293,7 +298,12 @@ def tag(gh_token: Optional[str]):
click.echo_via_pager(changes)
if click.confirm("Edit text?", default=False):
changes = click.edit(changes, require_save=False)
edited_changes = click.edit(changes, require_save=False)
# This assert is for mypy's benefit. click's docs are a little unclear, but
# when `require_save=False`, not saving the temp file in the editor returns
# the original string.
assert edited_changes is not None
changes = edited_changes
repo.create_tag(tag_name, message=changes, sign=True)
@@ -347,7 +357,7 @@ def tag(gh_token: Optional[str]):
@cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=True)
def publish(gh_token: str):
def publish(gh_token: str) -> None:
"""Publish release."""
# Make sure we're in a git repo.
@@ -390,7 +400,7 @@ def publish(gh_token: str):
@cli.command()
def upload():
def upload() -> None:
"""Upload release to pypi."""
current_version = get_package_version()
@@ -418,7 +428,7 @@ def upload():
@cli.command()
def announce():
def announce() -> None:
"""Generate markdown to announce the release."""
current_version = get_package_version()
@@ -461,18 +471,19 @@ def get_package_version() -> version.Version:
def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]:
"""Find the branch/ref, looking first locally then in the remote."""
if ref_name in repo.refs:
return repo.refs[ref_name]
if ref_name in repo.references:
return repo.references[ref_name]
elif ref_name in repo.remote().refs:
return repo.remote().refs[ref_name]
else:
return None
def update_branch(repo: git.Repo):
def update_branch(repo: git.Repo) -> None:
"""Ensure branch is up to date if it has a remote"""
if repo.active_branch.tracking_branch():
repo.git.merge(repo.active_branch.tracking_branch().name)
tracking_branch = repo.active_branch.tracking_branch()
if tracking_branch:
repo.git.merge(tracking_branch.name)
def get_changes_for_version(wanted_version: version.Version) -> str:
@@ -536,7 +547,9 @@ def get_changes_for_version(wanted_version: version.Version) -> str:
return "\n".join(version_changelog)
def generate_and_write_changelog(current_version: version.Version, new_version: str):
def generate_and_write_changelog(
current_version: version.Version, new_version: str
) -> None:
# We do this by getting a draft so that we can edit it before writing to the
# changelog.
result = run_until_successful(
@@ -558,8 +571,8 @@ def generate_and_write_changelog(current_version: version.Version, new_version:
f.write(existing_content)
# Remove all the news fragments
for f in glob.iglob("changelog.d/*.*"):
os.remove(f)
for filename in glob.iglob("changelog.d/*.*"):
os.remove(filename)
if __name__ == "__main__":

View File

@@ -27,7 +27,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.util import json_encoder
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="""Adds a signature to a JSON object.

View File

@@ -115,9 +115,7 @@ class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]):
def __getitem__(self, index: slice) -> List[_KT_co]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
class SortedItemsView( # type: ignore
ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]
):
class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]):
def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ...
@overload
def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ...

View File

@@ -48,7 +48,6 @@ from twisted.logger import LoggingFile, LogLevel
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.python.threadpool import ThreadPool
import synapse
from synapse.api.constants import MAX_PDU_SIZE
from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home
@@ -60,6 +59,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.opentracing import init_tracer
from synapse.metrics import install_gc_manager, register_threadpool
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -431,7 +431,7 @@ async def start(hs: "HomeServer") -> None:
refresh_certificate(hs)
# Start the tracer
synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
init_tracer(hs) # noqa
# Instantiate the modules so they can register their web resources to the module API
# before we start the listeners.

View File

@@ -43,6 +43,9 @@ class RegistrationConfig(Config):
self.registration_requires_token = config.get(
"registration_requires_token", False
)
self.enable_registration_token_3pid_bypasss = config.get(
"enable_registration_token_3pid_bypasss", False
)
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -309,6 +312,12 @@ class RegistrationConfig(Config):
#
#registration_requires_token: true
# Allow users to submit a token during registration to bypass any required 3pid
# steps configured in `registrations_require_3pid`.
# Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow.
#
#enable_registration_token_3pid_bypass: false
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#

View File

@@ -186,7 +186,7 @@ KNOWN_RESOURCES = {
class HttpResourceConfig:
names: List[str] = attr.ib(
factory=list,
validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore
validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)),
)
compress: bool = attr.ib(
default=False,
@@ -231,9 +231,7 @@ class ManholeConfig:
class LimitRemoteRoomsConfig:
enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False)
complexity: Union[float, int] = attr.ib(
validator=attr.validators.instance_of(
(float, int) # type: ignore[arg-type] # noqa
),
validator=attr.validators.instance_of((float, int)), # noqa
default=1.0,
)
complexity_error: str = attr.ib(

View File

@@ -27,7 +27,6 @@ from typing import (
)
import attr
from frozendict import frozendict
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
@@ -204,7 +203,9 @@ def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
key_to_move = field.pop(-1)
sub_dict = src
for sub_field in field: # e.g. sub_field => "content"
if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]:
if sub_field in sub_dict and isinstance(
sub_dict[sub_field], collections.abc.Mapping
):
sub_dict = sub_dict[sub_field]
else:
return
@@ -622,7 +623,7 @@ def validate_canonicaljson(value: Any) -> None:
# Note that Infinity, -Infinity, and NaN are also considered floats.
raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON)
elif isinstance(value, (dict, frozendict)):
elif isinstance(value, collections.abc.Mapping):
for v in value.values():
validate_canonicaljson(v)

View File

@@ -268,8 +268,8 @@ class FederationServer(FederationBase):
transaction_id=transaction_id,
destination=destination,
origin=origin,
origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore
pdus=transaction_data.get("pdus"), # type: ignore
origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore[arg-type]
pdus=transaction_data.get("pdus"),
edus=transaction_data.get("edus"),
)

View File

@@ -229,21 +229,21 @@ class TransportLayerClient:
"""
logger.debug(
"send_data dest=%s, txid=%s",
transaction.destination, # type: ignore
transaction.transaction_id, # type: ignore
transaction.destination,
transaction.transaction_id,
)
if transaction.destination == self.server_name: # type: ignore
if transaction.destination == self.server_name:
raise RuntimeError("Transport layer cannot send to itself!")
# FIXME: This is only used by the tests. The actual json sent is
# generated by the json_data_callback.
json_data = transaction.get_dict()
path = _create_v1_path("/send/%s", transaction.transaction_id) # type: ignore
path = _create_v1_path("/send/%s", transaction.transaction_id)
return await self.client.put_json(
transaction.destination, # type: ignore
transaction.destination,
path=path,
data=json_data,
json_data_callback=json_data_callback,

View File

@@ -481,7 +481,7 @@ class AuthHandler:
sid = authdict["session"]
# Convert the URI and method to strings.
uri = request.uri.decode("utf-8") # type: ignore
uri = request.uri.decode("utf-8")
method = request.method.decode("utf-8")
# If there's no session ID, create a new session.

View File

@@ -966,7 +966,7 @@ class OidcProvider:
"Mapping provider does not support de-duplicating Matrix IDs"
)
attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token
)

View File

@@ -0,0 +1,138 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional, Union
import attr
from synapse.api.errors import SynapseError, UnrecognizedRequestError
from synapse.push.baserules import BASE_RULE_IDS
from synapse.storage.push_rule import RuleNotFoundException
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RuleSpec:
scope: str
template: str
rule_id: str
attr: Optional[str]
class PushRulesHandler:
"""A class to handle changes in push rules for users."""
def __init__(self, hs: "HomeServer"):
self._notifier = hs.get_notifier()
self._main_store = hs.get_datastores().main
async def set_rule_attr(
self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
) -> None:
"""Set an attribute (enabled or actions) on an existing push rule.
Notifies listeners (e.g. sync handler) of the change.
Args:
user_id: the user for which to modify the push rule.
spec: the spec of the push rule to modify.
val: the value to change the attribute to.
Raises:
RuleNotFoundException if the rule being modified doesn't exist.
SynapseError(400) if the value is malformed.
UnrecognizedRequestError if the attribute to change is unknown.
InvalidRuleException if we're trying to change the actions on a rule but
the provided actions aren't compliant with the spec.
"""
if spec.attr not in ("enabled", "actions"):
# for the sake of potential future expansion, shouldn't report
# 404 in the case of an unknown request so check it corresponds to
# a known attribute first.
raise UnrecognizedRequestError()
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise RuleNotFoundException("Unknown rule %r" % (namespaced_rule_id,))
if spec.attr == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
if not isinstance(val, bool):
# Legacy fallback
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
await self._main_store.set_push_rule_enabled(
user_id, namespaced_rule_id, val, is_default_rule
)
elif spec.attr == "actions":
if not isinstance(val, dict):
raise SynapseError(400, "Value must be a dict")
actions = val.get("actions")
if not isinstance(actions, list):
raise SynapseError(400, "Value for 'actions' must be dict")
check_actions(actions)
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise RuleNotFoundException(
"Unknown rule %r" % (namespaced_rule_id,)
)
await self._main_store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
raise UnrecognizedRequestError()
self.notify_user(user_id)
def notify_user(self, user_id: str) -> None:
"""Notify listeners about a push rule change.
Args:
user_id: the user ID the change is for.
"""
stream_id = self._main_store.get_max_push_rules_stream_id()
self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
def check_actions(actions: List[Union[str, JsonDict]]) -> None:
"""Check if the given actions are spec compliant.
Args:
actions: the actions to check.
Raises:
InvalidRuleException if the rules aren't compliant with the spec.
"""
if not isinstance(actions, list):
raise InvalidRuleException("No actions found")
for a in actions:
if a in ["notify", "dont_notify", "coalesce"]:
pass
elif isinstance(a, dict) and "set_tweak" in a:
pass
else:
raise InvalidRuleException("Unrecognised action %s" % a)
class InvalidRuleException(Exception):
pass

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
import logging
from typing import (
TYPE_CHECKING,
@@ -24,7 +25,6 @@ from typing import (
)
import attr
from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
@@ -380,7 +380,7 @@ class RelationsHandler:
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
if isinstance(relates_to, collections.abc.Mapping):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
continue

View File

@@ -357,7 +357,7 @@ class SearchHandler:
itertools.chain(
# The events_before and events_after for each context.
itertools.chain.from_iterable(
itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
itertools.chain(context["events_before"], context["events_after"])
for context in contexts.values()
),
# The returned events.
@@ -373,10 +373,10 @@ class SearchHandler:
for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
context["events_before"], time_now, bundle_aggregations=aggregations
)
context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
context["events_after"], time_now, bundle_aggregations=aggregations
)
results = [

View File

@@ -256,7 +256,9 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self._enabled = bool(hs.config.registration.registration_requires_token)
self._enabled = bool(
hs.config.registration.registration_requires_token
) or bool(hs.config.registration.enable_registration_token_3pid_bypasss)
self.store = hs.get_datastores().main
def is_enabled(self) -> bool:

View File

@@ -295,7 +295,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
if isawaitable(raw_callback_return):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return # type: ignore
callback_return = raw_callback_return
return callback_return
@@ -469,7 +469,7 @@ class JsonResource(DirectServeJsonResource):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return # type: ignore
callback_return = raw_callback_return
return callback_return

View File

@@ -722,6 +722,11 @@ P = ParamSpec("P")
R = TypeVar("R")
async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R:
"""Unwraps an arbitrary awaitable by awaiting it."""
return await awaitable
@overload
def preserve_fn( # type: ignore[misc]
f: Callable[P, Awaitable[R]],
@@ -802,17 +807,20 @@ def run_in_background( # type: ignore[misc]
# by synchronous exceptions, so let's turn them into Failures.
return defer.fail()
# `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain
# value. Convert it to a `Deferred`.
if isinstance(res, typing.Coroutine):
# Wrap the coroutine in a `Deferred`.
res = defer.ensureDeferred(res)
# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
# `res` is not a `Deferred` and not a `Coroutine`.
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
assert not isinstance(res, Awaitable)
return defer.succeed(res)
elif isinstance(res, defer.Deferred):
pass
elif isinstance(res, Awaitable):
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
# or `Future` from `make_awaitable`.
res = defer.ensureDeferred(_unwrap_awaitable(res))
else:
# `res` is a plain value. Wrap it in a `Deferred`.
res = defer.succeed(res)
if res.called and not res.paused:
# The function should have maintained the logcontext, so we can

View File

@@ -82,6 +82,7 @@ from synapse.handlers.auth import (
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
DirectServeHtmlResource,
@@ -109,12 +110,14 @@ from synapse.storage.state import StateFilter
from synapse.types import (
DomainSpecificString,
JsonDict,
JsonMapping,
Requester,
StateMap,
UserID,
UserInfo,
UserProfile,
create_requester,
map_username_to_mxid_localpart,
)
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
@@ -151,6 +154,7 @@ __all__ = [
"PRESENCE_ALL_USERS",
"LoginResponse",
"JsonDict",
"JsonMapping",
"EventBase",
"StateMap",
"ProfileInfo",
@@ -193,6 +197,7 @@ class ModuleApi:
self._clock: Clock = hs.get_clock()
self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
self._push_rules_handler = hs.get_push_rules_handler()
self.custom_template_dir = hs.config.server.custom_template_directory
try:
@@ -569,6 +574,26 @@ class ModuleApi:
return username
return UserID(username, self._hs.hostname).to_string()
def normalize_username(
self, username: Union[str, bytes], case_sensitive: bool = False
) -> str:
"""Map a username onto a string suitable for a MXID
This follows the algorithm laid out at
https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
Added in Synapse v1.58.0
Args:
username: username to be mapped
case_sensitive: true if TEST and test should be mapped
onto different mxids
Returns:
string suitable for a mxid localpart
"""
return map_username_to_mxid_localpart(username, case_sensitive)
async def get_profile_for_user(self, localpart: str) -> ProfileInfo:
"""Look up the profile info for the user with the given localpart.
@@ -1350,6 +1375,68 @@ class ModuleApi:
"""
await self._store.add_user_bound_threepid(user_id, medium, address, id_server)
def check_push_rule_actions(
self, actions: List[Union[str, Dict[str, str]]]
) -> None:
"""Checks if the given push rule actions are valid according to the Matrix
specification.
See https://spec.matrix.org/v1.2/client-server-api/#actions for the list of valid
actions.
Added in Synapse v1.58.0.
Args:
actions: the actions to check.
Raises:
synapse.module_api.errors.InvalidRuleException if the actions are invalid.
"""
check_actions(actions)
async def set_push_rule_action(
self,
user_id: str,
scope: str,
kind: str,
rule_id: str,
actions: List[Union[str, Dict[str, str]]],
) -> None:
"""Changes the actions of an existing push rule for the given user.
See https://spec.matrix.org/v1.2/client-server-api/#push-rules for more
information about push rules and their syntax.
Can only be called on the main process.
Added in Synapse v1.58.0.
Args:
user_id: the user for which to change the push rule's actions.
scope: the push rule's scope, currently only "global" is allowed.
kind: the push rule's kind.
rule_id: the push rule's identifier.
actions: the actions to run when the rule's conditions match.
Raises:
RuntimeError if this method is called on a worker or `scope` is invalid.
synapse.module_api.errors.RuleNotFoundException if the rule being modified
can't be found.
synapse.module_api.errors.InvalidRuleException if the actions are invalid.
"""
if self.worker_app is not None:
raise RuntimeError("module tried to change push rule actions on a worker")
if scope != "global":
raise RuntimeError(
"invalid scope %s, only 'global' is currently allowed" % scope
)
spec = RuleSpec(scope, kind, rule_id, "actions")
await self._push_rules_handler.set_rule_attr(
user_id, spec, {"actions": actions}
)
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
@@ -1419,7 +1506,7 @@ class AccountDataManager:
f"{user_id} is not local to this homeserver; can't access account data for remote users."
)
async def get_global(self, user_id: str, data_type: str) -> Optional[JsonDict]:
async def get_global(self, user_id: str, data_type: str) -> Optional[JsonMapping]:
"""
Gets some global account data, of a specified type, for the specified user.

View File

@@ -20,10 +20,14 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config._base import ConfigError
from synapse.handlers.push_rules import InvalidRuleException
from synapse.storage.push_rule import RuleNotFoundException
__all__ = [
"InvalidClientCredentialsError",
"RedirectException",
"SynapseError",
"ConfigError",
"InvalidRuleException",
"RuleNotFoundException",
]

View File

@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
import attr
from typing import TYPE_CHECKING, List, Sequence, Tuple, Union
from synapse.api.errors import (
NotFoundError,
@@ -22,6 +20,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
from synapse.handlers.push_rules import InvalidRuleException, RuleSpec, check_actions
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -29,7 +28,6 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
@@ -40,14 +38,6 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RuleSpec:
scope: str
template: str
rule_id: str
attr: Optional[str]
class PushRuleRestServlet(RestServlet):
PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
@@ -60,6 +50,7 @@ class PushRuleRestServlet(RestServlet):
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None
self._push_rules_handler = hs.get_push_rules_handler()
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
@@ -81,8 +72,13 @@ class PushRuleRestServlet(RestServlet):
user_id = requester.user.to_string()
if spec.attr:
await self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
try:
await self._push_rules_handler.set_rule_attr(user_id, spec, content)
except InvalidRuleException as e:
raise SynapseError(400, "Invalid actions: %s" % e)
except RuleNotFoundException:
raise NotFoundError("Unknown rule")
return 200, {}
if spec.rule_id.startswith("."):
@@ -98,23 +94,23 @@ class PushRuleRestServlet(RestServlet):
before = parse_string(request, "before")
if before:
before = _namespaced_rule_id(spec, before)
before = f"global/{spec.template}/{before}"
after = parse_string(request, "after")
if after:
after = _namespaced_rule_id(spec, after)
after = f"global/{spec.template}/{after}"
try:
await self.store.add_push_rule(
user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec),
rule_id=f"global/{spec.template}/{spec.rule_id}",
priority_class=priority_class,
conditions=conditions,
actions=actions,
before=before,
after=after,
)
self.notify_user(user_id)
self._push_rules_handler.notify_user(user_id)
except InconsistentRuleException as e:
raise SynapseError(400, str(e))
except RuleNotFoundException as e:
@@ -133,11 +129,11 @@ class PushRuleRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
try:
await self.store.delete_push_rule(user_id, namespaced_rule_id)
self.notify_user(user_id)
self._push_rules_handler.notify_user(user_id)
return 200, {}
except StoreError as e:
if e.code == 404:
@@ -172,55 +168,6 @@ class PushRuleRestServlet(RestServlet):
else:
raise UnrecognizedRequestError()
def notify_user(self, user_id: str) -> None:
stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(
self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
) -> None:
if spec.attr not in ("enabled", "actions"):
# for the sake of potential future expansion, shouldn't report
# 404 in the case of an unknown request so check it corresponds to
# a known attribute first.
raise UnrecognizedRequestError()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
if spec.attr == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
if not isinstance(val, bool):
# Legacy fallback
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
await self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val, is_default_rule
)
elif spec.attr == "actions":
if not isinstance(val, dict):
raise SynapseError(400, "Value must be a dict")
actions = val.get("actions")
if not isinstance(actions, list):
raise SynapseError(400, "Value for 'actions' must be dict")
_check_actions(actions)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
raise UnrecognizedRequestError()
def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
"""Turn a sequence of path components into a rule spec
@@ -291,24 +238,11 @@ def _rule_tuple_from_request_object(
raise InvalidRuleException("No actions found")
actions = req_obj["actions"]
_check_actions(actions)
check_actions(actions)
return conditions, actions
def _check_actions(actions: List[Union[str, JsonDict]]) -> None:
if not isinstance(actions, list):
raise InvalidRuleException("No actions found")
for a in actions:
if a in ["notify", "dont_notify", "coalesce"]:
pass
elif isinstance(a, dict) and "set_tweak" in a:
pass
else:
raise InvalidRuleException("Unrecognised action")
def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict:
if path == []:
raise UnrecognizedRequestError(
@@ -357,17 +291,5 @@ def _priority_class_from_spec(spec: RuleSpec) -> int:
return pc
def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str:
return _namespaced_rule_id(spec, spec.rule_id)
def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str:
return "global/%s/%s" % (spec.template, rule_id)
class InvalidRuleException(Exception):
pass
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushRuleRestServlet(hs).register(http_server)

View File

@@ -929,6 +929,10 @@ def _calculate_registration_flows(
# always let users provide both MSISDN & email
flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
# Add a flow that doesn't require any 3pids, if the config requests it.
if config.registration.enable_registration_token_3pid_bypasss:
flows.append([LoginType.REGISTRATION_TOKEN])
# Prepend m.login.terms to all flows if we're requiring consent
if config.consent.user_consent_at_registration:
for flow in flows:
@@ -942,7 +946,8 @@ def _calculate_registration_flows(
# Prepend registration token to all flows if we're requiring a token
if config.registration.registration_requires_token:
for flow in flows:
flow.insert(0, LoginType.REGISTRATION_TOKEN)
if LoginType.REGISTRATION_TOKEN not in flow:
flow.insert(0, LoginType.REGISTRATION_TOKEN)
return flows

View File

@@ -91,6 +91,7 @@ from synapse.handlers.presence import (
WorkerPresenceHandler,
)
from synapse.handlers.profile import ProfileHandler
from synapse.handlers.push_rules import PushRulesHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler
@@ -810,6 +811,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_account_handler(self) -> AccountHandler:
return AccountHandler(self)
@cache_in_self
def get_push_rules_handler(self) -> PushRulesHandler:
return PushRulesHandler(self)
@cache_in_self
def get_outbound_redis_connection(self) -> "ConnectionHandler":
"""

View File

@@ -232,10 +232,10 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
self._invalidate_all_cache_and_stream(
txn, self.user_last_seen_monthly_active
)
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
reserved_users = await self.get_registered_reserved_users()
await self.db_pool.runInteraction(
@@ -363,7 +363,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
is_guest = await self.is_guest(user_id)
if is_guest:
return
is_trial = await self.is_trial_user(user_id)

View File

@@ -16,7 +16,7 @@ import abc
import logging
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import NotFoundError, StoreError
from synapse.api.errors import StoreError
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -618,7 +618,7 @@ class PushRuleStore(PushRulesWorkerStore):
are always stored in the database `push_rules` table).
Raises:
NotFoundError if the rule does not exist.
RuleNotFoundException if the rule does not exist.
"""
async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
@@ -668,8 +668,7 @@ class PushRuleStore(PushRulesWorkerStore):
)
txn.execute(sql, (user_id, rule_id))
if txn.fetchone() is None:
# needed to set NOT_FOUND code.
raise NotFoundError("Push rule does not exist.")
raise RuleNotFoundException("Push rule does not exist.")
self.db_pool.simple_upsert_txn(
txn,
@@ -698,9 +697,6 @@ class PushRuleStore(PushRulesWorkerStore):
"""
Sets the `actions` state of a push rule.
Will throw NotFoundError if the rule does not exist; the Code for this
is NOT_FOUND.
Args:
user_id: the user ID of the user who wishes to enable/disable the rule
e.g. '@tina:example.org'
@@ -712,6 +708,9 @@ class PushRuleStore(PushRulesWorkerStore):
is_default_rule: True if and only if this is a server-default rule.
This skips the check for existence (as only user-created rules
are always stored in the database `push_rules` table).
Raises:
RuleNotFoundException if the rule does not exist.
"""
actions_json = json_encoder.encode(actions)
@@ -744,7 +743,7 @@ class PushRuleStore(PushRulesWorkerStore):
except StoreError as serr:
if serr.code == 404:
# this sets the NOT_FOUND error Code
raise NotFoundError("Push rule does not exist")
raise RuleNotFoundException("Push rule does not exist")
else:
raise

View File

@@ -12,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
from frozendict import frozendict
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -160,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
predecessor = create_event.content.get("predecessor", None)
# Ensure the key is a dictionary
if not isinstance(predecessor, (dict, frozendict)):
if not isinstance(predecessor, collections.abc.Mapping):
return None
# The keys must be strings since the data is JSON.

View File

@@ -501,11 +501,11 @@ def _upgrade_existing_database(
if hasattr(module, "run_create"):
logger.info("Running %s:run_create", relative_path)
module.run_create(cur, database_engine) # type: ignore
module.run_create(cur, database_engine)
if not is_empty and hasattr(module, "run_upgrade"):
logger.info("Running %s:run_upgrade", relative_path)
module.run_upgrade(cur, database_engine, config=config) # type: ignore
module.run_upgrade(cur, database_engine, config=config)
elif ext == ".pyc" or file_name == "__pycache__":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package

View File

@@ -107,7 +107,7 @@ class TTLCache(Generic[KT, VT]):
self._metrics.inc_hits()
return e.value, e.expiry_time, e.ttl
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Remove a value from the cache
If key is in the cache, remove it and return its value, else return default.

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
from typing import Any
from frozendict import frozendict
@@ -35,7 +36,7 @@ def freeze(o: Any) -> Any:
def unfreeze(o: Any) -> Any:
if isinstance(o, (dict, frozendict)):
if isinstance(o, collections.abc.Mapping):
return {k: unfreeze(v) for k, v in o.items()}
if isinstance(o, (bytes, str)):

View File

@@ -83,7 +83,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
)
# mock up the response, and have the agent return it
self._mock_agent.request.return_value = defer.succeed(
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
_mock_response(
{
"pdus": [

View File

@@ -226,7 +226,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
self.hs.get_federation_transport_client().query_user_devices.return_value = (
defer.succeed(
make_awaitable(
{
"stream_id": "1",
"user_id": "@user2:host2",

View File

@@ -19,7 +19,6 @@ from unittest import mock
from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
@@ -704,7 +703,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_client_keys = mock.Mock(
return_value=defer.succeed(
return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
"master_keys": {
@@ -777,14 +776,14 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
return_value=make_awaitable({"some_room_id"})
)
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_user_devices = mock.Mock(
return_value=defer.succeed(
return_value=make_awaitable(
{
"user_id": remote_user_id,
"stream_id": 1,

View File

@@ -17,8 +17,6 @@
from typing import Any, Type, Union
from unittest.mock import Mock
from twisted.internet import defer
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -190,7 +188,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
@@ -226,13 +224,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.get_success(module_api.register_user("u"))
# log in twice, to get two devices
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()
# have the auth provider deny the request to start with
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
@@ -246,7 +244,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@@ -260,7 +258,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)
@@ -277,7 +275,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# have the auth provider deny the request
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
@@ -320,7 +318,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -342,7 +340,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# allow login via the auth provider
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
# log in twice, to get two devices
tok1 = self.login("localuser", "p")
@@ -359,7 +357,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
@@ -413,7 +411,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
@@ -427,7 +425,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
@@ -477,7 +475,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
body["auth"]["test_field"] = "foo"
@@ -490,7 +488,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
@@ -508,9 +506,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None))
callback = Mock(return_value=make_awaitable(None))
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
@@ -646,7 +644,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")

View File

@@ -193,8 +193,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self):
# Type ignore: mypy doesn't like us assigning to methods.
self.store.count_monthly_users = Mock( # type: ignore[assignment]
self.store.count_monthly_users = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
@@ -202,8 +201,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self):
# Type ignore: mypy doesn't like us assigning to methods.
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
@@ -211,8 +209,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
# Type ignore: mypy doesn't like us assigning to methods.
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(

View File

@@ -65,11 +65,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
# we mock out the federation client too
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
@@ -98,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore = hs.get_datastores().main
self.datastore.get_destination_retry_timings = Mock(
return_value=defer.succeed(None)
return_value=make_awaitable(None)
)
self.datastore.get_device_updates_by_remote = Mock(

View File

@@ -15,7 +15,6 @@ from typing import Tuple
from unittest.mock import Mock, patch
from urllib.parse import quote
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -30,6 +29,7 @@ from synapse.util import Clock
from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
@@ -439,7 +439,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
@@ -454,7 +454,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.store.register_user(user_id=r_user_id, password_hash=None)
)
mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):

View File

@@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import SynapseError
from synapse.rest import admin
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -22,7 +26,9 @@ class ModuleApiTestCase(HomeserverTestCase):
admin.register_servlets,
]
def prepare(self, reactor, clock, homeserver) -> None:
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self._store = homeserver.get_datastores().main
self._module_api = homeserver.get_module_api()
self._account_data_mgr = self._module_api.account_data_manager
@@ -91,7 +97,7 @@ class ModuleApiTestCase(HomeserverTestCase):
)
with self.assertRaises(TypeError):
# This throws an exception because it's a frozen dict.
the_data["wombat"] = False
the_data["wombat"] = False # type: ignore[index]
def test_put_global(self) -> None:
"""
@@ -143,15 +149,14 @@ class ModuleApiTestCase(HomeserverTestCase):
with self.assertRaises(TypeError):
# The account data type must be a string.
self.get_success_or_raise(
self._module_api.account_data_manager.put_global(
self.user_id, 42, {} # type: ignore
)
self._module_api.account_data_manager.put_global(self.user_id, 42, {}) # type: ignore[arg-type]
)
with self.assertRaises(TypeError):
# The account data dict must be a dict.
# noinspection PyTypeChecker
self.get_success_or_raise(
self._module_api.account_data_manager.put_global(
self.user_id, "test.data", 42 # type: ignore
self.user_id, "test.data", 42 # type: ignore[arg-type]
)
)

View File

@@ -19,8 +19,9 @@ from synapse.api.constants import EduTypes, EventTypes
from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.handlers.push_rules import InvalidRuleException
from synapse.rest import admin
from synapse.rest.client import login, presence, profile, room
from synapse.rest.client import login, notifications, presence, profile, room
from synapse.types import create_requester
from tests.events.test_presence_router import send_presence_update, sync_presence
@@ -38,6 +39,7 @@ class ModuleApiTestCase(HomeserverTestCase):
room.register_servlets,
presence.register_servlets,
profile.register_servlets,
notifications.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
@@ -553,6 +555,94 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(state[("org.matrix.test", "")].state_key, "")
self.assertEqual(state[("org.matrix.test", "")].content, {})
def test_set_push_rules_action(self) -> None:
"""Test that a module can change the actions of an existing push rule for a user."""
# Create a room with 2 users in it. Push rules must not match if the user is the
# event's sender, so we need one user to send messages and one user to receive
# notifications.
user_id = self.register_user("user", "password")
tok = self.login("user", "password")
room_id = self.helper.create_room_as(user_id, is_public=True, tok=tok)
user_id2 = self.register_user("user2", "password")
tok2 = self.login("user2", "password")
self.helper.join(room_id, user_id2, tok=tok2)
# Register a 3rd user and join them to the room, so that we don't accidentally
# trigger 1:1 push rules.
user_id3 = self.register_user("user3", "password")
tok3 = self.login("user3", "password")
self.helper.join(room_id, user_id3, tok=tok3)
# Send a message as the second user and check that it notifies.
res = self.helper.send(room_id=room_id, body="here's a message", tok=tok2)
event_id = res["event_id"]
channel = self.make_request(
"GET",
"/notifications",
access_token=tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
self.assertEqual(
channel.json_body["notifications"][0]["event"]["event_id"],
event_id,
channel.json_body,
)
# Change the .m.rule.message actions to not notify on new messages.
self.get_success(
defer.ensureDeferred(
self.module_api.set_push_rule_action(
user_id=user_id,
scope="global",
kind="underride",
rule_id=".m.rule.message",
actions=["dont_notify"],
)
)
)
# Send another message as the second user and check that the number of
# notifications didn't change.
self.helper.send(room_id=room_id, body="here's another message", tok=tok2)
channel = self.make_request(
"GET",
"/notifications?from=",
access_token=tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
def test_check_push_rules_actions(self) -> None:
"""Test that modules can check whether a list of push rules actions are spec
compliant.
"""
with self.assertRaises(InvalidRuleException):
self.module_api.check_push_rule_actions(["foo"])
with self.assertRaises(InvalidRuleException):
self.module_api.check_push_rule_actions({"foo": "bar"})
self.module_api.check_push_rule_actions(["notify"])
self.module_api.check_push_rule_actions(
[{"set_tweak": "sound", "value": "default"}]
)
def test_normalize_username(self) -> None:
username = "Haxxor"
username2 = "_leet"
username3 = "aNoThErTeSt"
self.assertEqual(self.module_api.normalize_username(username), "haxxor")
self.assertEqual(self.module_api.normalize_username(username2), "=5fleet")
self.assertEqual(self.module_api.normalize_username(username3), "anothertest")
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""

View File

@@ -102,8 +102,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
for i in range(20):
server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name)
mock_client1.reset_mock() # type: ignore[attr-defined]
mock_client2.reset_mock() # type: ignore[attr-defined]
mock_client1.reset_mock()
mock_client2.reset_mock()
self.create_and_send_event(room, UserID.from_string(user))
self.replicate()
@@ -167,8 +167,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
for i in range(20):
server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name)
mock_client1.reset_mock() # type: ignore[attr-defined]
mock_client2.reset_mock() # type: ignore[attr-defined]
mock_client1.reset_mock()
mock_client2.reset_mock()
self.get_success(
typing_handler.started_typing(

View File

@@ -14,7 +14,6 @@
from http import HTTPStatus
from unittest.mock import Mock
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.presence import PresenceHandler
@@ -24,6 +23,7 @@ from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class PresenceTestCase(unittest.HomeserverTestCase):
@@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
presence_handler = Mock(spec=PresenceHandler)
presence_handler.set_state.return_value = defer.succeed(None)
presence_handler.set_state.return_value = make_awaitable(None)
hs = self.setup_test_homeserver(
"red",

View File

@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, List, Optional
from unittest.mock import Mock, call
from urllib import parse as urlparse
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -1426,9 +1425,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
def test_simple(self) -> None:
"Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
{}
)
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
search_filter = {"generic_search_term": "foobar"}
@@ -1456,7 +1453,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(404, "Not Found", b""),
defer.succeed({}),
make_awaitable({}),
)
search_filter = {"generic_search_term": "foobar"}

View File

@@ -22,6 +22,7 @@ from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionC
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import MockClock
@@ -38,7 +39,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_executes_given_function(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
)
@@ -47,7 +48,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_deduplicates_based_on_key(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
@@ -130,7 +131,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cleans_up(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)

View File

@@ -14,8 +14,6 @@
from unittest.mock import Mock
from twisted.internet import defer
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
from synapse.rest import admin
@@ -68,16 +66,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=make_awaitable(1000)
)
self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock())
return_value=make_awaitable(Mock())
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self.user_id = "@user_id:test"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
return_value=defer.succeed("!something:localhost")
return_value=make_awaitable("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@override_config({"hs_disabled": True})
@@ -95,7 +93,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
@@ -111,7 +109,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
mock_event = Mock(
@@ -130,7 +129,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user does not have blocked notice, but should have one
"""
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -141,7 +141,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -152,7 +152,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(None)
)
@@ -167,7 +167,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
an alert message is not sent into the room
"""
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
@@ -182,7 +182,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
),
@@ -199,14 +199,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
is suppressed that the room is returned to an unblocked state.
"""
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
)
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
return_value=defer.succeed((True, []))
return_value=make_awaitable((True, []))
)
mock_event = Mock(

View File

@@ -14,7 +14,6 @@
from typing import Any, Dict, List
from unittest.mock import Mock
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
@@ -259,10 +258,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_should_update(self):
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
return_value=make_awaitable(None)
)
d = self.store.populate_monthly_active_users("user_id")
self.get_success(d)
@@ -272,9 +271,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_should_not_update(self):
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
return_value=make_awaitable(self.hs.get_clock().time_msec())
)
d = self.store.populate_monthly_active_users("user_id")

View File

@@ -233,7 +233,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock(
return_value=succeed(
return_value=make_awaitable(
{
"user_id": remote_user_id,
"stream_id": 1,

View File

@@ -52,7 +52,7 @@ def make_awaitable(result: TV) -> Awaitable[TV]:
This uses Futures as they can be awaited multiple times so can be returned
to multiple callers.
"""
future = Future() # type: ignore
future: Future[TV] = Future()
future.set_result(result)
return future
@@ -69,7 +69,7 @@ def setup_awaitable_errors() -> Callable[[], None]:
# State shared between unraisablehook and check_for_unraisable_exceptions.
unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook # type: ignore
orig_unraisablehook = sys.unraisablehook
def unraisablehook(unraisable):
unraisable_exceptions.append(unraisable.exc_value)
@@ -78,11 +78,11 @@ def setup_awaitable_errors() -> Callable[[], None]:
"""
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
"""
sys.unraisablehook = orig_unraisablehook # type: ignore
sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions:
raise unraisable_exceptions.pop()
sys.unraisablehook = unraisablehook # type: ignore
sys.unraisablehook = unraisablehook
return cleanup

View File

@@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit( # type: ignore
self.tx_log.emit(
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)